Remove dead code
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns --------------------------=//
2 //
3 // This pass is used to change structure accesses and type definitions in some
4 // way.  It can be used to arbitrarily permute structure fields, safely, without
5 // breaking code.  A transformation may only be done on a type if that type has
6 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
7 // assert and die if you try to do an illegal transformation.
8 //
9 // This is an interprocedural pass that requires the entire program to do a
10 // transformation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/IPO/MutateStructTypes.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Module.h"
17 #include "llvm/Function.h"
18 #include "llvm/BasicBlock.h"
19 #include "llvm/GlobalVariable.h"
20 #include "llvm/SymbolTable.h"
21 #include "llvm/iPHINode.h"
22 #include "llvm/iMemory.h"
23 #include "llvm/iTerminators.h"
24 #include "llvm/iOther.h"
25 #include "llvm/Argument.h"
26 #include "llvm/Constants.h"
27 #include "Support/STLExtras.h"
28 #include <algorithm>
29 using std::map;
30 using std::vector;
31
32 // To enable debugging, uncomment this...
33 //#define DEBUG_MST(x) x
34
35 #ifndef DEBUG_MST
36 #define DEBUG_MST(x)   // Disable debug code
37 #endif
38
39 // ValuePlaceHolder - A stupid little marker value.  It appears as an
40 // instruction of type Instruction::UserOp1.
41 //
42 struct ValuePlaceHolder : public Instruction {
43   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
44
45   virtual Instruction *clone() const { abort(); return 0; }
46   virtual const char *getOpcodeName() const { return "placeholder"; }
47 };
48
49
50 // ConvertType - Convert from the old type system to the new one...
51 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
52   if (Ty->isPrimitiveType() ||
53       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
54
55   map<const Type *, PATypeHolder>::iterator I = TypeMap.find(Ty);
56   if (I != TypeMap.end()) return I->second;
57
58   const Type *DestTy = 0;
59
60   PATypeHolder PlaceHolder = OpaqueType::get();
61   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
62
63   switch (Ty->getPrimitiveID()) {
64   case Type::FunctionTyID: {
65     const FunctionType *MT = cast<FunctionType>(Ty);
66     const Type *RetTy = ConvertType(MT->getReturnType());
67     vector<const Type*> ArgTypes;
68
69     for (FunctionType::ParamTypes::const_iterator I = MT->getParamTypes().begin(),
70            E = MT->getParamTypes().end(); I != E; ++I)
71       ArgTypes.push_back(ConvertType(*I));
72     
73     DestTy = FunctionType::get(RetTy, ArgTypes, MT->isVarArg());
74     break;
75   }
76   case Type::StructTyID: {
77     const StructType *ST = cast<StructType>(Ty);
78     const StructType::ElementTypes &El = ST->getElementTypes();
79     vector<const Type *> Types;
80
81     for (StructType::ElementTypes::const_iterator I = El.begin(), E = El.end();
82          I != E; ++I)
83       Types.push_back(ConvertType(*I));
84     DestTy = StructType::get(Types);
85     break;
86   }
87   case Type::ArrayTyID:
88     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
89                             cast<ArrayType>(Ty)->getNumElements());
90     break;
91
92   case Type::PointerTyID:
93     DestTy = PointerType::get(
94                  ConvertType(cast<PointerType>(Ty)->getElementType()));
95     break;
96   default:
97     assert(0 && "Unknown type!");
98     return 0;
99   }
100
101   assert(DestTy && "Type didn't get created!?!?");
102
103   // Refine our little placeholder value into a real type...
104   cast<DerivedType>(PlaceHolder.get())->refineAbstractTypeTo(DestTy);
105   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
106
107   return PlaceHolder.get();
108 }
109
110
111 // AdjustIndices - Convert the indexes specifed by Idx to the new changed form
112 // using the specified OldTy as the base type being indexed into.
113 //
114 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
115                                       vector<Value*> &Idx,
116                                       unsigned i = 0) {
117   assert(i < Idx.size() && "i out of range!");
118   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
119   if (NewCT == OldTy) return;  // No adjustment unless type changes
120
121   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
122     // Figure out what the current index is...
123     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
124     assert(ElNum < OldST->getElementTypes().size());
125
126     map<const StructType*, TransformType>::iterator I = Transforms.find(OldST);
127     if (I != Transforms.end()) {
128       assert(ElNum < I->second.second.size());
129       // Apply the XForm specified by Transforms map...
130       unsigned NewElNum = I->second.second[ElNum];
131       Idx[i] = ConstantUInt::get(Type::UByteTy, NewElNum);
132     }
133   }
134
135   // Recursively process subtypes...
136   if (i+1 < Idx.size())
137     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
138 }
139
140
141 // ConvertValue - Convert from the old value in the old type system to the new
142 // type system.
143 //
144 Value *MutateStructTypes::ConvertValue(const Value *V) {
145   // Ignore null values and simple constants..
146   if (V == 0) return 0;
147
148   if (Constant *CPV = dyn_cast<Constant>(V)) {
149     if (V->getType()->isPrimitiveType())
150       return CPV;
151
152     if (isa<ConstantPointerNull>(CPV))
153       return ConstantPointerNull::get(
154                       cast<PointerType>(ConvertType(V->getType())));
155     assert(0 && "Unable to convert constpool val of this type!");
156   }
157
158   // Check to see if this is an out of function reference first...
159   if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
160     // Check to see if the value is in the map...
161     map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
162     if (I == GlobalMap.end())
163       return GV;  // Not mapped, just return value itself
164     return I->second;
165   }
166   
167   map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
168   if (I != LocalValueMap.end()) return I->second;
169
170   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
171     // Create placeholder block to represent the basic block we haven't seen yet
172     // This will be used when the block gets created.
173     //
174     return LocalValueMap[V] = new BasicBlock(BB->getName());
175   }
176
177   DEBUG_MST(cerr << "NPH: " << V << endl);
178
179   // Otherwise make a constant to represent it
180   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
181 }
182
183
184 // setTransforms - Take a map that specifies what transformation to do for each
185 // field of the specified structure types.  There is one element of the vector
186 // for each field of the structure.  The value specified indicates which slot of
187 // the destination structure the field should end up in.  A negative value 
188 // indicates that the field should be deleted entirely.
189 //
190 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
191
192   // Loop over the types and insert dummy entries into the type map so that 
193   // recursive types are resolved properly...
194   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
195          E = XForm.end(); I != E; ++I) {
196     const StructType *OldTy = I->first;
197     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
198   }
199
200   // Loop over the type specified and figure out what types they should become
201   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
202          E = XForm.end(); I != E; ++I) {
203     const StructType  *OldTy = I->first;
204     const vector<int> &InVec = I->second;
205
206     assert(OldTy->getElementTypes().size() == InVec.size() &&
207            "Action not specified for every element of structure type!");
208
209     vector<const Type *> NewType;
210
211     // Convert the elements of the type over, including the new position mapping
212     int Idx = 0;
213     vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
214     while (TI != InVec.end()) {
215       unsigned Offset = TI-InVec.begin();
216       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
217       assert(NewEl && "Element not found!");
218       NewType.push_back(NewEl);
219
220       TI = find(InVec.begin(), InVec.end(), ++Idx);
221     }
222
223     // Create a new type that corresponds to the destination type
224     PATypeHolder NSTy = StructType::get(NewType);
225
226     // Refine the old opaque type to the new type to properly handle recursive
227     // types...
228     //
229     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
230     cast<DerivedType>(OldTypeStub)->refineAbstractTypeTo(NSTy);
231
232     // Add the transformation to the Transforms map.
233     Transforms.insert(std::make_pair(OldTy,
234                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
235
236     DEBUG_MST(cerr << "Mutate " << OldTy << "\nTo " << NSTy << endl);
237   }
238 }
239
240 void MutateStructTypes::clearTransforms() {
241   Transforms.clear();
242   TypeMap.clear();
243   GlobalMap.clear();
244   assert(LocalValueMap.empty() &&
245          "Local Value Map should always be empty between transformations!");
246 }
247
248 // doInitialization - This loops over global constants defined in the
249 // module, converting them to their new type.
250 //
251 void MutateStructTypes::processGlobals(Module *M) {
252   // Loop through the functions in the module and create a new version of the
253   // function to contained the transformed code.  Don't use an iterator, because
254   // we will be adding values to the end of the vector, and it could be
255   // reallocated.  Also, we don't want to process the values that we add.
256   //
257   unsigned NumFunctions = M->size();
258   for (unsigned i = 0; i < NumFunctions; ++i) {
259     Function *Meth = M->begin()[i];
260
261     if (!Meth->isExternal()) {
262       const FunctionType *NewMTy = 
263         cast<FunctionType>(ConvertType(Meth->getFunctionType()));
264       
265       // Create a new function to put stuff into...
266       Function *NewMeth = new Function(NewMTy, Meth->hasInternalLinkage(),
267                                    Meth->getName());
268       if (Meth->hasName())
269         Meth->setName("OLD."+Meth->getName());
270
271       // Insert the new function into the function list... to be filled in later
272       M->getFunctionList().push_back(NewMeth);
273       
274       // Keep track of the association...
275       GlobalMap[Meth] = NewMeth;
276     }
277   }
278
279   // TODO: HANDLE GLOBAL VARIABLES
280
281   // Remap the symbol table to refer to the types in a nice way
282   //
283   if (M->hasSymbolTable()) {
284     SymbolTable *ST = M->getSymbolTable();
285     SymbolTable::iterator I = ST->find(Type::TypeTy);
286     if (I != ST->end()) {    // Get the type plane for Type's
287       SymbolTable::VarMap &Plane = I->second;
288       for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
289            TI != TE; ++TI) {
290         // This is gross, I'm reaching right into a symbol table and mucking
291         // around with it's internals... but oh well.
292         //
293         TI->second = cast<Type>(ConvertType(cast<Type>(TI->second)));
294       }
295     }
296   }
297 }
298
299
300 // removeDeadGlobals - For this pass, all this does is remove the old versions
301 // of the functions and global variables that we no longer need.
302 void MutateStructTypes::removeDeadGlobals(Module *M) {
303   // Prepare for deletion of globals by dropping their interdependencies...
304   for(Module::iterator I = M->begin(); I != M->end(); ++I) {
305     if (GlobalMap.find(*I) != GlobalMap.end())
306       (*I)->Function::dropAllReferences();
307   }
308
309   // Run through and delete the functions and global variables...
310 #if 0  // TODO: HANDLE GLOBAL VARIABLES
311   M->getGlobalList().delete_span(M->gbegin(), M->gbegin()+NumGVars/2);
312 #endif
313   for(Module::iterator I = M->begin(); I != M->end();) {
314     if (GlobalMap.find(*I) != GlobalMap.end())
315       delete M->getFunctionList().remove(I);
316     else
317       ++I;
318   }
319 }
320
321
322
323 // transformFunction - This transforms the instructions of the function to use
324 // the new types.
325 //
326 void MutateStructTypes::transformFunction(Function *m) {
327   const Function *M = m;
328   map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
329   if (GMI == GlobalMap.end())
330     return;  // Do not affect one of our new functions that we are creating
331
332   Function *NewMeth = cast<Function>(GMI->second);
333
334   // Okay, first order of business, create the arguments...
335   for (unsigned i = 0, e = M->getArgumentList().size(); i != e; ++i) {
336     const Argument *OFA = M->getArgumentList()[i];
337     Argument *NFA = new Argument(ConvertType(OFA->getType()), OFA->getName());
338     NewMeth->getArgumentList().push_back(NFA);
339     LocalValueMap[OFA] = NFA; // Keep track of value mapping
340   }
341
342
343   // Loop over all of the basic blocks copying instructions over...
344   for (Function::const_iterator BBI = M->begin(), BBE = M->end(); BBI != BBE;
345        ++BBI) {
346
347     // Create a new basic block and establish a mapping between the old and new
348     const BasicBlock *BB = *BBI;
349     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
350     NewMeth->getBasicBlocks().push_back(NewBB);  // Add block to function
351
352     // Copy over all of the instructions in the basic block...
353     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
354          II != IE; ++II) {
355
356       const Instruction *I = *II;   // Get the current instruction...
357       Instruction *NewI = 0;
358
359       switch (I->getOpcode()) {
360         // Terminator Instructions
361       case Instruction::Ret:
362         NewI = new ReturnInst(
363                    ConvertValue(cast<ReturnInst>(I)->getReturnValue()));
364         break;
365       case Instruction::Br: {
366         const BranchInst *BI = cast<BranchInst>(I);
367         if (BI->isConditional()) {
368           NewI =
369             new BranchInst(cast<BasicBlock>(ConvertValue(BI->getSuccessor(0))),
370                            cast<BasicBlock>(ConvertValue(BI->getSuccessor(1))),
371                            ConvertValue(BI->getCondition()));
372         } else {
373           NewI = 
374             new BranchInst(cast<BasicBlock>(ConvertValue(BI->getSuccessor(0))));
375         }
376         break;
377       }
378       case Instruction::Switch:
379       case Instruction::Invoke:
380         assert(0 && "Insn not implemented!");
381
382         // Unary Instructions
383       case Instruction::Not:
384         NewI = UnaryOperator::create((Instruction::UnaryOps)I->getOpcode(),
385                                      ConvertValue(I->getOperand(0)));
386         break;
387
388         // Binary Instructions
389       case Instruction::Add:
390       case Instruction::Sub:
391       case Instruction::Mul:
392       case Instruction::Div:
393       case Instruction::Rem:
394         // Logical Operations
395       case Instruction::And:
396       case Instruction::Or:
397       case Instruction::Xor:
398
399         // Binary Comparison Instructions
400       case Instruction::SetEQ:
401       case Instruction::SetNE:
402       case Instruction::SetLE:
403       case Instruction::SetGE:
404       case Instruction::SetLT:
405       case Instruction::SetGT:
406         NewI = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(),
407                                       ConvertValue(I->getOperand(0)),
408                                       ConvertValue(I->getOperand(1)));
409         break;
410
411       case Instruction::Shr:
412       case Instruction::Shl:
413         NewI = new ShiftInst(cast<ShiftInst>(I)->getOpcode(),
414                              ConvertValue(I->getOperand(0)),
415                              ConvertValue(I->getOperand(1)));
416         break;
417
418
419         // Memory Instructions
420       case Instruction::Alloca:
421         NewI = 
422           new AllocaInst(ConvertType(I->getType()),
423                          I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
424         break;
425       case Instruction::Malloc:
426         NewI = 
427           new MallocInst(ConvertType(I->getType()),
428                          I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
429         break;
430
431       case Instruction::Free:
432         NewI = new FreeInst(ConvertValue(I->getOperand(0)));
433         break;
434
435       case Instruction::Load:
436       case Instruction::Store:
437       case Instruction::GetElementPtr: {
438         const MemAccessInst *MAI = cast<MemAccessInst>(I);
439         vector<Value*> Indices(MAI->idx_begin(), MAI->idx_end());
440         const Value *Ptr = MAI->getPointerOperand();
441         Value *NewPtr = ConvertValue(Ptr);
442         if (!Indices.empty()) {
443           const Type *PTy = cast<PointerType>(Ptr->getType())->getElementType();
444           AdjustIndices(cast<CompositeType>(PTy), Indices);
445         }
446
447         if (isa<LoadInst>(I)) {
448           NewI = new LoadInst(NewPtr, Indices);
449         } else if (isa<StoreInst>(I)) {
450           NewI = new StoreInst(ConvertValue(I->getOperand(0)), NewPtr, Indices);
451         } else if (isa<GetElementPtrInst>(I)) {
452           NewI = new GetElementPtrInst(NewPtr, Indices);
453         } else {
454           assert(0 && "Unknown memory access inst!!!");
455         }
456         break;
457       }
458
459         // Miscellaneous Instructions
460       case Instruction::PHINode: {
461         const PHINode *OldPN = cast<PHINode>(I);
462         PHINode *PN = new PHINode(ConvertType(I->getType()));
463         for (unsigned i = 0; i < OldPN->getNumIncomingValues(); ++i)
464           PN->addIncoming(ConvertValue(OldPN->getIncomingValue(i)),
465                     cast<BasicBlock>(ConvertValue(OldPN->getIncomingBlock(i))));
466         NewI = PN;
467         break;
468       }
469       case Instruction::Cast:
470         NewI = new CastInst(ConvertValue(I->getOperand(0)),
471                             ConvertType(I->getType()));
472         break;
473       case Instruction::Call: {
474         Value *Meth = ConvertValue(I->getOperand(0));
475         vector<Value*> Operands;
476         for (unsigned i = 1; i < I->getNumOperands(); ++i)
477           Operands.push_back(ConvertValue(I->getOperand(i)));
478         NewI = new CallInst(Meth, Operands);
479         break;
480       }
481         
482       default:
483         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
484         break;
485       }
486
487       NewI->setName(I->getName());
488       NewBB->getInstList().push_back(NewI);
489
490       // Check to see if we had to make a placeholder for this value...
491       map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(I);
492       if (LVMI != LocalValueMap.end()) {
493         // Yup, make sure it's a placeholder...
494         Instruction *I = cast<Instruction>(LVMI->second);
495         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
496
497         // Replace all uses of the place holder with the real deal...
498         I->replaceAllUsesWith(NewI);
499         delete I;                    // And free the placeholder memory
500       }
501
502       // Keep track of the fact the the local implementation of this instruction
503       // is NewI.
504       LocalValueMap[I] = NewI;
505     }
506   }
507
508   LocalValueMap.clear();
509 }
510
511
512 bool MutateStructTypes::run(Module *M) {
513   processGlobals(M);
514
515   for_each(M->begin(), M->end(),
516            bind_obj(this, &MutateStructTypes::transformFunction));
517
518   removeDeadGlobals(M);
519   return true;
520 }
521