Refix bugs, stop using deprecated strstream header
[oota-llvm.git] / lib / Transforms / Instrumentation / TraceValues.cpp
1 // $Id$
2 //***************************************************************************
3 // File:
4 //      TraceValues.cpp
5 // 
6 // Purpose:
7 //      Support for inserting LLVM code to print values at basic block
8 //      and method exits.  Also exports functions to create a call
9 //      "printf" instruction with one of the signatures listed below.
10 // 
11 // History:
12 //      10/11/01         -  Vikram Adve  -  Created
13 //**************************************************************************/
14
15
16 #include "llvm/Transforms/Instrumentation/TraceValues.h"
17 #include "llvm/GlobalVariable.h"
18 #include "llvm/ConstPoolVals.h"
19 #include "llvm/Type.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Instruction.h"
22 #include "llvm/iMemory.h"
23 #include "llvm/iTerminators.h"
24 #include "llvm/iOther.h"
25 #include "llvm/BasicBlock.h"
26 #include "llvm/Method.h"
27 #include "llvm/Module.h"
28 #include "llvm/SymbolTable.h"
29 #include "llvm/Assembly/Writer.h"
30 #include <sstream>
31
32 static inline GlobalVariable *GetStringRef(Module *M, const string &str) {
33   ConstPoolArray *Init = ConstPoolArray::get(str);
34   GlobalVariable *GV = new GlobalVariable(Init->getType(), /*Const*/true, Init);
35   M->getGlobalList().push_back(GV);
36   return GV;
37 }
38
39
40 static inline bool
41 TraceThisOpCode(unsigned opCode)
42 {
43   // Explicitly test for opCodes *not* to trace so that any new opcodes will
44   // be traced by default (VoidTy's are already excluded)
45   // 
46   return (opCode  < Instruction::FirstOtherOp &&
47           opCode != Instruction::Alloca &&
48           opCode != Instruction::PHINode &&
49           opCode != Instruction::Cast);
50 }
51
52 // 
53 // Check if this instruction has any uses outside its basic block
54 // 
55 static inline bool
56 LiveAtBBExit(Instruction* I)
57 {
58   BasicBlock* bb = I->getParent();
59   bool isLive = false;
60   for (Value::use_const_iterator U = I->use_begin(); U != I->use_end(); ++U)
61     {
62       const Instruction* userI = dyn_cast<Instruction>(*U);
63       if (userI == NULL || userI->getParent() != bb)
64         isLive = true;
65     }
66   
67   return isLive;
68 }
69
70
71 static void 
72 FindValuesToTraceInBB(BasicBlock* bb, vector<Instruction*>& valuesToTraceInBB)
73 {
74   for (BasicBlock::iterator II = bb->begin(); II != bb->end(); ++II)
75     if ((*II)->getOpcode() == Instruction::Store
76         || (LiveAtBBExit(*II) &&
77             (*II)->getType()->isPrimitiveType() && 
78             (*II)->getType() != Type::VoidTy &&
79             TraceThisOpCode((*II)->getOpcode())))
80       {
81         valuesToTraceInBB.push_back(*II);
82       }
83 }
84
85 #if 0  // Code is disabled for now
86 // 
87 // Let's save this code for future use; it has been tested and works:
88 // 
89 // The signatures of the printf methods supported are:
90 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  int      intValue)
91 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  unsigned uintValue)
92 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  float    floatValue)
93 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  double   doubleValue)
94 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  char*    stringValue)
95 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  void*    ptrValue)
96 // 
97 // The invocation should be:
98 //       call "printf"(fmt, bbName, valueName, valueTypeName, value).
99 // 
100 Value *GetPrintfMethodForType(Module* module, const Type* valueType)
101 {
102   PointerType *ubytePtrTy = PointerType::get(ArrayType::get(Type::UByteTy));
103   vector<const Type*> argTypesVec(4, ubytePtrTy);
104   argTypesVec.push_back(valueType);
105     
106   MethodType *printMethodTy = MethodType::get(Type::IntTy, argTypesVec,
107                                               /*isVarArg*/ false);
108   
109   SymbolTable *ST = module->getSymbolTable();
110   if (Value *Meth = ST->lookup(PointerType::get(printMethodTy), "printf"))
111     return Meth;
112
113   // Create a new method and add it to the module
114   Method *printMethod = new Method(printMethodTy, "printf");
115   module->getMethodList().push_back(printMethod);
116   
117   return printMethod;
118 }
119
120
121 Instruction*
122 CreatePrintfInstr(Value* val,
123                   const BasicBlock* bb,
124                   Module* module,
125                   unsigned int indent,
126                   bool isMethodExit)
127 {
128   ostringstream fmtString, scopeNameString, valNameString;
129   vector<Value*> paramList;
130   const Type* valueType = val->getType();
131   Method* printMethod = GetPrintfMethodForType(module, valueType);
132   
133   if (! valueType->isPrimitiveType() ||
134       valueType->getPrimitiveID() == Type::VoidTyID ||
135       valueType->getPrimitiveID() == Type::TypeTyID ||
136       valueType->getPrimitiveID() == Type::LabelTyID)
137     {
138       assert(0 && "Unsupported type for printing");
139       return NULL;
140     }
141   
142   const Value* scopeToUse = (isMethodExit)? (const Value*) bb->getParent()
143                                           : (const Value*) bb;
144   if (scopeToUse->hasName())
145     scopeNameString << scopeToUse->getName() << ends;
146   else
147     scopeNameString << scopeToUse << ends;
148   
149   if (val->hasName())
150     valNameString << val->getName() << ends;
151   else
152     valNameString << val << ends;
153     
154   for (unsigned i=0; i < indent; i++)
155     fmtString << " ";
156   
157   fmtString << " At exit of "
158             << ((isMethodExit)? "Method " : "BB ")
159             << "%s : val %s = %s ";
160   
161   GlobalVariable* scopeNameVal = GetStringRef(module, scopeNameString.str());
162   GlobalVariable* valNameVal   = GetStringRef(module,valNameString.str());
163   GlobalVariable* typeNameVal  = GetStringRef(module,
164                                      val->getType()->getDescription().c_str());
165   
166   switch(valueType->getPrimitiveID())
167     {
168     case Type::BoolTyID:
169     case Type::UByteTyID: case Type::UShortTyID:
170     case Type::UIntTyID:  case Type::ULongTyID:
171     case Type::SByteTyID: case Type::ShortTyID:
172     case Type::IntTyID:   case Type::LongTyID:
173       fmtString << " %d\0A";
174       break;
175       
176     case Type::FloatTyID:     case Type::DoubleTyID:
177       fmtString << " %g\0A";
178       break;
179       
180     case Type::PointerTyID:
181       fmtString << " %p\0A";
182       break;
183       
184     default:
185       assert(0 && "Should not get here.  Check the IF expression above");
186       return NULL;
187     }
188   
189   fmtString << ends;
190   GlobalVariable* fmtVal = GetStringRef(module, fmtString.str());
191   
192   paramList.push_back(fmtVal);
193   paramList.push_back(scopeNameVal);
194   paramList.push_back(valNameVal);
195   paramList.push_back(typeNameVal);
196   paramList.push_back(val);
197   
198   return new CallInst(printMethod, paramList);
199 }
200 #endif
201
202
203 // The invocation should be:
204 //       call "printVal"(value).
205 // 
206 static Value *GetPrintMethodForType(Module *Mod, const Type *VTy) {
207   MethodType *MTy = MethodType::get(Type::VoidTy, vector<const Type*>(1, VTy),
208                                     /*isVarArg*/ false);
209   
210   SymbolTable *ST = Mod->getSymbolTableSure();
211   if (Value *V = ST->lookup(PointerType::get(MTy), "printVal"))
212     return V;
213
214   // Create a new method and add it to the module
215   Method *M = new Method(MTy, "printVal");
216   Mod->getMethodList().push_back(M);
217   return M;
218 }
219
220
221 static void
222 InsertPrintInsts(Value *Val,
223                  BasicBlock* BB,
224                  BasicBlock::iterator &BBI,
225                  Module *Mod,
226                  unsigned int indent,
227                  bool isMethodExit)
228 {
229   const Type* ValTy = Val->getType();
230   
231   assert(ValTy->isPrimitiveType() &&
232          ValTy->getPrimitiveID() != Type::VoidTyID &&
233          ValTy->getPrimitiveID() != Type::TypeTyID &&
234          ValTy->getPrimitiveID() != Type::LabelTyID && 
235          "Unsupported type for printing");
236   
237   const Value* scopeToUse = 
238     isMethodExit ? (const Value*)BB->getParent() : (const Value*)BB;
239
240   // Create the marker string...
241   ostringstream scopeNameString;
242   WriteAsOperand(scopeNameString, scopeToUse) << " : ";
243   WriteAsOperand(scopeNameString, Val) << " = " << ends;
244   string fmtString(indent, ' ');
245   
246   fmtString += string(" At exit of") + scopeNameString.str();
247   
248   // Turn the marker string into a global variable...
249   GlobalVariable *fmtVal = GetStringRef(Mod, fmtString);
250   
251   // Insert the first print instruction to print the string flag:
252   Instruction *I = new CallInst(GetPrintMethodForType(Mod, fmtVal->getType()),
253                                 vector<Value*>(1, fmtVal));
254   BBI = BB->getInstList().insert(BBI, I)+1;
255
256   // Insert the next print instruction to print the value:
257   I = new CallInst(GetPrintMethodForType(Mod, ValTy),
258                    vector<Value*>(1, Val));
259   BBI = BB->getInstList().insert(BBI, I)+1;
260
261   // Print out a newline
262   fmtVal = GetStringRef(Mod, "\n");
263   I = new CallInst(GetPrintMethodForType(Mod, fmtVal->getType()),
264                    vector<Value*>(1, fmtVal));
265   BBI = BB->getInstList().insert(BBI, I)+1;
266 }
267
268
269 static LoadInst*
270 InsertLoadInst(StoreInst* storeInst,
271                BasicBlock *bb,
272                BasicBlock::iterator &BBI)
273 {
274   LoadInst* loadInst = new LoadInst(storeInst->getPtrOperand(),
275                                     storeInst->getIndexVec());
276   BBI = bb->getInstList().insert(BBI, loadInst) + 1;
277   return loadInst;
278 }
279
280
281 // 
282 // Insert print instructions at the end of the basic block *bb
283 // for each value in valueVec[] that is live at the end of that basic block,
284 // or that is stored to memory in this basic block.
285 // If the value is stored to memory, we load it back before printing
286 // We also return all such loaded values in the vector valuesStoredInMethod
287 // for printing at the exit from the method.  (Note that in each invocation
288 // of the method, this will only get the last value stored for each static
289 // store instruction).
290 // *bb must be the block in which the value is computed;
291 // this is not checked here.
292 // 
293 static void
294 TraceValuesAtBBExit(const vector<Instruction*>& valueVec,
295                     BasicBlock* bb,
296                     Module* module,
297                     unsigned int indent,
298                     bool isMethodExit,
299                     vector<Instruction*>* valuesStoredInMethod)
300 {
301   // Get an iterator to point to the insertion location
302   // 
303   BasicBlock::InstListType& instList = bb->getInstList();
304   BasicBlock::iterator here = instList.end()-1;
305   assert((*here)->isTerminator());
306   
307   // Insert a print instruction for each value.
308   // 
309   for (unsigned i=0, N=valueVec.size(); i < N; i++)
310     {
311       Instruction* I = valueVec[i];
312       if (I->getOpcode() == Instruction::Store)
313         {
314           assert(valuesStoredInMethod != NULL &&
315                  "Should not be printing a store instruction at method exit");
316           I = InsertLoadInst((StoreInst*) I, bb, here);
317           valuesStoredInMethod->push_back(I);
318         }
319       InsertPrintInsts(I, bb, here, module, indent, isMethodExit);
320     }
321 }
322
323
324
325 static Instruction*
326 CreateMethodTraceInst(Method* method,
327                       unsigned int indent,
328                       const string& msg)
329 {
330   string fmtString(indent, ' ');
331   ostringstream methodNameString;
332   WriteAsOperand(methodNameString, method) << ends;
333   fmtString += msg + methodNameString.str();
334   
335   GlobalVariable *fmtVal = GetStringRef(method->getParent(), fmtString);
336   Instruction *printInst =
337     new CallInst(GetPrintMethodForType(method->getParent(), fmtVal->getType()),
338                  vector<Value*>(1, fmtVal));
339
340   return printInst;
341 }
342
343
344 static inline void
345 InsertCodeToShowMethodEntry(Method* method,
346                             BasicBlock* entryBB,
347                             unsigned int indent)
348 {
349   // Get an iterator to point to the insertion location
350   BasicBlock::InstListType& instList = entryBB->getInstList();
351   BasicBlock::iterator here = instList.begin();
352   
353   Instruction *printInst = CreateMethodTraceInst(method, indent, 
354                                                  "Entering Method"); 
355   
356   here = entryBB->getInstList().insert(here, printInst) + 1;
357 }
358
359
360 static inline void
361 InsertCodeToShowMethodExit(Method* method,
362                            BasicBlock* exitBB,
363                            unsigned int indent)
364 {
365   // Get an iterator to point to the insertion location
366   BasicBlock::InstListType& instList = exitBB->getInstList();
367   BasicBlock::iterator here = instList.end()-1;
368   assert((*here)->isTerminator());
369   
370   Instruction *printInst = CreateMethodTraceInst(method, indent,
371                                                  "Leaving Method"); 
372   
373   exitBB->getInstList().insert(here, printInst) + 1;
374 }
375
376
377 //************************** External Functions ****************************/
378
379
380 bool
381 InsertTraceCode::doInsertTraceCode(Method *M,
382                                    bool traceBasicBlockExits,
383                                    bool traceMethodExits)
384 {
385   vector<Instruction*> valuesStoredInMethod;
386   Module* module = M->getParent();
387   vector<BasicBlock*> exitBlocks;
388
389   if (M->isExternal() ||
390       (! traceBasicBlockExits && ! traceMethodExits))
391     return false;
392
393   if (traceMethodExits)
394     InsertCodeToShowMethodEntry(M, M->getEntryNode(), /*indent*/ 0);
395   
396   for (Method::iterator BI = M->begin(); BI != M->end(); ++BI)
397     {
398       BasicBlock* bb = *BI;
399       bool isExitBlock = false;
400       vector<Instruction*> valuesToTraceInBB;
401       
402       FindValuesToTraceInBB(bb, valuesToTraceInBB);
403       
404       if (bb->succ_begin() == bb->succ_end())
405         { // record this as an exit block
406           exitBlocks.push_back(bb);
407           isExitBlock = true;
408         }
409       
410       if (traceBasicBlockExits)
411         TraceValuesAtBBExit(valuesToTraceInBB, bb, module,
412                             /*indent*/ 4, /*isMethodExit*/ false,
413                             &valuesStoredInMethod);
414     }
415
416   if (traceMethodExits)
417     for (unsigned i=0; i < exitBlocks.size(); ++i)
418       {
419         TraceValuesAtBBExit(valuesStoredInMethod, exitBlocks[i], module,
420                             /*indent*/ 0, /*isMethodExit*/ true,
421                             /*valuesStoredInMethod*/ NULL);
422         InsertCodeToShowMethodExit(M, exitBlocks[i], /*indent*/ 0);
423       }
424
425   return true;
426 }