Insert code to trace values at basic block and method exits.
[oota-llvm.git] / lib / Transforms / Instrumentation / TraceValues.cpp
1 // $Id$
2 //***************************************************************************
3 // File:
4 //      TraceValues.cpp
5 // 
6 // Purpose:
7 //      Support for inserting LLVM code to print values at basic block
8 //      and method exits.  Also exports functions to create a call
9 //      "printf" instruction with one of the signatures listed below.
10 // 
11 // History:
12 //      10/11/01         -  Vikram Adve  -  Created
13 //**************************************************************************/
14
15
16 #include "llvm/Transforms/Instrumentation/TraceValues.h"
17 #include "llvm/GlobalVariable.h"
18 #include "llvm/ConstPoolVals.h"
19 #include "llvm/Type.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Instruction.h"
22 #include "llvm/iTerminators.h"
23 #include "llvm/iOther.h"
24 #include "llvm/BasicBlock.h"
25 #include "llvm/Method.h"
26 #include "llvm/Module.h"
27 #include "llvm/SymbolTable.h"
28 #include "llvm/Support/HashExtras.h"
29 #include <hash_map>
30 #include <vector>
31 #include <strstream.h>
32
33
34 //*********************** Internal Data Structures *************************/
35
36 const char* const PRINTF = "printf";
37
38 #undef DONT_EMBED_STRINGS_IN_FMT
39
40
41 //************************** Internal Functions ****************************/
42
43 #undef USE_PTRREF
44 #ifdef USE_PTRREF
45 inline ConstPoolPointerReference*
46 GetStringRef(Module* module, const char* str)
47 {
48   static hash_map<string, ConstPoolPointerReference*> stringRefCache;
49   static Module* lastModule = NULL;
50   
51   if (lastModule != module)
52     { // Let's make sure we create separate global references in each module
53       stringRefCache.clear();
54       lastModule = module;
55     }
56   
57   ConstPoolPointerReference* result = stringRefCache[str];
58   if (result == NULL)
59     {
60       ConstPoolArray* charArray = ConstPoolArray::get(str);
61       GlobalVariable* stringVar =
62         new GlobalVariable(charArray->getType(),/*isConst*/true,charArray,str);
63       module->getGlobalList().push_back(stringVar);
64       result = ConstPoolPointerReference::get(stringVar);
65       assert(result && "Failed to create reference to string constant");
66       stringRefCache[str] = result;
67     }
68   
69   return result;
70 }
71 #endif USE_PTRREF
72
73 inline GlobalVariable*
74 GetStringRef(Module* module, const char* str)
75 {
76   static hash_map<string, GlobalVariable*> stringRefCache;
77   static Module* lastModule = NULL;
78   
79   if (lastModule != module)
80     { // Let's make sure we create separate global references in each module
81       stringRefCache.clear();
82       lastModule = module;
83     }
84   
85   GlobalVariable* result = stringRefCache[str];
86   if (result == NULL)
87     {
88       ConstPoolArray* charArray = ConstPoolArray::get(str);
89       GlobalVariable* stringVar =
90         new GlobalVariable(charArray->getType(),/*isConst*/true,charArray);
91       module->getGlobalList().push_back(stringVar);
92       result = stringVar;
93       // result = ConstPoolPointerReference::get(stringVar);
94       assert(result && "Failed to create reference to string constant");
95       stringRefCache[str] = result;
96     }
97   
98   return result;
99 }
100
101
102 inline bool
103 TraceThisOpCode(unsigned opCode)
104 {
105   // Explicitly test for opCodes *not* to trace so that any new opcodes will
106   // be traced by default (or will fail in a later assertion on VoidTy)
107   // 
108   return (opCode  < Instruction::FirstOtherOp &&
109           opCode != Instruction::Ret &&
110           opCode != Instruction::Br &&
111           opCode != Instruction::Switch &&
112           opCode != Instruction::Free &&
113           opCode != Instruction::Alloca &&
114           opCode != Instruction::Store &&
115           opCode != Instruction::PHINode &&
116           opCode != Instruction::Cast);
117 }
118
119
120 static void
121 FindValuesToTraceInBB(BasicBlock* bb,
122                       vector<Value*>& valuesToTraceInBB)
123 {
124   for (BasicBlock::iterator II = bb->begin(); II != bb->end(); ++II)
125     if ((*II)->getType()->isPrimitiveType() &&
126         TraceThisOpCode((*II)->getOpcode()))
127       {
128         valuesToTraceInBB.push_back(*II);
129       }
130 }
131
132
133 // 
134 // Insert print instructions at the end of the basic block *bb
135 // for each value in valueVec[].  *bb must postdominate the block
136 // in which the value is computed; this is not checked here.
137 // 
138 static void
139 TraceValuesAtBBExit(const vector<Value*>& valueVec,
140                     BasicBlock* bb,
141                     Module* module,
142                     unsigned int indent,
143                     bool isMethodExit)
144 {
145   // Get an iterator to point to the insertion location
146   // 
147   BasicBlock::InstListType& instList = bb->getInstList();
148   TerminatorInst* termInst = bb->getTerminator(); 
149   BasicBlock::InstListType::iterator here = instList.end();
150   while ((*here) != termInst && here != instList.begin())
151     --here;
152   assert((*here) == termInst);
153   
154   // Insert a print instruction for each value.
155   // 
156   for (unsigned i=0, N=valueVec.size(); i < N; i++)
157     {
158       Instruction* traceInstr =
159         CreatePrintInstr(valueVec[i], bb, module, indent, isMethodExit);
160       here = instList.insert(here, traceInstr);
161     }
162 }
163
164 void
165 InsertCodeToShowMethodEntry(BasicBlock* entryBB)
166 {
167 }
168
169 void
170 InsertCodeToShowMethodExit(BasicBlock* exitBB)
171 {
172 }
173
174
175 //************************** External Functions ****************************/
176
177 // 
178 // The signatures of the print methods supported are:
179 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  int      intValue)
180 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  unsigned uintValue)
181 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  float    floatValue)
182 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  double   doubleValue)
183 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  char*    stringValue)
184 //   int printf(ubyte*,  ubyte*,  ubyte*,  ubyte*,  void*    ptrValue)
185 // 
186 // The invocation should be:
187 //       call "printf"(fmt, bbName, valueName, valueTypeName, value).
188 // 
189 Method*
190 GetPrintMethodForType(Module* module, const Type* valueType)
191 {
192 #ifdef DONT_EMBED_STRINGS_IN_FMT
193   static const int LASTARGINDEX = 4;
194 #else
195   static const int LASTARGINDEX = 1;
196 #endif
197   static PointerType* ubytePtrTy = NULL;
198   static vector<const Type*> argTypesVec(LASTARGINDEX + 1);
199   
200   if (ubytePtrTy == NULL)
201     { // create these once since they are invariant
202       ubytePtrTy = PointerType::get(ArrayType::get(Type::UByteTy));
203       argTypesVec[0] = ubytePtrTy;
204 #ifdef DONT_EMBED_STRINGS_IN_FMT
205       argTypesVec[1] = ubytePtrTy;
206       argTypesVec[2] = ubytePtrTy;
207       argTypesVec[3] = ubytePtrTy;
208 #endif DONT_EMBED_STRINGS_IN_FMT
209     }
210   
211   SymbolTable* symtab = module->getSymbolTable();
212   argTypesVec[LASTARGINDEX] = valueType;
213   MethodType* printMethodTy = MethodType::get(Type::IntTy, argTypesVec,
214                                               /*isVarArg*/ false);
215   
216   Method* printMethod =
217     cast<Method>(symtab->lookup(PointerType::get(printMethodTy), PRINTF));
218   if (printMethod == NULL)
219     { // Create a new method and add it to the module
220       printMethod = new Method(printMethodTy, PRINTF);
221       module->getMethodList().push_back(printMethod);
222       
223       // Create the argument list for the method so that the full signature
224       // can be declared.  The args can be anonymous.
225       Method::ArgumentListType &argList = printMethod->getArgumentList();
226       for (unsigned i=0; i < argTypesVec.size(); ++i)
227         argList.push_back(new MethodArgument(argTypesVec[i]));
228     }
229   
230   return printMethod;
231 }
232
233
234 Instruction*
235 CreatePrintInstr(Value* val,
236                  const BasicBlock* bb,
237                  Module* module,
238                  unsigned int indent,
239                  bool isMethodExit)
240 {
241   strstream fmtString, scopeNameString, valNameString;
242   vector<Value*> paramList;
243   const Type* valueType = val->getType();
244   Method* printMethod = GetPrintMethodForType(module, valueType);
245   
246   if (! valueType->isPrimitiveType() ||
247       valueType->getPrimitiveID() == Type::VoidTyID ||
248       valueType->getPrimitiveID() == Type::TypeTyID ||
249       valueType->getPrimitiveID() == Type::LabelTyID)
250     {
251       assert(0 && "Unsupported type for printing");
252       return NULL;
253     }
254   
255   const Value* scopeToUse = (isMethodExit)? (const Value*) bb->getParent()
256                                           : (const Value*) bb;
257   if (scopeToUse->hasName())
258     scopeNameString << scopeToUse->getName() << ends;
259   else
260     scopeNameString << scopeToUse << ends;
261   
262   if (val->hasName())
263     valNameString << val->getName() << ends;
264   else
265     valNameString << val << ends;
266     
267   for (unsigned i=0; i < indent; i++)
268     fmtString << " ";
269   
270 #undef DONT_EMBED_STRINGS_IN_FMT
271 #ifdef DONT_EMBED_STRINGS_IN_FMT
272   fmtString << " At exit of "
273             << ((isMethodExit)? "Method " : "BB ")
274             << "%s : val %s = %s ";
275   
276   GlobalVariable* scopeNameVal = GetStringRef(module, scopeNameString.str());
277   GlobalVariable* valNameVal   = GetStringRef(module,valNameString.str());
278   GlobalVariable* typeNameVal  = GetStringRef(module,
279                                      val->getType()->getDescription().c_str());
280 #else
281   fmtString << " At exit of "
282             << ((isMethodExit)? "Method " : "BB ")
283             << scopeNameString.str() << " : "
284             << valNameString.str()   << " = "
285             << val->getType()->getDescription().c_str();
286 #endif DONT_EMBED_STRINGS_IN_FMT
287   
288   switch(valueType->getPrimitiveID())
289     {
290     case Type::BoolTyID:
291     case Type::UByteTyID: case Type::UShortTyID:
292     case Type::UIntTyID:  case Type::ULongTyID:
293     case Type::SByteTyID: case Type::ShortTyID:
294     case Type::IntTyID:   case Type::LongTyID:
295       fmtString << " %d\0A";
296       break;
297       
298     case Type::FloatTyID:     case Type::DoubleTyID:
299       fmtString << " %g\0A";
300       break;
301       
302     case Type::PointerTyID:
303       fmtString << " %p\0A";
304       break;
305       
306     default:
307       assert(0 && "Should not get here.  Check the IF expression above");
308       return NULL;
309     }
310   
311   fmtString << ends;
312   GlobalVariable* fmtVal = GetStringRef(module, fmtString.str());
313   
314 #ifdef DONT_EMBED_STRINGS_IN_FMT
315   paramList.push_back(fmtVal);
316   paramList.push_back(scopeNameVal);
317   paramList.push_back(valNameVal);
318   paramList.push_back(typeNameVal);
319   paramList.push_back(val);
320 #else
321   paramList.push_back(fmtVal);
322   paramList.push_back(val);
323 #endif DONT_EMBED_STRINGS_IN_FMT
324   
325   free(fmtString.str());
326   free(scopeNameString.str());
327   free(valNameString.str());
328   
329   return new CallInst(printMethod, paramList);
330 }
331
332
333 void
334 InsertCodeToTraceValues(Method* method,
335                         bool traceBasicBlockExits,
336                         bool traceMethodExits)
337 {
338   vector<Value*> valuesToTraceInMethod;
339   Module* module = method->getParent();
340   BasicBlock* exitBB = NULL;
341   
342   if (method->isExternal() ||
343       (! traceBasicBlockExits && ! traceMethodExits))
344     return;
345   
346   if (traceMethodExits)
347     {
348       InsertCodeToShowMethodEntry(method->getEntryNode());
349       exitBB = method->getExitNode();
350     }
351   
352   for (Method::iterator BI = method->begin(); BI != method->end(); ++BI)
353     {
354       BasicBlock* bb = *BI;
355       vector<Value*> valuesToTraceInBB;
356       FindValuesToTraceInBB(bb, valuesToTraceInBB);
357       
358       if (traceBasicBlockExits && bb != exitBB)
359         TraceValuesAtBBExit(valuesToTraceInBB, bb, module,
360                             /*indent*/ 4, /*isMethodExit*/ false);
361       
362       if (traceMethodExits)
363         valuesToTraceInMethod.insert(valuesToTraceInMethod.end(),
364                                      valuesToTraceInBB.begin(),
365                                      valuesToTraceInBB.end());
366     }
367   
368   if (traceMethodExits)
369     {
370       TraceValuesAtBBExit(valuesToTraceInMethod, exitBB, module,
371                           /*indent*/ 0, /*isMethodExit*/ true);
372       InsertCodeToShowMethodExit(exitBB);
373     }
374 }