* getExitNode() doesn't exist in method anymore
[oota-llvm.git] / lib / Transforms / Instrumentation / TraceValues.cpp
1 // $Id$
2 //***************************************************************************
3 // File:
4 //      TraceValues.cpp
5 // 
6 // Purpose:
7 //      Support for inserting LLVM code to print values at basic block
8 //      and method exits.  Also exports functions to create a call
9 //      "printf" instruction with one of the signatures listed below.
10 // 
11 // History:
12 //      10/11/01         -  Vikram Adve  -  Created
13 //**************************************************************************/
14
15
16 #include "llvm/Transforms/Instrumentation/TraceValues.h"
17 #include "llvm/GlobalVariable.h"
18 #include "llvm/ConstPoolVals.h"
19 #include "llvm/Type.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Instruction.h"
22 #include "llvm/iTerminators.h"
23 #include "llvm/iOther.h"
24 #include "llvm/BasicBlock.h"
25 #include "llvm/Method.h"
26 #include "llvm/Module.h"
27 #include "llvm/SymbolTable.h"
28 #include "llvm/Support/HashExtras.h"
29 #include <hash_map>
30 #include <strstream.h>
31
32
33 //*********************** Internal Data Structures *************************/
34
35 const char* const PRINTF = "printf";
36
37 #undef DONT_EMBED_STRINGS_IN_FMT
38
39
40 //************************** Internal Functions ****************************/
41
42 #undef USE_PTRREF
43 #ifdef USE_PTRREF
44 static inline ConstPoolPointerReference*
45 GetStringRef(Module* module, const char* str)
46 {
47   static hash_map<string, ConstPoolPointerReference*> stringRefCache;
48   static Module* lastModule = NULL;
49   
50   if (lastModule != module)
51     { // Let's make sure we create separate global references in each module
52       stringRefCache.clear();
53       lastModule = module;
54     }
55   
56   ConstPoolPointerReference* result = stringRefCache[str];
57   if (result == NULL)
58     {
59       ConstPoolArray* charArray = ConstPoolArray::get(str);
60       GlobalVariable* stringVar =
61         new GlobalVariable(charArray->getType(),/*isConst*/true,charArray,str);
62       module->getGlobalList().push_back(stringVar);
63       result = ConstPoolPointerReference::get(stringVar);
64       assert(result && "Failed to create reference to string constant");
65       stringRefCache[str] = result;
66     }
67   
68   return result;
69 }
70 #endif USE_PTRREF
71
72 static inline GlobalVariable*
73 GetStringRef(Module* module, const char* str)
74 {
75   static hash_map<string, GlobalVariable*> stringRefCache;
76   static Module* lastModule = NULL;
77   
78   if (lastModule != module)
79     { // Let's make sure we create separate global references in each module
80       stringRefCache.clear();
81       lastModule = module;
82     }
83   
84   GlobalVariable* result = stringRefCache[str];
85   if (result == NULL)
86     {
87       ConstPoolArray* charArray = ConstPoolArray::get(str);
88       GlobalVariable* stringVar =
89         new GlobalVariable(charArray->getType(),/*isConst*/true,charArray);
90       module->getGlobalList().push_back(stringVar);
91       result = stringVar;
92       // result = ConstPoolPointerReference::get(stringVar);
93       assert(result && "Failed to create reference to string constant");
94       stringRefCache[str] = result;
95     }
96   
97   return result;
98 }
99
100
101 static inline bool
102 TraceThisOpCode(unsigned opCode)
103 {
104   // Explicitly test for opCodes *not* to trace so that any new opcodes will
105   // be traced by default (or will fail in a later assertion on VoidTy)
106   // 
107   return (opCode  < Instruction::FirstOtherOp &&
108           opCode != Instruction::Ret &&
109           opCode != Instruction::Br &&
110           opCode != Instruction::Switch &&
111           opCode != Instruction::Free &&
112           opCode != Instruction::Alloca &&
113           opCode != Instruction::Store &&
114           opCode != Instruction::PHINode &&
115           opCode != Instruction::Cast);
116 }
117
118
119 static void
120 FindValuesToTraceInBB(BasicBlock* bb,
121                       vector<Value*>& valuesToTraceInBB)
122 {
123   for (BasicBlock::iterator II = bb->begin(); II != bb->end(); ++II)
124     if ((*II)->getType()->isPrimitiveType() &&
125         TraceThisOpCode((*II)->getOpcode()))
126       {
127         valuesToTraceInBB.push_back(*II);
128       }
129 }
130
131
132 // 
133 // Insert print instructions at the end of the basic block *bb
134 // for each value in valueVec[].  *bb must postdominate the block
135 // in which the value is computed; this is not checked here.
136 // 
137 static void
138 TraceValuesAtBBExit(const vector<Value*>& valueVec,
139                     BasicBlock* bb,
140                     Module* module,
141                     unsigned int indent,
142                     bool isMethodExit)
143 {
144   // Get an iterator to point to the insertion location
145   // 
146   BasicBlock::InstListType& instList = bb->getInstList();
147   TerminatorInst* termInst = bb->getTerminator(); 
148   BasicBlock::InstListType::iterator here = instList.end();
149   while ((*here) != termInst && here != instList.begin())
150     --here;
151   assert((*here) == termInst);
152   
153   // Insert a print instruction for each value.
154   // 
155   for (unsigned i=0, N=valueVec.size(); i < N; i++)
156     {
157       Instruction* traceInstr =
158         CreatePrintInstr(valueVec[i], bb, module, indent, isMethodExit);
159       here = instList.insert(here, traceInstr);
160     }
161 }
162
163 static void
164 InsertCodeToShowMethodEntry(BasicBlock* entryBB)
165 {
166 }
167
168 static void
169 InsertCodeToShowMethodExit(BasicBlock* exitBB)
170 {
171 }
172
173
174 //************************** External Functions ****************************/
175
176 // 
177 // The signatures of the print methods supported are:
178 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  int      intValue)
179 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  unsigned uintValue)
180 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  float    floatValue)
181 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  double   doubleValue)
182 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  char*    stringValue)
183 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  void*    ptrValue)
184 // 
185 // The invocation should be:
186 //       call "printf"(fmt, bbName, valueName, valueTypeName, value).
187 // 
188 Method*
189 GetPrintMethodForType(Module* module, const Type* valueType)
190 {
191 #ifdef DONT_EMBED_STRINGS_IN_FMT
192   static const int LASTARGINDEX = 4;
193 #else
194   static const int LASTARGINDEX = 1;
195 #endif
196   static PointerType* ubytePtrTy = NULL;
197   static vector<const Type*> argTypesVec(LASTARGINDEX + 1);
198   
199   if (ubytePtrTy == NULL)
200     { // create these once since they are invariant
201       ubytePtrTy = PointerType::get(ArrayType::get(Type::UByteTy));
202       argTypesVec[0] = ubytePtrTy;
203 #ifdef DONT_EMBED_STRINGS_IN_FMT
204       argTypesVec[1] = ubytePtrTy;
205       argTypesVec[2] = ubytePtrTy;
206       argTypesVec[3] = ubytePtrTy;
207 #endif DONT_EMBED_STRINGS_IN_FMT
208     }
209   
210   SymbolTable* symtab = module->getSymbolTable();
211   argTypesVec[LASTARGINDEX] = valueType;
212   MethodType* printMethodTy = MethodType::get(Type::IntTy, argTypesVec,
213                                               /*isVarArg*/ false);
214   
215   Method* printMethod =
216     cast<Method>(symtab->lookup(PointerType::get(printMethodTy), PRINTF));
217   if (printMethod == NULL)
218     { // Create a new method and add it to the module
219       printMethod = new Method(printMethodTy, PRINTF);
220       module->getMethodList().push_back(printMethod);
221       
222       // Create the argument list for the method so that the full signature
223       // can be declared.  The args can be anonymous.
224       Method::ArgumentListType &argList = printMethod->getArgumentList();
225       for (unsigned i=0; i < argTypesVec.size(); ++i)
226         argList.push_back(new MethodArgument(argTypesVec[i]));
227     }
228   
229   return printMethod;
230 }
231
232
233 Instruction*
234 CreatePrintInstr(Value* val,
235                  const BasicBlock* bb,
236                  Module* module,
237                  unsigned int indent,
238                  bool isMethodExit)
239 {
240   strstream fmtString, scopeNameString, valNameString;
241   vector<Value*> paramList;
242   const Type* valueType = val->getType();
243   Method* printMethod = GetPrintMethodForType(module, valueType);
244   
245   if (! valueType->isPrimitiveType() ||
246       valueType->getPrimitiveID() == Type::VoidTyID ||
247       valueType->getPrimitiveID() == Type::TypeTyID ||
248       valueType->getPrimitiveID() == Type::LabelTyID)
249     {
250       assert(0 && "Unsupported type for printing");
251       return NULL;
252     }
253   
254   const Value* scopeToUse = (isMethodExit)? (const Value*) bb->getParent()
255                                           : (const Value*) bb;
256   if (scopeToUse->hasName())
257     scopeNameString << scopeToUse->getName() << ends;
258   else
259     scopeNameString << scopeToUse << ends;
260   
261   if (val->hasName())
262     valNameString << val->getName() << ends;
263   else
264     valNameString << val << ends;
265     
266   for (unsigned i=0; i < indent; i++)
267     fmtString << " ";
268   
269 #undef DONT_EMBED_STRINGS_IN_FMT
270 #ifdef DONT_EMBED_STRINGS_IN_FMT
271   fmtString << " At exit of "
272             << ((isMethodExit)? "Method " : "BB ")
273             << "%s : val %s = %s ";
274   
275   GlobalVariable* scopeNameVal = GetStringRef(module, scopeNameString.str());
276   GlobalVariable* valNameVal   = GetStringRef(module,valNameString.str());
277   GlobalVariable* typeNameVal  = GetStringRef(module,
278                                      val->getType()->getDescription().c_str());
279 #else
280   fmtString << " At exit of "
281             << ((isMethodExit)? "Method " : "BB ")
282             << scopeNameString.str() << " : "
283             << valNameString.str()   << " = "
284             << val->getType()->getDescription().c_str();
285 #endif DONT_EMBED_STRINGS_IN_FMT
286   
287   switch(valueType->getPrimitiveID())
288     {
289     case Type::BoolTyID:
290     case Type::UByteTyID: case Type::UShortTyID:
291     case Type::UIntTyID:  case Type::ULongTyID:
292     case Type::SByteTyID: case Type::ShortTyID:
293     case Type::IntTyID:   case Type::LongTyID:
294       fmtString << " %d\0A";
295       break;
296       
297     case Type::FloatTyID:     case Type::DoubleTyID:
298       fmtString << " %g\0A";
299       break;
300       
301     case Type::PointerTyID:
302       fmtString << " %p\0A";
303       break;
304       
305     default:
306       assert(0 && "Should not get here.  Check the IF expression above");
307       return NULL;
308     }
309   
310   fmtString << ends;
311   GlobalVariable* fmtVal = GetStringRef(module, fmtString.str());
312   
313 #ifdef DONT_EMBED_STRINGS_IN_FMT
314   paramList.push_back(fmtVal);
315   paramList.push_back(scopeNameVal);
316   paramList.push_back(valNameVal);
317   paramList.push_back(typeNameVal);
318   paramList.push_back(val);
319 #else
320   paramList.push_back(fmtVal);
321   paramList.push_back(val);
322 #endif DONT_EMBED_STRINGS_IN_FMT
323   
324   free(fmtString.str());
325   free(scopeNameString.str());
326   free(valNameString.str());
327   
328   return new CallInst(printMethod, paramList);
329 }
330
331
332 void
333 InsertCodeToTraceValues(Method* method,
334                         bool traceBasicBlockExits,
335                         bool traceMethodExits)
336 {
337   vector<Value*> valuesToTraceInMethod;
338   Module* module = method->getParent();
339   BasicBlock* exitBB = NULL;
340   
341   if (method->isExternal() ||
342       (! traceBasicBlockExits && ! traceMethodExits))
343     return;
344   
345   if (traceMethodExits)
346     {
347       InsertCodeToShowMethodEntry(method->getEntryNode());
348 #ifdef TODO_LATER
349       exitBB = method->getExitNode();
350 #endif
351     }
352   
353   for (Method::iterator BI = method->begin(); BI != method->end(); ++BI)
354     {
355       BasicBlock* bb = *BI;
356       vector<Value*> valuesToTraceInBB;
357       FindValuesToTraceInBB(bb, valuesToTraceInBB);
358       
359       if (traceBasicBlockExits && bb != exitBB)
360         TraceValuesAtBBExit(valuesToTraceInBB, bb, module,
361                             /*indent*/ 4, /*isMethodExit*/ false);
362       
363       if (traceMethodExits)
364         valuesToTraceInMethod.insert(valuesToTraceInMethod.end(),
365                                      valuesToTraceInBB.begin(),
366                                      valuesToTraceInBB.end());
367     }
368   
369   if (traceMethodExits)
370     {
371       TraceValuesAtBBExit(valuesToTraceInMethod, exitBB, module,
372                           /*indent*/ 0, /*isMethodExit*/ true);
373       InsertCodeToShowMethodExit(exitBB);
374     }
375 }