Add a flag to choose between isolating a function or deleting the function from
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns ------------------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass is used to change structure accesses and type definitions in some
11 // way.  It can be used to arbitrarily permute structure fields, safely, without
12 // breaking code.  A transformation may only be done on a type if that type has
13 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
14 // assert and die if you try to do an illegal transformation.
15 //
16 // This is an interprocedural pass that requires the entire program to do a
17 // transformation.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #include "llvm/Transforms/MutateStructTypes.h"
22 #include "llvm/DerivedTypes.h"
23 #include "llvm/Module.h"
24 #include "llvm/SymbolTable.h"
25 #include "llvm/Instructions.h"
26 #include "llvm/Constants.h"
27 #include "Support/STLExtras.h"
28 #include "Support/Debug.h"
29 #include <algorithm>
30
31 using namespace llvm;
32
33 // ValuePlaceHolder - A stupid little marker value.  It appears as an
34 // instruction of type Instruction::UserOp1.
35 //
36 struct ValuePlaceHolder : public Instruction {
37   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
38
39   virtual Instruction *clone() const { abort(); return 0; }
40   virtual const char *getOpcodeName() const { return "placeholder"; }
41 };
42
43
44 // ConvertType - Convert from the old type system to the new one...
45 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
46   if (Ty->isPrimitiveType() ||
47       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
48
49   std::map<const Type *, PATypeHolder>::iterator I = TypeMap.find(Ty);
50   if (I != TypeMap.end()) return I->second;
51
52   const Type *DestTy = 0;
53
54   PATypeHolder PlaceHolder = OpaqueType::get();
55   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
56
57   switch (Ty->getPrimitiveID()) {
58   case Type::FunctionTyID: {
59     const FunctionType *FT = cast<FunctionType>(Ty);
60     const Type *RetTy = ConvertType(FT->getReturnType());
61     std::vector<const Type*> ArgTypes;
62
63     for (FunctionType::param_iterator I = FT->param_begin(),
64            E = FT->param_end(); I != E; ++I)
65       ArgTypes.push_back(ConvertType(*I));
66     
67     DestTy = FunctionType::get(RetTy, ArgTypes, FT->isVarArg());
68     break;
69   }
70   case Type::StructTyID: {
71     const StructType *ST = cast<StructType>(Ty);
72     std::vector<const Type *> Types;
73
74     for (StructType::element_iterator I = ST->element_begin(),
75            E = ST->element_end(); I != E; ++I)
76       Types.push_back(ConvertType(*I));
77     DestTy = StructType::get(Types);
78     break;
79   }
80   case Type::ArrayTyID:
81     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
82                             cast<ArrayType>(Ty)->getNumElements());
83     break;
84
85   case Type::PointerTyID:
86     DestTy = PointerType::get(
87                  ConvertType(cast<PointerType>(Ty)->getElementType()));
88     break;
89   default:
90     assert(0 && "Unknown type!");
91     return 0;
92   }
93
94   assert(DestTy && "Type didn't get created!?!?");
95
96   // Refine our little placeholder value into a real type...
97   ((DerivedType*)PlaceHolder.get())->refineAbstractTypeTo(DestTy);
98   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
99
100   return PlaceHolder.get();
101 }
102
103
104 // AdjustIndices - Convert the indices specified by Idx to the new changed form
105 // using the specified OldTy as the base type being indexed into.
106 //
107 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
108                                       std::vector<Value*> &Idx,
109                                       unsigned i) {
110   assert(i < Idx.size() && "i out of range!");
111   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
112   if (NewCT == OldTy) return;  // No adjustment unless type changes
113
114   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
115     // Figure out what the current index is...
116     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
117     assert(ElNum < OldST->getNumElements());
118
119     std::map<const StructType*, TransformType>::iterator
120       I = Transforms.find(OldST);
121     if (I != Transforms.end()) {
122       assert(ElNum < I->second.second.size());
123       // Apply the XForm specified by Transforms map...
124       unsigned NewElNum = I->second.second[ElNum];
125       Idx[i] = ConstantUInt::get(Idx[i]->getType(), NewElNum);
126     }
127   }
128
129   // Recursively process subtypes...
130   if (i+1 < Idx.size())
131     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
132 }
133
134
135 // ConvertValue - Convert from the old value in the old type system to the new
136 // type system.
137 //
138 Value *MutateStructTypes::ConvertValue(const Value *V) {
139   // Ignore null values and simple constants..
140   if (V == 0) return 0;
141
142   if (const Constant *CPV = dyn_cast<Constant>(V)) {
143     if (V->getType()->isPrimitiveType())
144       return (Value*)CPV;
145
146     if (isa<ConstantPointerNull>(CPV))
147       return ConstantPointerNull::get(
148                       cast<PointerType>(ConvertType(V->getType())));
149     assert(0 && "Unable to convert constpool val of this type!");
150   }
151
152   // Check to see if this is an out of function reference first...
153   if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
154     // Check to see if the value is in the map...
155     std::map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
156     if (I == GlobalMap.end())
157       return (Value*)GV;  // Not mapped, just return value itself
158     return I->second;
159   }
160   
161   std::map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
162   if (I != LocalValueMap.end()) return I->second;
163
164   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
165     // Create placeholder block to represent the basic block we haven't seen yet
166     // This will be used when the block gets created.
167     //
168     return LocalValueMap[V] = new BasicBlock(BB->getName());
169   }
170
171   DEBUG(std::cerr << "NPH: " << V << "\n");
172
173   // Otherwise make a constant to represent it
174   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
175 }
176
177
178 // setTransforms - Take a map that specifies what transformation to do for each
179 // field of the specified structure types.  There is one element of the vector
180 // for each field of the structure.  The value specified indicates which slot of
181 // the destination structure the field should end up in.  A negative value 
182 // indicates that the field should be deleted entirely.
183 //
184 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
185
186   // Loop over the types and insert dummy entries into the type map so that 
187   // recursive types are resolved properly...
188   for (std::map<const StructType*, std::vector<int> >::const_iterator
189          I = XForm.begin(), E = XForm.end(); I != E; ++I) {
190     const StructType *OldTy = I->first;
191     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
192   }
193
194   // Loop over the type specified and figure out what types they should become
195   for (std::map<const StructType*, std::vector<int> >::const_iterator
196          I = XForm.begin(), E = XForm.end(); I != E; ++I) {
197     const StructType  *OldTy = I->first;
198     const std::vector<int> &InVec = I->second;
199
200     assert(OldTy->getNumElements() == InVec.size() &&
201            "Action not specified for every element of structure type!");
202
203     std::vector<const Type *> NewType;
204
205     // Convert the elements of the type over, including the new position mapping
206     int Idx = 0;
207     std::vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
208     while (TI != InVec.end()) {
209       unsigned Offset = TI-InVec.begin();
210       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
211       assert(NewEl && "Element not found!");
212       NewType.push_back(NewEl);
213
214       TI = find(InVec.begin(), InVec.end(), ++Idx);
215     }
216
217     // Create a new type that corresponds to the destination type
218     PATypeHolder NSTy = StructType::get(NewType);
219
220     // Refine the old opaque type to the new type to properly handle recursive
221     // types...
222     //
223     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
224     ((DerivedType*)OldTypeStub)->refineAbstractTypeTo(NSTy);
225
226     // Add the transformation to the Transforms map.
227     Transforms.insert(std::make_pair(OldTy,
228                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
229
230     DEBUG(std::cerr << "Mutate " << OldTy << "\nTo " << NSTy << "\n");
231   }
232 }
233
234 void MutateStructTypes::clearTransforms() {
235   Transforms.clear();
236   TypeMap.clear();
237   GlobalMap.clear();
238   assert(LocalValueMap.empty() &&
239          "Local Value Map should always be empty between transformations!");
240 }
241
242 // processGlobals - This loops over global constants defined in the
243 // module, converting them to their new type.
244 //
245 void MutateStructTypes::processGlobals(Module &M) {
246   // Loop through the functions in the module and create a new version of the
247   // function to contained the transformed code.  Also, be careful to not
248   // process the values that we add.
249   //
250   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
251     if (!I->isExternal()) {
252       const FunctionType *NewMTy = 
253         cast<FunctionType>(ConvertType(I->getFunctionType()));
254       
255       // Create a new function to put stuff into...
256       Function *NewMeth = new Function(NewMTy, I->getLinkage(), I->getName());
257       if (I->hasName())
258         I->setName("OLD."+I->getName());
259
260       // Insert the new function into the function list... to be filled in later
261       M.getFunctionList().push_back(NewMeth);
262       
263       // Keep track of the association...
264       GlobalMap[I] = NewMeth;
265     }
266
267   // TODO: HANDLE GLOBAL VARIABLES
268
269   // Remap the symbol table to refer to the types in a nice way
270   //
271   SymbolTable &ST = M.getSymbolTable();
272   SymbolTable::iterator I = ST.find(Type::TypeTy);
273   if (I != ST.end()) {    // Get the type plane for Type's
274     SymbolTable::VarMap &Plane = I->second;
275     for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
276          TI != TE; ++TI) {
277       // FIXME: This is gross, I'm reaching right into a symbol table and
278       // mucking around with it's internals... but oh well.
279       //
280       TI->second = (Value*)cast<Type>(ConvertType(cast<Type>(TI->second)));
281     }
282   }
283 }
284
285
286 // removeDeadGlobals - For this pass, all this does is remove the old versions
287 // of the functions and global variables that we no longer need.
288 void MutateStructTypes::removeDeadGlobals(Module &M) {
289   // Prepare for deletion of globals by dropping their interdependencies...
290   for(Module::iterator I = M.begin(); I != M.end(); ++I) {
291     if (GlobalMap.find(I) != GlobalMap.end())
292       I->dropAllReferences();
293   }
294
295   // Run through and delete the functions and global variables...
296 #if 0  // TODO: HANDLE GLOBAL VARIABLES
297   M->getGlobalList().delete_span(M.gbegin(), M.gbegin()+NumGVars/2);
298 #endif
299   for(Module::iterator I = M.begin(); I != M.end();) {
300     if (GlobalMap.find(I) != GlobalMap.end())
301       I = M.getFunctionList().erase(I);
302     else
303       ++I;
304   }
305 }
306
307
308
309 // transformFunction - This transforms the instructions of the function to use
310 // the new types.
311 //
312 void MutateStructTypes::transformFunction(Function *m) {
313   const Function *M = m;
314   std::map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
315   if (GMI == GlobalMap.end())
316     return;  // Do not affect one of our new functions that we are creating
317
318   Function *NewMeth = cast<Function>(GMI->second);
319
320   // Okay, first order of business, create the arguments...
321   for (Function::aiterator I = m->abegin(), E = m->aend(),
322          DI = NewMeth->abegin(); I != E; ++I, ++DI) {
323     DI->setName(I->getName());
324     LocalValueMap[I] = DI; // Keep track of value mapping
325   }
326
327
328   // Loop over all of the basic blocks copying instructions over...
329   for (Function::const_iterator BB = M->begin(), BBE = M->end(); BB != BBE;
330        ++BB) {
331     // Create a new basic block and establish a mapping between the old and new
332     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
333     NewMeth->getBasicBlockList().push_back(NewBB);  // Add block to function
334
335     // Copy over all of the instructions in the basic block...
336     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
337          II != IE; ++II) {
338
339       const Instruction &I = *II;   // Get the current instruction...
340       Instruction *NewI = 0;
341
342       switch (I.getOpcode()) {
343         // Terminator Instructions
344       case Instruction::Ret:
345         NewI = new ReturnInst(
346                    ConvertValue(cast<ReturnInst>(I).getReturnValue()));
347         break;
348       case Instruction::Br: {
349         const BranchInst &BI = cast<BranchInst>(I);
350         if (BI.isConditional()) {
351           NewI =
352               new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))),
353                              cast<BasicBlock>(ConvertValue(BI.getSuccessor(1))),
354                              ConvertValue(BI.getCondition()));
355         } else {
356           NewI = 
357             new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))));
358         }
359         break;
360       }
361       case Instruction::Switch:
362       case Instruction::Invoke:
363       case Instruction::Unwind:
364         assert(0 && "Insn not implemented!");
365
366         // Binary Instructions
367       case Instruction::Add:
368       case Instruction::Sub:
369       case Instruction::Mul:
370       case Instruction::Div:
371       case Instruction::Rem:
372         // Logical Operations
373       case Instruction::And:
374       case Instruction::Or:
375       case Instruction::Xor:
376
377         // Binary Comparison Instructions
378       case Instruction::SetEQ:
379       case Instruction::SetNE:
380       case Instruction::SetLE:
381       case Instruction::SetGE:
382       case Instruction::SetLT:
383       case Instruction::SetGT:
384         NewI = BinaryOperator::create((Instruction::BinaryOps)I.getOpcode(),
385                                       ConvertValue(I.getOperand(0)),
386                                       ConvertValue(I.getOperand(1)));
387         break;
388
389       case Instruction::Shr:
390       case Instruction::Shl:
391         NewI = new ShiftInst(cast<ShiftInst>(I).getOpcode(),
392                              ConvertValue(I.getOperand(0)),
393                              ConvertValue(I.getOperand(1)));
394         break;
395
396
397         // Memory Instructions
398       case Instruction::Alloca:
399         NewI = 
400           new MallocInst(
401                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
402                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
403         break;
404       case Instruction::Malloc:
405         NewI = 
406           new MallocInst(
407                   ConvertType(cast<PointerType>(I.getType())->getElementType()),
408                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
409         break;
410
411       case Instruction::Free:
412         NewI = new FreeInst(ConvertValue(I.getOperand(0)));
413         break;
414
415       case Instruction::Load:
416         NewI = new LoadInst(ConvertValue(I.getOperand(0)));
417         break;
418       case Instruction::Store:
419         NewI = new StoreInst(ConvertValue(I.getOperand(0)),
420                              ConvertValue(I.getOperand(1)));
421         break;
422       case Instruction::GetElementPtr: {
423         const GetElementPtrInst &GEP = cast<GetElementPtrInst>(I);
424         std::vector<Value*> Indices(GEP.idx_begin(), GEP.idx_end());
425         if (!Indices.empty()) {
426           const Type *PTy =
427             cast<PointerType>(GEP.getOperand(0)->getType())->getElementType();
428           AdjustIndices(cast<CompositeType>(PTy), Indices);
429         }
430
431         NewI = new GetElementPtrInst(ConvertValue(GEP.getOperand(0)), Indices);
432         break;
433       }
434
435         // Miscellaneous Instructions
436       case Instruction::PHI: {
437         const PHINode &OldPN = cast<PHINode>(I);
438         PHINode *PN = new PHINode(ConvertType(OldPN.getType()));
439         for (unsigned i = 0; i < OldPN.getNumIncomingValues(); ++i)
440           PN->addIncoming(ConvertValue(OldPN.getIncomingValue(i)),
441                     cast<BasicBlock>(ConvertValue(OldPN.getIncomingBlock(i))));
442         NewI = PN;
443         break;
444       }
445       case Instruction::Cast:
446         NewI = new CastInst(ConvertValue(I.getOperand(0)),
447                             ConvertType(I.getType()));
448         break;
449       case Instruction::Call: {
450         Value *Meth = ConvertValue(I.getOperand(0));
451         std::vector<Value*> Operands;
452         for (unsigned i = 1; i < I.getNumOperands(); ++i)
453           Operands.push_back(ConvertValue(I.getOperand(i)));
454         NewI = new CallInst(Meth, Operands);
455         break;
456       }
457         
458       default:
459         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
460         break;
461       }
462
463       NewI->setName(I.getName());
464       NewBB->getInstList().push_back(NewI);
465
466       // Check to see if we had to make a placeholder for this value...
467       std::map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(&I);
468       if (LVMI != LocalValueMap.end()) {
469         // Yup, make sure it's a placeholder...
470         Instruction *I = cast<Instruction>(LVMI->second);
471         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
472
473         // Replace all uses of the place holder with the real deal...
474         I->replaceAllUsesWith(NewI);
475         delete I;                    // And free the placeholder memory
476       }
477
478       // Keep track of the fact the the local implementation of this instruction
479       // is NewI.
480       LocalValueMap[&I] = NewI;
481     }
482   }
483
484   LocalValueMap.clear();
485 }
486
487
488 bool MutateStructTypes::run(Module &M) {
489   processGlobals(M);
490
491   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
492     transformFunction(I);
493
494   removeDeadGlobals(M);
495   return true;
496 }
497