Add support for the Switch instruction by running the lowerSwitch pass first
[oota-llvm.git] / lib / Transforms / IPO / PoolAllocate.cpp
1 //===-- PoolAllocate.cpp - Pool Allocation Pass ---------------------------===//
2 //
3 // This transform changes programs so that disjoint data structures are
4 // allocated out of different pools of memory, increasing locality.
5 //
6 //===----------------------------------------------------------------------===//
7
8 #include "llvm/Transforms/PoolAllocate.h"
9 #include "llvm/Transforms/Utils/Cloning.h"
10 #include "llvm/Analysis/DataStructure.h"
11 #include "llvm/Analysis/DSGraph.h"
12 #include "llvm/Module.h"
13 #include "llvm/DerivedTypes.h"
14 #include "llvm/Constants.h"
15 #include "llvm/Instructions.h"
16 #include "llvm/Target/TargetData.h"
17 #include "llvm/Support/InstVisitor.h"
18 #include "Support/Statistic.h"
19 #include "Support/VectorExtras.h"
20
21 using namespace PA;
22
23 namespace {
24   const Type *VoidPtrTy = PointerType::get(Type::SByteTy);
25   // The type to allocate for a pool descriptor: { sbyte*, uint }
26   const Type *PoolDescType =
27     StructType::get(make_vector<const Type*>(VoidPtrTy, Type::UIntTy, 0));
28   const PointerType *PoolDescPtr = PointerType::get(PoolDescType);
29
30   RegisterOpt<PoolAllocate>
31   X("poolalloc", "Pool allocate disjoint data structures");
32 }
33
34 void PoolAllocate::getAnalysisUsage(AnalysisUsage &AU) const {
35   AU.addRequired<BUDataStructures>();
36   AU.addRequired<TargetData>();
37 }
38
39 bool PoolAllocate::run(Module &M) {
40   if (M.begin() == M.end()) return false;
41   CurModule = &M;
42   
43   AddPoolPrototypes();
44   BU = &getAnalysis<BUDataStructures>();
45
46   std::map<Function*, Function*> FuncMap;
47
48   // Loop over only the function initially in the program, don't traverse newly
49   // added ones.  If the function uses memory, make its clone.
50   Module::iterator LastOrigFunction = --M.end();
51   for (Module::iterator I = M.begin(); ; ++I) {
52     if (!I->isExternal())
53       if (Function *R = MakeFunctionClone(*I))
54         FuncMap[I] = R;
55     if (I == LastOrigFunction) break;
56   }
57
58   ++LastOrigFunction;
59
60   // Now that all call targets are available, rewrite the function bodies of the
61   // clones.
62   for (Module::iterator I = M.begin(); I != LastOrigFunction; ++I)
63     if (!I->isExternal()) {
64       std::map<Function*, Function*>::iterator FI = FuncMap.find(I);
65       ProcessFunctionBody(*I, FI != FuncMap.end() ? *FI->second : *I);
66     }
67
68   FunctionInfo.clear();
69   return true;
70 }
71
72
73 // AddPoolPrototypes - Add prototypes for the pool functions to the specified
74 // module and update the Pool* instance variables to point to them.
75 //
76 void PoolAllocate::AddPoolPrototypes() {
77   CurModule->addTypeName("PoolDescriptor", PoolDescType);
78
79   // Get poolinit function...
80   FunctionType *PoolInitTy =
81     FunctionType::get(Type::VoidTy,
82                       make_vector<const Type*>(PoolDescPtr, Type::UIntTy, 0),
83                       false);
84   PoolInit = CurModule->getOrInsertFunction("poolinit", PoolInitTy);
85
86   // Get pooldestroy function...
87   std::vector<const Type*> PDArgs(1, PoolDescPtr);
88   FunctionType *PoolDestroyTy =
89     FunctionType::get(Type::VoidTy, PDArgs, false);
90   PoolDestroy = CurModule->getOrInsertFunction("pooldestroy", PoolDestroyTy);
91
92   // Get the poolalloc function...
93   FunctionType *PoolAllocTy = FunctionType::get(VoidPtrTy, PDArgs, false);
94   PoolAlloc = CurModule->getOrInsertFunction("poolalloc", PoolAllocTy);
95
96   // Get the poolfree function...
97   PDArgs.push_back(VoidPtrTy);       // Pointer to free
98   FunctionType *PoolFreeTy = FunctionType::get(Type::VoidTy, PDArgs, false);
99   PoolFree = CurModule->getOrInsertFunction("poolfree", PoolFreeTy);
100
101 #if 0
102   Args[0] = Type::UIntTy;            // Number of slots to allocate
103   FunctionType *PoolAllocArrayTy = FunctionType::get(VoidPtrTy, Args, true);
104   PoolAllocArray = CurModule->getOrInsertFunction("poolallocarray",
105                                                   PoolAllocArrayTy);
106 #endif
107 }
108
109
110 // MakeFunctionClone - If the specified function needs to be modified for pool
111 // allocation support, make a clone of it, adding additional arguments as
112 // neccesary, and return it.  If not, just return null.
113 //
114 Function *PoolAllocate::MakeFunctionClone(Function &F) {
115   DSGraph &G = BU->getDSGraph(F);
116   std::vector<DSNode*> &Nodes = G.getNodes();
117   if (Nodes.empty()) return 0;  // No memory activity, nothing is required
118
119   FuncInfo &FI = FunctionInfo[&F];   // Create a new entry for F
120   FI.Clone = 0;
121
122   // Find DataStructure nodes which are allocated in pools non-local to the
123   // current function.  This set will contain all of the DSNodes which require
124   // pools to be passed in from outside of the function.
125   hash_set<DSNode*> &MarkedNodes = FI.MarkedNodes;
126
127   // Mark globals and incomplete nodes as live... (this handles arguments)
128   if (F.getName() != "main")
129     for (unsigned i = 0, e = Nodes.size(); i != e; ++i)
130       if (Nodes[i]->NodeType & (DSNode::GlobalNode | DSNode::Incomplete) &&
131           Nodes[i]->NodeType & (DSNode::HeapNode))
132         Nodes[i]->markReachableNodes(MarkedNodes);
133
134   // Marked the returned node as alive...
135   if (DSNode *RetNode = G.getRetNode().getNode())
136     if (RetNode->NodeType & DSNode::HeapNode)
137       RetNode->markReachableNodes(MarkedNodes);
138
139   if (MarkedNodes.empty())   // We don't need to clone the function if there
140     return 0;                // are no incoming arguments to be added.
141
142   // Figure out what the arguments are to be for the new version of the function
143   const FunctionType *OldFuncTy = F.getFunctionType();
144   std::vector<const Type*> ArgTys;
145   ArgTys.reserve(OldFuncTy->getParamTypes().size() + MarkedNodes.size());
146
147   FI.ArgNodes.reserve(MarkedNodes.size());
148   for (hash_set<DSNode*>::iterator I = MarkedNodes.begin(),
149          E = MarkedNodes.end(); I != E; ++I)
150     if ((*I)->NodeType & DSNode::Incomplete) {
151       ArgTys.push_back(PoolDescPtr);      // Add the appropriate # of pool descs
152       FI.ArgNodes.push_back(*I);
153     }
154   if (FI.ArgNodes.empty()) return 0;      // No nodes to be pool allocated!
155
156   ArgTys.insert(ArgTys.end(), OldFuncTy->getParamTypes().begin(),
157                 OldFuncTy->getParamTypes().end());
158
159
160   // Create the new function prototype
161   FunctionType *FuncTy = FunctionType::get(OldFuncTy->getReturnType(), ArgTys,
162                                            OldFuncTy->isVarArg());
163   // Create the new function...
164   Function *New = new Function(FuncTy, GlobalValue::InternalLinkage,
165                                F.getName(), F.getParent());
166
167   // Set the rest of the new arguments names to be PDa<n> and add entries to the
168   // pool descriptors map
169   std::map<DSNode*, Value*> &PoolDescriptors = FI.PoolDescriptors;
170   Function::aiterator NI = New->abegin();
171   for (unsigned i = 0, e = FI.ArgNodes.size(); i != e; ++i, ++NI) {
172     NI->setName("PDa");  // Add pd entry
173     PoolDescriptors.insert(std::make_pair(FI.ArgNodes[i], NI));
174   }
175
176   // Map the existing arguments of the old function to the corresponding
177   // arguments of the new function.
178   std::map<const Value*, Value*> ValueMap;
179   for (Function::aiterator I = F.abegin(), E = F.aend(); I != E; ++I, ++NI) {
180     ValueMap[I] = NI;
181     NI->setName(I->getName());
182   }
183
184   // Populate the value map with all of the globals in the program.
185   // FIXME: This should be unneccesary!
186   Module &M = *F.getParent();
187   for (Module::iterator I = M.begin(), E=M.end(); I!=E; ++I)    ValueMap[I] = I;
188   for (Module::giterator I = M.gbegin(), E=M.gend(); I!=E; ++I) ValueMap[I] = I;
189
190   // Perform the cloning.
191   std::vector<ReturnInst*> Returns;
192   CloneFunctionInto(New, &F, ValueMap, Returns);
193
194   // Invert the ValueMap into the NewToOldValueMap
195   std::map<Value*, const Value*> &NewToOldValueMap = FI.NewToOldValueMap;
196   for (std::map<const Value*, Value*>::iterator I = ValueMap.begin(),
197          E = ValueMap.end(); I != E; ++I)
198     NewToOldValueMap.insert(std::make_pair(I->second, I->first));
199   
200   return FI.Clone = New;
201 }
202
203
204 // processFunction - Pool allocate any data structures which are contained in
205 // the specified function...
206 //
207 void PoolAllocate::ProcessFunctionBody(Function &F, Function &NewF) {
208   DSGraph &G = BU->getDSGraph(F);
209   std::vector<DSNode*> &Nodes = G.getNodes();
210   if (Nodes.empty()) return;     // Quick exit if nothing to do...
211
212   FuncInfo &FI = FunctionInfo[&F];   // Get FuncInfo for F
213   hash_set<DSNode*> &MarkedNodes = FI.MarkedNodes;
214  
215   DEBUG(std::cerr << "[" << F.getName() << "] Pool Allocate: ");
216
217   // Loop over all of the nodes which are non-escaping, adding pool-allocatable
218   // ones to the NodesToPA vector.
219   std::vector<DSNode*> NodesToPA;
220   for (unsigned i = 0, e = Nodes.size(); i != e; ++i)
221     if (Nodes[i]->NodeType & DSNode::HeapNode &&   // Pick nodes with heap elems
222         !(Nodes[i]->NodeType & DSNode::Array) &&   // Doesn't handle arrays yet.
223         !MarkedNodes.count(Nodes[i]))              // Can't be marked
224       NodesToPA.push_back(Nodes[i]);
225
226   DEBUG(std::cerr << NodesToPA.size() << " nodes to pool allocate\n");
227   if (!NodesToPA.empty()) {
228     // Create pool construction/destruction code
229     std::map<DSNode*, Value*> &PoolDescriptors = FI.PoolDescriptors;
230     CreatePools(NewF, NodesToPA, PoolDescriptors);
231   }
232
233   // Transform the body of the function now...
234   TransformFunctionBody(NewF, G, FI);
235 }
236
237
238 // CreatePools - This creates the pool initialization and destruction code for
239 // the DSNodes specified by the NodesToPA list.  This adds an entry to the
240 // PoolDescriptors map for each DSNode.
241 //
242 void PoolAllocate::CreatePools(Function &F,
243                                const std::vector<DSNode*> &NodesToPA,
244                                std::map<DSNode*, Value*> &PoolDescriptors) {
245   // Find all of the return nodes in the CFG...
246   std::vector<BasicBlock*> ReturnNodes;
247   for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I)
248     if (isa<ReturnInst>(I->getTerminator()))
249       ReturnNodes.push_back(I);
250
251   TargetData &TD = getAnalysis<TargetData>();
252
253   // Loop over all of the pools, inserting code into the entry block of the
254   // function for the initialization and code in the exit blocks for
255   // destruction.
256   //
257   Instruction *InsertPoint = F.front().begin();
258   for (unsigned i = 0, e = NodesToPA.size(); i != e; ++i) {
259     DSNode *Node = NodesToPA[i];
260
261     // Create a new alloca instruction for the pool...
262     Value *AI = new AllocaInst(PoolDescType, 0, "PD", InsertPoint);
263
264     Value *ElSize =
265       ConstantUInt::get(Type::UIntTy, TD.getTypeSize(Node->getType()));
266
267     // Insert the call to initialize the pool...
268     new CallInst(PoolInit, make_vector(AI, ElSize, 0), "", InsertPoint);
269
270     // Update the PoolDescriptors map
271     PoolDescriptors.insert(std::make_pair(Node, AI));
272
273     // Insert a call to pool destroy before each return inst in the function
274     for (unsigned r = 0, e = ReturnNodes.size(); r != e; ++r)
275       new CallInst(PoolDestroy, make_vector(AI, 0), "",
276                    ReturnNodes[r]->getTerminator());
277   }
278 }
279
280
281 namespace {
282   /// FuncTransform - This class implements transformation required of pool
283   /// allocated functions.
284   struct FuncTransform : public InstVisitor<FuncTransform> {
285     PoolAllocate &PAInfo;
286     DSGraph &G;
287     FuncInfo &FI;
288
289     FuncTransform(PoolAllocate &P, DSGraph &g, FuncInfo &fi)
290       : PAInfo(P), G(g), FI(fi) {}
291
292     void visitMallocInst(MallocInst &MI);
293     void visitFreeInst(FreeInst &FI);
294     void visitCallInst(CallInst &CI);
295
296   private:
297     DSNode *getDSNodeFor(Value *V) {
298       if (!FI.NewToOldValueMap.empty()) {
299         // If the NewToOldValueMap is in effect, use it.
300         std::map<Value*,const Value*>::iterator I = FI.NewToOldValueMap.find(V);
301         if (I != FI.NewToOldValueMap.end())
302           V = (Value*)I->second;
303       }
304
305       return G.getScalarMap()[V].getNode();
306     }
307     Value *getPoolHandle(Value *V) {
308       DSNode *Node = getDSNodeFor(V);
309       // Get the pool handle for this DSNode...
310       std::map<DSNode*, Value*>::iterator I = FI.PoolDescriptors.find(Node);
311       return I != FI.PoolDescriptors.end() ? I->second : 0;
312     }
313   };
314 }
315
316 void PoolAllocate::TransformFunctionBody(Function &F, DSGraph &G, FuncInfo &FI){
317   FuncTransform(*this, G, FI).visit(F);
318 }
319
320
321 void FuncTransform::visitMallocInst(MallocInst &MI) {
322   // Get the pool handle for the node that this contributes to...
323   Value *PH = getPoolHandle(&MI);
324   if (PH == 0) return;
325   
326   // Insert a call to poolalloc
327   Value *V = new CallInst(PAInfo.PoolAlloc, make_vector(PH, 0),
328                           MI.getName(), &MI);
329   MI.setName("");  // Nuke MIs name
330   
331   // Cast to the appropriate type...
332   Value *Casted = new CastInst(V, MI.getType(), V->getName(), &MI);
333   
334   // Update def-use info
335   MI.replaceAllUsesWith(Casted);
336   
337   // Remove old malloc instruction
338   MI.getParent()->getInstList().erase(&MI);
339   
340   hash_map<Value*, DSNodeHandle> &SM = G.getScalarMap();
341   hash_map<Value*, DSNodeHandle>::iterator MII = SM.find(&MI);
342   
343   // If we are modifying the original function, update the DSGraph... 
344   if (MII != SM.end()) {
345     // V and Casted now point to whatever the original malloc did...
346     SM.insert(std::make_pair(V, MII->second));
347     SM.insert(std::make_pair(Casted, MII->second));
348     SM.erase(MII);                     // The malloc is now destroyed
349   } else {             // Otherwise, update the NewToOldValueMap
350     std::map<Value*,const Value*>::iterator MII =
351       FI.NewToOldValueMap.find(&MI);
352     assert(MII != FI.NewToOldValueMap.end() && "MI not found in clone?");
353     FI.NewToOldValueMap.insert(std::make_pair(V, MII->second));
354     FI.NewToOldValueMap.insert(std::make_pair(Casted, MII->second));
355     FI.NewToOldValueMap.erase(MII);
356   }
357 }
358
359 void FuncTransform::visitFreeInst(FreeInst &FI) {
360   Value *Arg = FI.getOperand(0);
361   Value *PH = getPoolHandle(Arg);  // Get the pool handle for this DSNode...
362   if (PH == 0) return;
363   // Insert a cast and a call to poolfree...
364   Value *Casted = new CastInst(Arg, PointerType::get(Type::SByteTy),
365                                Arg->getName()+".casted", &FI);
366   new CallInst(PAInfo.PoolFree, make_vector(PH, Casted, 0), "", &FI);
367   
368   // Delete the now obsolete free instruction...
369   FI.getParent()->getInstList().erase(&FI);
370 }
371
372 static void CalcNodeMapping(DSNode *Caller, DSNode *Callee,
373                             std::map<DSNode*, DSNode*> &NodeMapping) {
374   if (Callee == 0) return;
375   assert(Caller && "Callee has node but caller doesn't??");
376
377   std::map<DSNode*, DSNode*>::iterator I = NodeMapping.find(Callee);
378   if (I != NodeMapping.end()) {   // Node already in map...
379     assert(I->second == Caller && "Node maps to different nodes on paths?");
380   } else {
381     NodeMapping.insert(I, std::make_pair(Callee, Caller));
382     
383     // Recursively add pointed to nodes...
384     for (unsigned i = 0, e = Callee->getNumLinks(); i != e; ++i)
385       CalcNodeMapping(Caller->getLink(i << DS::PointerShift).getNode(),
386                       Callee->getLink(i << DS::PointerShift).getNode(),
387                       NodeMapping);
388   }
389 }
390
391 void FuncTransform::visitCallInst(CallInst &CI) {
392   Function *CF = CI.getCalledFunction();
393   assert(CF && "FIXME: Pool allocation doesn't handle indirect calls!");
394
395   FuncInfo *CFI = PAInfo.getFuncInfo(*CF);
396   if (CFI == 0 || CFI->Clone == 0) return;  // Nothing to transform...
397
398   DEBUG(std::cerr << "  Handling call: " << CI);
399
400   DSGraph &CG = PAInfo.getBUDataStructures().getDSGraph(*CF);  // Callee graph
401
402   // We need to figure out which local pool descriptors correspond to the pool
403   // descriptor arguments passed into the function call.  Calculate a mapping
404   // from callee DSNodes to caller DSNodes.  We construct a partial isomophism
405   // between the graphs to figure out which pool descriptors need to be passed
406   // in.  The roots of this mapping is found from arguments and return values.
407   //
408   std::map<DSNode*, DSNode*> NodeMapping;
409
410   Function::aiterator AI = CF->abegin(), AE = CF->aend();
411   unsigned OpNum = 1;
412   for (; AI != AE; ++AI, ++OpNum)
413     CalcNodeMapping(getDSNodeFor(CI.getOperand(OpNum)),
414                     CG.getScalarMap()[AI].getNode(), NodeMapping);
415   assert(OpNum == CI.getNumOperands() && "Varargs calls not handled yet!");
416   
417   // Map the return value as well...
418   CalcNodeMapping(getDSNodeFor(&CI), CG.getRetNode().getNode(), NodeMapping);
419
420
421   // Okay, now that we have established our mapping, we can figure out which
422   // pool descriptors to pass in...
423   std::vector<Value*> Args;
424
425   // Add an argument for each pool which must be passed in...
426   for (unsigned i = 0, e = CFI->ArgNodes.size(); i != e; ++i) {
427     if (NodeMapping.count(CFI->ArgNodes[i])) {
428       assert(NodeMapping.count(CFI->ArgNodes[i]) && "Node not in mapping!");
429       DSNode *LocalNode = NodeMapping.find(CFI->ArgNodes[i])->second;
430       assert(FI.PoolDescriptors.count(LocalNode) && "Node not pool allocated?");
431       Args.push_back(FI.PoolDescriptors.find(LocalNode)->second);
432     } else {
433       Args.push_back(Constant::getNullValue(PoolDescPtr));
434     }
435   }
436
437   // Add the rest of the arguments...
438   Args.insert(Args.end(), CI.op_begin()+1, CI.op_end());
439
440   std::string Name = CI.getName(); CI.setName("");
441   Value *NewCall = new CallInst(CFI->Clone, Args, Name, &CI);
442   CI.replaceAllUsesWith(NewCall);
443
444   DEBUG(std::cerr << "  Result Call: " << *NewCall);
445   CI.getParent()->getInstList().erase(&CI);
446 }