Speed up the predicate used to decide when to inline by caching the size
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns --------------------------=//
2 //
3 // This pass is used to change structure accesses and type definitions in some
4 // way.  It can be used to arbitrarily permute structure fields, safely, without
5 // breaking code.  A transformation may only be done on a type if that type has
6 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
7 // assert and die if you try to do an illegal transformation.
8 //
9 // This is an interprocedural pass that requires the entire program to do a
10 // transformation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/MutateStructTypes.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Module.h"
17 #include "llvm/SymbolTable.h"
18 #include "llvm/Instructions.h"
19 #include "llvm/Constants.h"
20 #include "Support/STLExtras.h"
21 #include "Support/Debug.h"
22 #include <algorithm>
23
24 // ValuePlaceHolder - A stupid little marker value.  It appears as an
25 // instruction of type Instruction::UserOp1.
26 //
27 struct ValuePlaceHolder : public Instruction {
28   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
29
30   virtual Instruction *clone() const { abort(); return 0; }
31   virtual const char *getOpcodeName() const { return "placeholder"; }
32 };
33
34
35 // ConvertType - Convert from the old type system to the new one...
36 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
37   if (Ty->isPrimitiveType() ||
38       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
39
40   std::map<const Type *, PATypeHolder>::iterator I = TypeMap.find(Ty);
41   if (I != TypeMap.end()) return I->second;
42
43   const Type *DestTy = 0;
44
45   PATypeHolder PlaceHolder = OpaqueType::get();
46   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
47
48   switch (Ty->getPrimitiveID()) {
49   case Type::FunctionTyID: {
50     const FunctionType *MT = cast<FunctionType>(Ty);
51     const Type *RetTy = ConvertType(MT->getReturnType());
52     std::vector<const Type*> ArgTypes;
53
54     for (FunctionType::ParamTypes::const_iterator I = MT->getParamTypes().begin(),
55            E = MT->getParamTypes().end(); I != E; ++I)
56       ArgTypes.push_back(ConvertType(*I));
57     
58     DestTy = FunctionType::get(RetTy, ArgTypes, MT->isVarArg());
59     break;
60   }
61   case Type::StructTyID: {
62     const StructType *ST = cast<StructType>(Ty);
63     const StructType::ElementTypes &El = ST->getElementTypes();
64     std::vector<const Type *> Types;
65
66     for (StructType::ElementTypes::const_iterator I = El.begin(), E = El.end();
67          I != E; ++I)
68       Types.push_back(ConvertType(*I));
69     DestTy = StructType::get(Types);
70     break;
71   }
72   case Type::ArrayTyID:
73     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
74                             cast<ArrayType>(Ty)->getNumElements());
75     break;
76
77   case Type::PointerTyID:
78     DestTy = PointerType::get(
79                  ConvertType(cast<PointerType>(Ty)->getElementType()));
80     break;
81   default:
82     assert(0 && "Unknown type!");
83     return 0;
84   }
85
86   assert(DestTy && "Type didn't get created!?!?");
87
88   // Refine our little placeholder value into a real type...
89   ((DerivedType*)PlaceHolder.get())->refineAbstractTypeTo(DestTy);
90   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
91
92   return PlaceHolder.get();
93 }
94
95
96 // AdjustIndices - Convert the indexes specifed by Idx to the new changed form
97 // using the specified OldTy as the base type being indexed into.
98 //
99 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
100                                       std::vector<Value*> &Idx,
101                                       unsigned i) {
102   assert(i < Idx.size() && "i out of range!");
103   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
104   if (NewCT == OldTy) return;  // No adjustment unless type changes
105
106   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
107     // Figure out what the current index is...
108     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
109     assert(ElNum < OldST->getElementTypes().size());
110
111     std::map<const StructType*, TransformType>::iterator
112       I = Transforms.find(OldST);
113     if (I != Transforms.end()) {
114       assert(ElNum < I->second.second.size());
115       // Apply the XForm specified by Transforms map...
116       unsigned NewElNum = I->second.second[ElNum];
117       Idx[i] = ConstantUInt::get(Type::UByteTy, NewElNum);
118     }
119   }
120
121   // Recursively process subtypes...
122   if (i+1 < Idx.size())
123     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
124 }
125
126
127 // ConvertValue - Convert from the old value in the old type system to the new
128 // type system.
129 //
130 Value *MutateStructTypes::ConvertValue(const Value *V) {
131   // Ignore null values and simple constants..
132   if (V == 0) return 0;
133
134   if (const Constant *CPV = dyn_cast<Constant>(V)) {
135     if (V->getType()->isPrimitiveType())
136       return (Value*)CPV;
137
138     if (isa<ConstantPointerNull>(CPV))
139       return ConstantPointerNull::get(
140                       cast<PointerType>(ConvertType(V->getType())));
141     assert(0 && "Unable to convert constpool val of this type!");
142   }
143
144   // Check to see if this is an out of function reference first...
145   if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
146     // Check to see if the value is in the map...
147     std::map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
148     if (I == GlobalMap.end())
149       return (Value*)GV;  // Not mapped, just return value itself
150     return I->second;
151   }
152   
153   std::map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
154   if (I != LocalValueMap.end()) return I->second;
155
156   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
157     // Create placeholder block to represent the basic block we haven't seen yet
158     // This will be used when the block gets created.
159     //
160     return LocalValueMap[V] = new BasicBlock(BB->getName());
161   }
162
163   DEBUG(std::cerr << "NPH: " << V << "\n");
164
165   // Otherwise make a constant to represent it
166   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
167 }
168
169
170 // setTransforms - Take a map that specifies what transformation to do for each
171 // field of the specified structure types.  There is one element of the vector
172 // for each field of the structure.  The value specified indicates which slot of
173 // the destination structure the field should end up in.  A negative value 
174 // indicates that the field should be deleted entirely.
175 //
176 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
177
178   // Loop over the types and insert dummy entries into the type map so that 
179   // recursive types are resolved properly...
180   for (std::map<const StructType*, std::vector<int> >::const_iterator
181          I = XForm.begin(), E = XForm.end(); I != E; ++I) {
182     const StructType *OldTy = I->first;
183     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
184   }
185
186   // Loop over the type specified and figure out what types they should become
187   for (std::map<const StructType*, std::vector<int> >::const_iterator
188          I = XForm.begin(), E = XForm.end(); I != E; ++I) {
189     const StructType  *OldTy = I->first;
190     const std::vector<int> &InVec = I->second;
191
192     assert(OldTy->getElementTypes().size() == InVec.size() &&
193            "Action not specified for every element of structure type!");
194
195     std::vector<const Type *> NewType;
196
197     // Convert the elements of the type over, including the new position mapping
198     int Idx = 0;
199     std::vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
200     while (TI != InVec.end()) {
201       unsigned Offset = TI-InVec.begin();
202       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
203       assert(NewEl && "Element not found!");
204       NewType.push_back(NewEl);
205
206       TI = find(InVec.begin(), InVec.end(), ++Idx);
207     }
208
209     // Create a new type that corresponds to the destination type
210     PATypeHolder NSTy = StructType::get(NewType);
211
212     // Refine the old opaque type to the new type to properly handle recursive
213     // types...
214     //
215     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
216     ((DerivedType*)OldTypeStub)->refineAbstractTypeTo(NSTy);
217
218     // Add the transformation to the Transforms map.
219     Transforms.insert(std::make_pair(OldTy,
220                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
221
222     DEBUG(std::cerr << "Mutate " << OldTy << "\nTo " << NSTy << "\n");
223   }
224 }
225
226 void MutateStructTypes::clearTransforms() {
227   Transforms.clear();
228   TypeMap.clear();
229   GlobalMap.clear();
230   assert(LocalValueMap.empty() &&
231          "Local Value Map should always be empty between transformations!");
232 }
233
234 // processGlobals - This loops over global constants defined in the
235 // module, converting them to their new type.
236 //
237 void MutateStructTypes::processGlobals(Module &M) {
238   // Loop through the functions in the module and create a new version of the
239   // function to contained the transformed code.  Also, be careful to not
240   // process the values that we add.
241   //
242   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
243     if (!I->isExternal()) {
244       const FunctionType *NewMTy = 
245         cast<FunctionType>(ConvertType(I->getFunctionType()));
246       
247       // Create a new function to put stuff into...
248       Function *NewMeth = new Function(NewMTy, I->getLinkage(), I->getName());
249       if (I->hasName())
250         I->setName("OLD."+I->getName());
251
252       // Insert the new function into the function list... to be filled in later
253       M.getFunctionList().push_back(NewMeth);
254       
255       // Keep track of the association...
256       GlobalMap[I] = NewMeth;
257     }
258
259   // TODO: HANDLE GLOBAL VARIABLES
260
261   // Remap the symbol table to refer to the types in a nice way
262   //
263   SymbolTable &ST = M.getSymbolTable();
264   SymbolTable::iterator I = ST.find(Type::TypeTy);
265   if (I != ST.end()) {    // Get the type plane for Type's
266     SymbolTable::VarMap &Plane = I->second;
267     for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
268          TI != TE; ++TI) {
269       // FIXME: This is gross, I'm reaching right into a symbol table and
270       // mucking around with it's internals... but oh well.
271       //
272       TI->second = (Value*)cast<Type>(ConvertType(cast<Type>(TI->second)));
273     }
274   }
275 }
276
277
278 // removeDeadGlobals - For this pass, all this does is remove the old versions
279 // of the functions and global variables that we no longer need.
280 void MutateStructTypes::removeDeadGlobals(Module &M) {
281   // Prepare for deletion of globals by dropping their interdependencies...
282   for(Module::iterator I = M.begin(); I != M.end(); ++I) {
283     if (GlobalMap.find(I) != GlobalMap.end())
284       I->dropAllReferences();
285   }
286
287   // Run through and delete the functions and global variables...
288 #if 0  // TODO: HANDLE GLOBAL VARIABLES
289   M->getGlobalList().delete_span(M.gbegin(), M.gbegin()+NumGVars/2);
290 #endif
291   for(Module::iterator I = M.begin(); I != M.end();) {
292     if (GlobalMap.find(I) != GlobalMap.end())
293       I = M.getFunctionList().erase(I);
294     else
295       ++I;
296   }
297 }
298
299
300
301 // transformFunction - This transforms the instructions of the function to use
302 // the new types.
303 //
304 void MutateStructTypes::transformFunction(Function *m) {
305   const Function *M = m;
306   std::map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
307   if (GMI == GlobalMap.end())
308     return;  // Do not affect one of our new functions that we are creating
309
310   Function *NewMeth = cast<Function>(GMI->second);
311
312   // Okay, first order of business, create the arguments...
313   for (Function::aiterator I = m->abegin(), E = m->aend(),
314          DI = NewMeth->abegin(); I != E; ++I, ++DI) {
315     DI->setName(I->getName());
316     LocalValueMap[I] = DI; // Keep track of value mapping
317   }
318
319
320   // Loop over all of the basic blocks copying instructions over...
321   for (Function::const_iterator BB = M->begin(), BBE = M->end(); BB != BBE;
322        ++BB) {
323     // Create a new basic block and establish a mapping between the old and new
324     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
325     NewMeth->getBasicBlockList().push_back(NewBB);  // Add block to function
326
327     // Copy over all of the instructions in the basic block...
328     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
329          II != IE; ++II) {
330
331       const Instruction &I = *II;   // Get the current instruction...
332       Instruction *NewI = 0;
333
334       switch (I.getOpcode()) {
335         // Terminator Instructions
336       case Instruction::Ret:
337         NewI = new ReturnInst(
338                    ConvertValue(cast<ReturnInst>(I).getReturnValue()));
339         break;
340       case Instruction::Br: {
341         const BranchInst &BI = cast<BranchInst>(I);
342         if (BI.isConditional()) {
343           NewI =
344               new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))),
345                              cast<BasicBlock>(ConvertValue(BI.getSuccessor(1))),
346                              ConvertValue(BI.getCondition()));
347         } else {
348           NewI = 
349             new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))));
350         }
351         break;
352       }
353       case Instruction::Switch:
354       case Instruction::Invoke:
355       case Instruction::Unwind:
356         assert(0 && "Insn not implemented!");
357
358         // Binary Instructions
359       case Instruction::Add:
360       case Instruction::Sub:
361       case Instruction::Mul:
362       case Instruction::Div:
363       case Instruction::Rem:
364         // Logical Operations
365       case Instruction::And:
366       case Instruction::Or:
367       case Instruction::Xor:
368
369         // Binary Comparison Instructions
370       case Instruction::SetEQ:
371       case Instruction::SetNE:
372       case Instruction::SetLE:
373       case Instruction::SetGE:
374       case Instruction::SetLT:
375       case Instruction::SetGT:
376         NewI = BinaryOperator::create((Instruction::BinaryOps)I.getOpcode(),
377                                       ConvertValue(I.getOperand(0)),
378                                       ConvertValue(I.getOperand(1)));
379         break;
380
381       case Instruction::Shr:
382       case Instruction::Shl:
383         NewI = new ShiftInst(cast<ShiftInst>(I).getOpcode(),
384                              ConvertValue(I.getOperand(0)),
385                              ConvertValue(I.getOperand(1)));
386         break;
387
388
389         // Memory Instructions
390       case Instruction::Alloca:
391         NewI = 
392           new MallocInst(
393                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
394                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
395         break;
396       case Instruction::Malloc:
397         NewI = 
398           new MallocInst(
399                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
400                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
401         break;
402
403       case Instruction::Free:
404         NewI = new FreeInst(ConvertValue(I.getOperand(0)));
405         break;
406
407       case Instruction::Load:
408         NewI = new LoadInst(ConvertValue(I.getOperand(0)));
409         break;
410       case Instruction::Store:
411         NewI = new StoreInst(ConvertValue(I.getOperand(0)),
412                              ConvertValue(I.getOperand(1)));
413         break;
414       case Instruction::GetElementPtr: {
415         const GetElementPtrInst &GEP = cast<GetElementPtrInst>(I);
416         std::vector<Value*> Indices(GEP.idx_begin(), GEP.idx_end());
417         if (!Indices.empty()) {
418           const Type *PTy =
419             cast<PointerType>(GEP.getOperand(0)->getType())->getElementType();
420           AdjustIndices(cast<CompositeType>(PTy), Indices);
421         }
422
423         NewI = new GetElementPtrInst(ConvertValue(GEP.getOperand(0)), Indices);
424         break;
425       }
426
427         // Miscellaneous Instructions
428       case Instruction::PHINode: {
429         const PHINode &OldPN = cast<PHINode>(I);
430         PHINode *PN = new PHINode(ConvertType(OldPN.getType()));
431         for (unsigned i = 0; i < OldPN.getNumIncomingValues(); ++i)
432           PN->addIncoming(ConvertValue(OldPN.getIncomingValue(i)),
433                     cast<BasicBlock>(ConvertValue(OldPN.getIncomingBlock(i))));
434         NewI = PN;
435         break;
436       }
437       case Instruction::Cast:
438         NewI = new CastInst(ConvertValue(I.getOperand(0)),
439                             ConvertType(I.getType()));
440         break;
441       case Instruction::Call: {
442         Value *Meth = ConvertValue(I.getOperand(0));
443         std::vector<Value*> Operands;
444         for (unsigned i = 1; i < I.getNumOperands(); ++i)
445           Operands.push_back(ConvertValue(I.getOperand(i)));
446         NewI = new CallInst(Meth, Operands);
447         break;
448       }
449         
450       default:
451         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
452         break;
453       }
454
455       NewI->setName(I.getName());
456       NewBB->getInstList().push_back(NewI);
457
458       // Check to see if we had to make a placeholder for this value...
459       std::map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(&I);
460       if (LVMI != LocalValueMap.end()) {
461         // Yup, make sure it's a placeholder...
462         Instruction *I = cast<Instruction>(LVMI->second);
463         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
464
465         // Replace all uses of the place holder with the real deal...
466         I->replaceAllUsesWith(NewI);
467         delete I;                    // And free the placeholder memory
468       }
469
470       // Keep track of the fact the the local implementation of this instruction
471       // is NewI.
472       LocalValueMap[&I] = NewI;
473     }
474   }
475
476   LocalValueMap.clear();
477 }
478
479
480 bool MutateStructTypes::run(Module &M) {
481   processGlobals(M);
482
483   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
484     transformFunction(I);
485
486   removeDeadGlobals(M);
487   return true;
488 }
489