Updates to work with recent Statistic's changes:
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns --------------------------=//
2 //
3 // This pass is used to change structure accesses and type definitions in some
4 // way.  It can be used to arbitrarily permute structure fields, safely, without
5 // breaking code.  A transformation may only be done on a type if that type has
6 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
7 // assert and die if you try to do an illegal transformation.
8 //
9 // This is an interprocedural pass that requires the entire program to do a
10 // transformation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/IPO/MutateStructTypes.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Module.h"
17 #include "llvm/SymbolTable.h"
18 #include "llvm/iPHINode.h"
19 #include "llvm/iMemory.h"
20 #include "llvm/iTerminators.h"
21 #include "llvm/iOther.h"
22 #include "llvm/Constants.h"
23 #include "Support/STLExtras.h"
24 #include "Support/Statistic.h"
25 #include <algorithm>
26
27 using std::map;
28 using std::vector;
29
30 // ValuePlaceHolder - A stupid little marker value.  It appears as an
31 // instruction of type Instruction::UserOp1.
32 //
33 struct ValuePlaceHolder : public Instruction {
34   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
35
36   virtual Instruction *clone() const { abort(); return 0; }
37   virtual const char *getOpcodeName() const { return "placeholder"; }
38 };
39
40
41 // ConvertType - Convert from the old type system to the new one...
42 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
43   if (Ty->isPrimitiveType() ||
44       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
45
46   map<const Type *, PATypeHolder>::iterator I = TypeMap.find(Ty);
47   if (I != TypeMap.end()) return I->second;
48
49   const Type *DestTy = 0;
50
51   PATypeHolder PlaceHolder = OpaqueType::get();
52   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
53
54   switch (Ty->getPrimitiveID()) {
55   case Type::FunctionTyID: {
56     const FunctionType *MT = cast<FunctionType>(Ty);
57     const Type *RetTy = ConvertType(MT->getReturnType());
58     vector<const Type*> ArgTypes;
59
60     for (FunctionType::ParamTypes::const_iterator I = MT->getParamTypes().begin(),
61            E = MT->getParamTypes().end(); I != E; ++I)
62       ArgTypes.push_back(ConvertType(*I));
63     
64     DestTy = FunctionType::get(RetTy, ArgTypes, MT->isVarArg());
65     break;
66   }
67   case Type::StructTyID: {
68     const StructType *ST = cast<StructType>(Ty);
69     const StructType::ElementTypes &El = ST->getElementTypes();
70     vector<const Type *> Types;
71
72     for (StructType::ElementTypes::const_iterator I = El.begin(), E = El.end();
73          I != E; ++I)
74       Types.push_back(ConvertType(*I));
75     DestTy = StructType::get(Types);
76     break;
77   }
78   case Type::ArrayTyID:
79     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
80                             cast<ArrayType>(Ty)->getNumElements());
81     break;
82
83   case Type::PointerTyID:
84     DestTy = PointerType::get(
85                  ConvertType(cast<PointerType>(Ty)->getElementType()));
86     break;
87   default:
88     assert(0 && "Unknown type!");
89     return 0;
90   }
91
92   assert(DestTy && "Type didn't get created!?!?");
93
94   // Refine our little placeholder value into a real type...
95   ((DerivedType*)PlaceHolder.get())->refineAbstractTypeTo(DestTy);
96   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
97
98   return PlaceHolder.get();
99 }
100
101
102 // AdjustIndices - Convert the indexes specifed by Idx to the new changed form
103 // using the specified OldTy as the base type being indexed into.
104 //
105 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
106                                       vector<Value*> &Idx,
107                                       unsigned i) {
108   assert(i < Idx.size() && "i out of range!");
109   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
110   if (NewCT == OldTy) return;  // No adjustment unless type changes
111
112   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
113     // Figure out what the current index is...
114     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
115     assert(ElNum < OldST->getElementTypes().size());
116
117     map<const StructType*, TransformType>::iterator I = Transforms.find(OldST);
118     if (I != Transforms.end()) {
119       assert(ElNum < I->second.second.size());
120       // Apply the XForm specified by Transforms map...
121       unsigned NewElNum = I->second.second[ElNum];
122       Idx[i] = ConstantUInt::get(Type::UByteTy, NewElNum);
123     }
124   }
125
126   // Recursively process subtypes...
127   if (i+1 < Idx.size())
128     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
129 }
130
131
132 // ConvertValue - Convert from the old value in the old type system to the new
133 // type system.
134 //
135 Value *MutateStructTypes::ConvertValue(const Value *V) {
136   // Ignore null values and simple constants..
137   if (V == 0) return 0;
138
139   if (const Constant *CPV = dyn_cast<Constant>(V)) {
140     if (V->getType()->isPrimitiveType())
141       return (Value*)CPV;
142
143     if (isa<ConstantPointerNull>(CPV))
144       return ConstantPointerNull::get(
145                       cast<PointerType>(ConvertType(V->getType())));
146     assert(0 && "Unable to convert constpool val of this type!");
147   }
148
149   // Check to see if this is an out of function reference first...
150   if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
151     // Check to see if the value is in the map...
152     map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
153     if (I == GlobalMap.end())
154       return (Value*)GV;  // Not mapped, just return value itself
155     return I->second;
156   }
157   
158   map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
159   if (I != LocalValueMap.end()) return I->second;
160
161   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
162     // Create placeholder block to represent the basic block we haven't seen yet
163     // This will be used when the block gets created.
164     //
165     return LocalValueMap[V] = new BasicBlock(BB->getName());
166   }
167
168   DEBUG(std::cerr << "NPH: " << V << "\n");
169
170   // Otherwise make a constant to represent it
171   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
172 }
173
174
175 // setTransforms - Take a map that specifies what transformation to do for each
176 // field of the specified structure types.  There is one element of the vector
177 // for each field of the structure.  The value specified indicates which slot of
178 // the destination structure the field should end up in.  A negative value 
179 // indicates that the field should be deleted entirely.
180 //
181 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
182
183   // Loop over the types and insert dummy entries into the type map so that 
184   // recursive types are resolved properly...
185   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
186          E = XForm.end(); I != E; ++I) {
187     const StructType *OldTy = I->first;
188     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
189   }
190
191   // Loop over the type specified and figure out what types they should become
192   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
193          E = XForm.end(); I != E; ++I) {
194     const StructType  *OldTy = I->first;
195     const vector<int> &InVec = I->second;
196
197     assert(OldTy->getElementTypes().size() == InVec.size() &&
198            "Action not specified for every element of structure type!");
199
200     vector<const Type *> NewType;
201
202     // Convert the elements of the type over, including the new position mapping
203     int Idx = 0;
204     vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
205     while (TI != InVec.end()) {
206       unsigned Offset = TI-InVec.begin();
207       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
208       assert(NewEl && "Element not found!");
209       NewType.push_back(NewEl);
210
211       TI = find(InVec.begin(), InVec.end(), ++Idx);
212     }
213
214     // Create a new type that corresponds to the destination type
215     PATypeHolder NSTy = StructType::get(NewType);
216
217     // Refine the old opaque type to the new type to properly handle recursive
218     // types...
219     //
220     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
221     ((DerivedType*)OldTypeStub)->refineAbstractTypeTo(NSTy);
222
223     // Add the transformation to the Transforms map.
224     Transforms.insert(std::make_pair(OldTy,
225                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
226
227     DEBUG(std::cerr << "Mutate " << OldTy << "\nTo " << NSTy << "\n");
228   }
229 }
230
231 void MutateStructTypes::clearTransforms() {
232   Transforms.clear();
233   TypeMap.clear();
234   GlobalMap.clear();
235   assert(LocalValueMap.empty() &&
236          "Local Value Map should always be empty between transformations!");
237 }
238
239 // processGlobals - This loops over global constants defined in the
240 // module, converting them to their new type.
241 //
242 void MutateStructTypes::processGlobals(Module &M) {
243   // Loop through the functions in the module and create a new version of the
244   // function to contained the transformed code.  Also, be careful to not
245   // process the values that we add.
246   //
247   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
248     if (!I->isExternal()) {
249       const FunctionType *NewMTy = 
250         cast<FunctionType>(ConvertType(I->getFunctionType()));
251       
252       // Create a new function to put stuff into...
253       Function *NewMeth = new Function(NewMTy, I->hasInternalLinkage(),
254                                        I->getName());
255       if (I->hasName())
256         I->setName("OLD."+I->getName());
257
258       // Insert the new function into the function list... to be filled in later
259       M.getFunctionList().push_back(NewMeth);
260       
261       // Keep track of the association...
262       GlobalMap[I] = NewMeth;
263     }
264
265   // TODO: HANDLE GLOBAL VARIABLES
266
267   // Remap the symbol table to refer to the types in a nice way
268   //
269   if (SymbolTable *ST = M.getSymbolTable()) {    
270     SymbolTable::iterator I = ST->find(Type::TypeTy);
271     if (I != ST->end()) {    // Get the type plane for Type's
272       SymbolTable::VarMap &Plane = I->second;
273       for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
274            TI != TE; ++TI) {
275         // FIXME: This is gross, I'm reaching right into a symbol table and
276         // mucking around with it's internals... but oh well.
277         //
278         TI->second = (Value*)cast<Type>(ConvertType(cast<Type>(TI->second)));
279       }
280     }
281   }
282 }
283
284
285 // removeDeadGlobals - For this pass, all this does is remove the old versions
286 // of the functions and global variables that we no longer need.
287 void MutateStructTypes::removeDeadGlobals(Module &M) {
288   // Prepare for deletion of globals by dropping their interdependencies...
289   for(Module::iterator I = M.begin(); I != M.end(); ++I) {
290     if (GlobalMap.find(I) != GlobalMap.end())
291       I->dropAllReferences();
292   }
293
294   // Run through and delete the functions and global variables...
295 #if 0  // TODO: HANDLE GLOBAL VARIABLES
296   M->getGlobalList().delete_span(M.gbegin(), M.gbegin()+NumGVars/2);
297 #endif
298   for(Module::iterator I = M.begin(); I != M.end();) {
299     if (GlobalMap.find(I) != GlobalMap.end())
300       I = M.getFunctionList().erase(I);
301     else
302       ++I;
303   }
304 }
305
306
307
308 // transformFunction - This transforms the instructions of the function to use
309 // the new types.
310 //
311 void MutateStructTypes::transformFunction(Function *m) {
312   const Function *M = m;
313   map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
314   if (GMI == GlobalMap.end())
315     return;  // Do not affect one of our new functions that we are creating
316
317   Function *NewMeth = cast<Function>(GMI->second);
318
319   // Okay, first order of business, create the arguments...
320   for (Function::aiterator I = m->abegin(), E = m->aend(); I != E; ++I) {
321     Argument *NFA = new Argument(ConvertType(I->getType()), I->getName());
322     NewMeth->getArgumentList().push_back(NFA);
323     LocalValueMap[I] = NFA; // Keep track of value mapping
324   }
325
326
327   // Loop over all of the basic blocks copying instructions over...
328   for (Function::const_iterator BB = M->begin(), BBE = M->end(); BB != BBE;
329        ++BB) {
330     // Create a new basic block and establish a mapping between the old and new
331     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
332     NewMeth->getBasicBlockList().push_back(NewBB);  // Add block to function
333
334     // Copy over all of the instructions in the basic block...
335     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
336          II != IE; ++II) {
337
338       const Instruction &I = *II;   // Get the current instruction...
339       Instruction *NewI = 0;
340
341       switch (I.getOpcode()) {
342         // Terminator Instructions
343       case Instruction::Ret:
344         NewI = new ReturnInst(
345                    ConvertValue(cast<ReturnInst>(I).getReturnValue()));
346         break;
347       case Instruction::Br: {
348         const BranchInst &BI = cast<BranchInst>(I);
349         if (BI.isConditional()) {
350           NewI =
351               new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))),
352                              cast<BasicBlock>(ConvertValue(BI.getSuccessor(1))),
353                              ConvertValue(BI.getCondition()));
354         } else {
355           NewI = 
356             new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))));
357         }
358         break;
359       }
360       case Instruction::Switch:
361       case Instruction::Invoke:
362         assert(0 && "Insn not implemented!");
363
364         // Binary Instructions
365       case Instruction::Add:
366       case Instruction::Sub:
367       case Instruction::Mul:
368       case Instruction::Div:
369       case Instruction::Rem:
370         // Logical Operations
371       case Instruction::And:
372       case Instruction::Or:
373       case Instruction::Xor:
374
375         // Binary Comparison Instructions
376       case Instruction::SetEQ:
377       case Instruction::SetNE:
378       case Instruction::SetLE:
379       case Instruction::SetGE:
380       case Instruction::SetLT:
381       case Instruction::SetGT:
382         NewI = BinaryOperator::create((Instruction::BinaryOps)I.getOpcode(),
383                                       ConvertValue(I.getOperand(0)),
384                                       ConvertValue(I.getOperand(1)));
385         break;
386
387       case Instruction::Shr:
388       case Instruction::Shl:
389         NewI = new ShiftInst(cast<ShiftInst>(I).getOpcode(),
390                              ConvertValue(I.getOperand(0)),
391                              ConvertValue(I.getOperand(1)));
392         break;
393
394
395         // Memory Instructions
396       case Instruction::Alloca:
397         NewI = 
398           new MallocInst(
399                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
400                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
401         break;
402       case Instruction::Malloc:
403         NewI = 
404           new MallocInst(
405                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
406                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
407         break;
408
409       case Instruction::Free:
410         NewI = new FreeInst(ConvertValue(I.getOperand(0)));
411         break;
412
413       case Instruction::Load:
414         NewI = new LoadInst(ConvertValue(I.getOperand(0)));
415         break;
416       case Instruction::Store:
417         NewI = new StoreInst(ConvertValue(I.getOperand(0)),
418                              ConvertValue(I.getOperand(1)));
419         break;
420       case Instruction::GetElementPtr: {
421         const GetElementPtrInst &GEP = cast<GetElementPtrInst>(I);
422         vector<Value*> Indices(GEP.idx_begin(), GEP.idx_end());
423         if (!Indices.empty()) {
424           const Type *PTy =
425             cast<PointerType>(GEP.getOperand(0)->getType())->getElementType();
426           AdjustIndices(cast<CompositeType>(PTy), Indices);
427         }
428
429         NewI = new GetElementPtrInst(ConvertValue(GEP.getOperand(0)), Indices);
430         break;
431       }
432
433         // Miscellaneous Instructions
434       case Instruction::PHINode: {
435         const PHINode &OldPN = cast<PHINode>(I);
436         PHINode *PN = new PHINode(ConvertType(OldPN.getType()));
437         for (unsigned i = 0; i < OldPN.getNumIncomingValues(); ++i)
438           PN->addIncoming(ConvertValue(OldPN.getIncomingValue(i)),
439                     cast<BasicBlock>(ConvertValue(OldPN.getIncomingBlock(i))));
440         NewI = PN;
441         break;
442       }
443       case Instruction::Cast:
444         NewI = new CastInst(ConvertValue(I.getOperand(0)),
445                             ConvertType(I.getType()));
446         break;
447       case Instruction::Call: {
448         Value *Meth = ConvertValue(I.getOperand(0));
449         vector<Value*> Operands;
450         for (unsigned i = 1; i < I.getNumOperands(); ++i)
451           Operands.push_back(ConvertValue(I.getOperand(i)));
452         NewI = new CallInst(Meth, Operands);
453         break;
454       }
455         
456       default:
457         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
458         break;
459       }
460
461       NewI->setName(I.getName());
462       NewBB->getInstList().push_back(NewI);
463
464       // Check to see if we had to make a placeholder for this value...
465       map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(&I);
466       if (LVMI != LocalValueMap.end()) {
467         // Yup, make sure it's a placeholder...
468         Instruction *I = cast<Instruction>(LVMI->second);
469         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
470
471         // Replace all uses of the place holder with the real deal...
472         I->replaceAllUsesWith(NewI);
473         delete I;                    // And free the placeholder memory
474       }
475
476       // Keep track of the fact the the local implementation of this instruction
477       // is NewI.
478       LocalValueMap[&I] = NewI;
479     }
480   }
481
482   LocalValueMap.clear();
483 }
484
485
486 bool MutateStructTypes::run(Module &M) {
487   processGlobals(M);
488
489   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
490     transformFunction(I);
491
492   removeDeadGlobals(M);
493   return true;
494 }
495