Non-functionality change just to make it more clear what is going on
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
index e06ee61412802d86799ba3907292ac3685da692d..7f62f2b50e1947c086100d4d6649003bbafe7d3e 100644 (file)
 #include "llvm/Transforms/IPO/MutateStructTypes.h"
 #include "llvm/DerivedTypes.h"
 #include "llvm/Module.h"
-#include "llvm/Function.h"
-#include "llvm/BasicBlock.h"
-#include "llvm/GlobalVariable.h"
 #include "llvm/SymbolTable.h"
 #include "llvm/iPHINode.h"
 #include "llvm/iMemory.h"
 #include "llvm/iTerminators.h"
 #include "llvm/iOther.h"
+#include "llvm/Constants.h"
 #include "Support/STLExtras.h"
+#include "Support/Statistic.h"
 #include <algorithm>
+
 using std::map;
 using std::vector;
 
-//FIXME: These headers are only included because the analyses are killed!!!
-#include "llvm/Analysis/CallGraph.h"
-#include "llvm/Analysis/FindUsedTypes.h"
-#include "llvm/Analysis/FindUnsafePointerTypes.h"
-//FIXME end
-
-// To enable debugging, uncomment this...
-//#define DEBUG_MST(x) x
-
-#ifndef DEBUG_MST
-#define DEBUG_MST(x)   // Disable debug code
-#endif
-
 // ValuePlaceHolder - A stupid little marker value.  It appears as an
 // instruction of type Instruction::UserOp1.
 //
@@ -105,7 +92,7 @@ const Type *MutateStructTypes::ConvertType(const Type *Ty) {
   assert(DestTy && "Type didn't get created!?!?");
 
   // Refine our little placeholder value into a real type...
-  cast<DerivedType>(PlaceHolder.get())->refineAbstractTypeTo(DestTy);
+  ((DerivedType*)PlaceHolder.get())->refineAbstractTypeTo(DestTy);
   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
 
   return PlaceHolder.get();
@@ -117,7 +104,7 @@ const Type *MutateStructTypes::ConvertType(const Type *Ty) {
 //
 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
                                       vector<Value*> &Idx,
-                                      unsigned i = 0) {
+                                      unsigned i) {
   assert(i < Idx.size() && "i out of range!");
   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
   if (NewCT == OldTy) return;  // No adjustment unless type changes
@@ -149,9 +136,9 @@ Value *MutateStructTypes::ConvertValue(const Value *V) {
   // Ignore null values and simple constants..
   if (V == 0) return 0;
 
-  if (Constant *CPV = dyn_cast<Constant>(V)) {
+  if (const Constant *CPV = dyn_cast<Constant>(V)) {
     if (V->getType()->isPrimitiveType())
-      return CPV;
+      return (Value*)CPV;
 
     if (isa<ConstantPointerNull>(CPV))
       return ConstantPointerNull::get(
@@ -160,11 +147,11 @@ Value *MutateStructTypes::ConvertValue(const Value *V) {
   }
 
   // Check to see if this is an out of function reference first...
-  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
+  if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
     // Check to see if the value is in the map...
     map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
     if (I == GlobalMap.end())
-      return GV;  // Not mapped, just return value itself
+      return (Value*)GV;  // Not mapped, just return value itself
     return I->second;
   }
   
@@ -178,7 +165,7 @@ Value *MutateStructTypes::ConvertValue(const Value *V) {
     return LocalValueMap[V] = new BasicBlock(BB->getName());
   }
 
-  DEBUG_MST(cerr << "NPH: " << V << endl);
+  DEBUG(std::cerr << "NPH: " << V << "\n");
 
   // Otherwise make a constant to represent it
   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
@@ -231,13 +218,13 @@ void MutateStructTypes::setTransforms(const TransformsType &XForm) {
     // types...
     //
     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
-    cast<DerivedType>(OldTypeStub)->refineAbstractTypeTo(NSTy);
+    ((DerivedType*)OldTypeStub)->refineAbstractTypeTo(NSTy);
 
     // Add the transformation to the Transforms map.
     Transforms.insert(std::make_pair(OldTy,
                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
 
-    DEBUG_MST(cerr << "Mutate " << OldTy << "\nTo " << NSTy << endl);
+    DEBUG(std::cerr << "Mutate " << OldTy << "\nTo " << NSTy << "\n");
   }
 }
 
@@ -249,52 +236,46 @@ void MutateStructTypes::clearTransforms() {
          "Local Value Map should always be empty between transformations!");
 }
 
-// doInitialization - This loops over global constants defined in the
+// processGlobals - This loops over global constants defined in the
 // module, converting them to their new type.
 //
-void MutateStructTypes::processGlobals(Module *M) {
+void MutateStructTypes::processGlobals(Module &M) {
   // Loop through the functions in the module and create a new version of the
-  // function to contained the transformed code.  Don't use an iterator, because
-  // we will be adding values to the end of the vector, and it could be
-  // reallocated.  Also, we don't want to process the values that we add.
+  // function to contained the transformed code.  Also, be careful to not
+  // process the values that we add.
   //
-  unsigned NumFunctions = M->size();
-  for (unsigned i = 0; i < NumFunctions; ++i) {
-    Function *Meth = M->begin()[i];
-
-    if (!Meth->isExternal()) {
+  for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
+    if (!I->isExternal()) {
       const FunctionType *NewMTy = 
-        cast<FunctionType>(ConvertType(Meth->getFunctionType()));
+        cast<FunctionType>(ConvertType(I->getFunctionType()));
       
       // Create a new function to put stuff into...
-      Function *NewMeth = new Function(NewMTy, Meth->hasInternalLinkage(),
-                                  Meth->getName());
-      if (Meth->hasName())
-        Meth->setName("OLD."+Meth->getName());
+      Function *NewMeth = new Function(NewMTy, I->hasInternalLinkage(),
+                                       I->getName());
+      if (I->hasName())
+        I->setName("OLD."+I->getName());
 
-      // Insert the new function into the method list... to be filled in later..
-      M->getFunctionList().push_back(NewMeth);
+      // Insert the new function into the function list... to be filled in later
+      M.getFunctionList().push_back(NewMeth);
       
       // Keep track of the association...
-      GlobalMap[Meth] = NewMeth;
+      GlobalMap[I] = NewMeth;
     }
-  }
 
   // TODO: HANDLE GLOBAL VARIABLES
 
   // Remap the symbol table to refer to the types in a nice way
   //
-  if (M->hasSymbolTable()) {
-    SymbolTable *ST = M->getSymbolTable();
+  if (SymbolTable *ST = M.getSymbolTable()) {    
     SymbolTable::iterator I = ST->find(Type::TypeTy);
     if (I != ST->end()) {    // Get the type plane for Type's
       SymbolTable::VarMap &Plane = I->second;
       for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
            TI != TE; ++TI) {
-        // This is gross, I'm reaching right into a symbol table and mucking
-        // around with it's internals... but oh well.
+        // FIXME: This is gross, I'm reaching right into a symbol table and
+        // mucking around with it's internals... but oh well.
         //
-        TI->second = cast<Type>(ConvertType(cast<Type>(TI->second)));
+        TI->second = (Value*)cast<Type>(ConvertType(cast<Type>(TI->second)));
       }
     }
   }
@@ -303,20 +284,20 @@ void MutateStructTypes::processGlobals(Module *M) {
 
 // removeDeadGlobals - For this pass, all this does is remove the old versions
 // of the functions and global variables that we no longer need.
-void MutateStructTypes::removeDeadGlobals(Module *M) {
+void MutateStructTypes::removeDeadGlobals(Module &M) {
   // Prepare for deletion of globals by dropping their interdependencies...
-  for(Module::iterator I = M->begin(); I != M->end(); ++I) {
-    if (GlobalMap.find(*I) != GlobalMap.end())
-      (*I)->Function::dropAllReferences();
+  for(Module::iterator I = M.begin(); I != M.end(); ++I) {
+    if (GlobalMap.find(I) != GlobalMap.end())
+      I->dropAllReferences();
   }
 
   // Run through and delete the functions and global variables...
 #if 0  // TODO: HANDLE GLOBAL VARIABLES
-  M->getGlobalList().delete_span(M->gbegin(), M->gbegin()+NumGVars/2);
+  M->getGlobalList().delete_span(M.gbegin(), M.gbegin()+NumGVars/2);
 #endif
-  for(Module::iterator I = M->begin(); I != M->end();) {
-    if (GlobalMap.find(*I) != GlobalMap.end())
-      delete M->getFunctionList().remove(I);
+  for(Module::iterator I = M.begin(); I != M.end();) {
+    if (GlobalMap.find(I) != GlobalMap.end())
+      I = M.getFunctionList().erase(I);
     else
       ++I;
   }
@@ -324,10 +305,10 @@ void MutateStructTypes::removeDeadGlobals(Module *M) {
 
 
 
-// transformMethod - This transforms the instructions of the function to use the
-// new types.
+// transformFunction - This transforms the instructions of the function to use
+// the new types.
 //
-void MutateStructTypes::transformMethod(Function *m) {
+void MutateStructTypes::transformFunction(Function *m) {
   const Function *M = m;
   map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
   if (GMI == GlobalMap.end())
@@ -336,55 +317,50 @@ void MutateStructTypes::transformMethod(Function *m) {
   Function *NewMeth = cast<Function>(GMI->second);
 
   // Okay, first order of business, create the arguments...
-  for (unsigned i = 0, e = M->getArgumentList().size(); i != e; ++i) {
-    const FunctionArgument *OFA = M->getArgumentList()[i];
-    FunctionArgument *NFA = new FunctionArgument(ConvertType(OFA->getType()),
-                                                 OFA->getName());
+  for (Function::aiterator I = m->abegin(), E = m->aend(); I != E; ++I) {
+    Argument *NFA = new Argument(ConvertType(I->getType()), I->getName());
     NewMeth->getArgumentList().push_back(NFA);
-    LocalValueMap[OFA] = NFA; // Keep track of value mapping
+    LocalValueMap[I] = NFA; // Keep track of value mapping
   }
 
 
   // Loop over all of the basic blocks copying instructions over...
-  for (Function::const_iterator BBI = M->begin(), BBE = M->end(); BBI != BBE;
-       ++BBI) {
-
+  for (Function::const_iterator BB = M->begin(), BBE = M->end(); BB != BBE;
+       ++BB) {
     // Create a new basic block and establish a mapping between the old and new
-    const BasicBlock *BB = *BBI;
     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
-    NewMeth->getBasicBlocks().push_back(NewBB);  // Add block to function
+    NewMeth->getBasicBlockList().push_back(NewBB);  // Add block to function
 
     // Copy over all of the instructions in the basic block...
     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
          II != IE; ++II) {
 
-      const Instruction *I = *II;   // Get the current instruction...
+      const Instruction &I = *II;   // Get the current instruction...
       Instruction *NewI = 0;
 
-      switch (I->getOpcode()) {
+      switch (I.getOpcode()) {
         // Terminator Instructions
       case Instruction::Ret:
         NewI = new ReturnInst(
-                   ConvertValue(cast<ReturnInst>(I)->getReturnValue()));
+                   ConvertValue(cast<ReturnInst>(I).getReturnValue()));
         break;
       case Instruction::Br: {
-        const BranchInst *BI = cast<BranchInst>(I);
-        NewI = new BranchInst(
-                           cast<BasicBlock>(ConvertValue(BI->getSuccessor(0))),
-                    cast_or_null<BasicBlock>(ConvertValue(BI->getSuccessor(1))),
-                              ConvertValue(BI->getCondition()));
+        const BranchInst &BI = cast<BranchInst>(I);
+        if (BI.isConditional()) {
+          NewI =
+              new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))),
+                             cast<BasicBlock>(ConvertValue(BI.getSuccessor(1))),
+                             ConvertValue(BI.getCondition()));
+        } else {
+          NewI = 
+            new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))));
+        }
         break;
       }
       case Instruction::Switch:
       case Instruction::Invoke:
         assert(0 && "Insn not implemented!");
 
-        // Unary Instructions
-      case Instruction::Not:
-        NewI = UnaryOperator::create((Instruction::UnaryOps)I->getOpcode(),
-                                     ConvertValue(I->getOperand(0)));
-        break;
-
         // Binary Instructions
       case Instruction::Add:
       case Instruction::Sub:
@@ -403,78 +379,76 @@ void MutateStructTypes::transformMethod(Function *m) {
       case Instruction::SetGE:
       case Instruction::SetLT:
       case Instruction::SetGT:
-        NewI = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(),
-                                      ConvertValue(I->getOperand(0)),
-                                      ConvertValue(I->getOperand(1)));
+        NewI = BinaryOperator::create((Instruction::BinaryOps)I.getOpcode(),
+                                      ConvertValue(I.getOperand(0)),
+                                      ConvertValue(I.getOperand(1)));
         break;
 
       case Instruction::Shr:
       case Instruction::Shl:
-        NewI = new ShiftInst(cast<ShiftInst>(I)->getOpcode(),
-                             ConvertValue(I->getOperand(0)),
-                             ConvertValue(I->getOperand(1)));
+        NewI = new ShiftInst(cast<ShiftInst>(I).getOpcode(),
+                             ConvertValue(I.getOperand(0)),
+                             ConvertValue(I.getOperand(1)));
         break;
 
 
         // Memory Instructions
       case Instruction::Alloca:
         NewI = 
-          new AllocaInst(ConvertType(I->getType()),
-                         I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
+          new MallocInst(
+                  ConvertType(cast<PointerType>(I.getType())->getElementType()),
+                         I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
         break;
       case Instruction::Malloc:
         NewI = 
-          new MallocInst(ConvertType(I->getType()),
-                         I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
+          new MallocInst(
+                  ConvertType(cast<PointerType>(I.getType())->getElementType()),
+                         I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
         break;
 
       case Instruction::Free:
-        NewI = new FreeInst(ConvertValue(I->getOperand(0)));
+        NewI = new FreeInst(ConvertValue(I.getOperand(0)));
         break;
 
       case Instruction::Load:
+        NewI = new LoadInst(ConvertValue(I.getOperand(0)));
+        break;
       case Instruction::Store:
+        NewI = new StoreInst(ConvertValue(I.getOperand(0)),
+                             ConvertValue(I.getOperand(1)));
+        break;
       case Instruction::GetElementPtr: {
-        const MemAccessInst *MAI = cast<MemAccessInst>(I);
-        vector<Value*> Indices(MAI->idx_begin(), MAI->idx_end());
-        const Value *Ptr = MAI->getPointerOperand();
-        Value *NewPtr = ConvertValue(Ptr);
+        const GetElementPtrInst &GEP = cast<GetElementPtrInst>(I);
+        vector<Value*> Indices(GEP.idx_begin(), GEP.idx_end());
         if (!Indices.empty()) {
-          const Type *PTy = cast<PointerType>(Ptr->getType())->getElementType();
+          const Type *PTy =
+            cast<PointerType>(GEP.getOperand(0)->getType())->getElementType();
           AdjustIndices(cast<CompositeType>(PTy), Indices);
         }
 
-        if (isa<LoadInst>(I)) {
-          NewI = new LoadInst(NewPtr, Indices);
-        } else if (isa<StoreInst>(I)) {
-          NewI = new StoreInst(ConvertValue(I->getOperand(0)), NewPtr, Indices);
-        } else if (isa<GetElementPtrInst>(I)) {
-          NewI = new GetElementPtrInst(NewPtr, Indices);
-        } else {
-          assert(0 && "Unknown memory access inst!!!");
-        }
+        NewI = new GetElementPtrInst(ConvertValue(GEP.getOperand(0)), Indices);
         break;
       }
 
         // Miscellaneous Instructions
       case Instruction::PHINode: {
-        const PHINode *OldPN = cast<PHINode>(I);
-        PHINode *PN = new PHINode(ConvertType(I->getType()));
-        for (unsigned i = 0; i < OldPN->getNumIncomingValues(); ++i)
-          PN->addIncoming(ConvertValue(OldPN->getIncomingValue(i)),
-                    cast<BasicBlock>(ConvertValue(OldPN->getIncomingBlock(i))));
+        const PHINode &OldPN = cast<PHINode>(I);
+        PHINode *PN = new PHINode(ConvertType(OldPN.getType()));
+        for (unsigned i = 0; i < OldPN.getNumIncomingValues(); ++i)
+          PN->addIncoming(ConvertValue(OldPN.getIncomingValue(i)),
+                    cast<BasicBlock>(ConvertValue(OldPN.getIncomingBlock(i))));
         NewI = PN;
         break;
       }
       case Instruction::Cast:
-        NewI = new CastInst(ConvertValue(I->getOperand(0)),
-                            ConvertType(I->getType()));
+        NewI = new CastInst(ConvertValue(I.getOperand(0)),
+                            ConvertType(I.getType()));
         break;
       case Instruction::Call: {
-        Value *Meth = ConvertValue(I->getOperand(0));
+        Value *Meth = ConvertValue(I.getOperand(0));
         vector<Value*> Operands;
-        for (unsigned i = 1; i < I->getNumOperands(); ++i)
-          Operands.push_back(ConvertValue(I->getOperand(i)));
+        for (unsigned i = 1; i < I.getNumOperands(); ++i)
+          Operands.push_back(ConvertValue(I.getOperand(i)));
         NewI = new CallInst(Meth, Operands);
         break;
       }
@@ -484,11 +458,11 @@ void MutateStructTypes::transformMethod(Function *m) {
         break;
       }
 
-      NewI->setName(I->getName());
+      NewI->setName(I.getName());
       NewBB->getInstList().push_back(NewI);
 
       // Check to see if we had to make a placeholder for this value...
-      map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(I);
+      map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(&I);
       if (LVMI != LocalValueMap.end()) {
         // Yup, make sure it's a placeholder...
         Instruction *I = cast<Instruction>(LVMI->second);
@@ -501,7 +475,7 @@ void MutateStructTypes::transformMethod(Function *m) {
 
       // Keep track of the fact the the local implementation of this instruction
       // is NewI.
-      LocalValueMap[I] = NewI;
+      LocalValueMap[&I] = NewI;
     }
   }
 
@@ -509,23 +483,13 @@ void MutateStructTypes::transformMethod(Function *m) {
 }
 
 
-bool MutateStructTypes::run(Module *M) {
+bool MutateStructTypes::run(Module &M) {
   processGlobals(M);
 
-  for_each(M->begin(), M->end(),
-           bind_obj(this, &MutateStructTypes::transformMethod));
+  for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
+    transformFunction(I);
 
   removeDeadGlobals(M);
   return true;
 }
 
-// getAnalysisUsageInfo - This function needs the results of the
-// FindUsedTypes and FindUnsafePointerTypes analysis passes...
-//
-void MutateStructTypes::getAnalysisUsageInfo(Pass::AnalysisSet &Required,
-                                             Pass::AnalysisSet &Destroyed,
-                                             Pass::AnalysisSet &Provided) {
-  Destroyed.push_back(FindUsedTypes::ID);
-  Destroyed.push_back(FindUnsafePointerTypes::ID);
-  Destroyed.push_back(CallGraph::ID);
-}