Add support for the new va_arg instruction
[oota-llvm.git] / lib / Transforms / IPO / FunctionResolution.cpp
index cbcfa2aba853c74070765c57ade30b5319fac846..1b1065e3e926da092b31307e1cc9c094ae687b5a 100644 (file)
 #include "llvm/Pass.h"
 #include "llvm/iOther.h"
 #include "llvm/Constants.h"
+#include "llvm/Assembly/Writer.h"  // FIXME: remove when varargs implemented
 #include "Support/Statistic.h"
 #include <algorithm>
 
-using std::vector;
-using std::string;
-using std::cerr;
-
 namespace {
   Statistic<>NumResolved("funcresolve", "Number of varargs functions resolved");
   Statistic<> NumGlobals("funcresolve", "Number of global variables resolved");
@@ -50,19 +47,32 @@ static void ConvertCallTo(CallInst *CI, Function *Dest) {
   // Keep an iterator to where we want to insert cast instructions if the
   // argument types don't agree.
   //
-  BasicBlock::iterator BBI = CI;
-  assert(CI->getNumOperands()-1 == ParamTys.size() &&
-         "Function calls resolved funny somehow, incompatible number of args");
+  unsigned NumArgsToCopy = CI->getNumOperands()-1;
+  if (NumArgsToCopy != ParamTys.size() &&
+      !(NumArgsToCopy > ParamTys.size() &&
+        Dest->getFunctionType()->isVarArg())) {
+    std::cerr << "WARNING: Call arguments do not match expected number of"
+              << " parameters.\n";
+    std::cerr << "WARNING: In function '"
+              << CI->getParent()->getParent()->getName() << "': call: " << *CI;
+    std::cerr << "Function resolved to: ";
+    WriteAsOperand(std::cerr, Dest);
+    std::cerr << "\n";
+    if (NumArgsToCopy > ParamTys.size())
+      NumArgsToCopy = ParamTys.size();
+  }
 
-  vector<Value*> Params;
+  std::vector<Value*> Params;
 
   // Convert all of the call arguments over... inserting cast instructions if
   // the types are not compatible.
-  for (unsigned i = 1; i < CI->getNumOperands(); ++i) {
+  for (unsigned i = 1; i <= NumArgsToCopy; ++i) {
     Value *V = CI->getOperand(i);
 
-    if (V->getType() != ParamTys[i-1])  // Must insert a cast...
-      V = new CastInst(V, ParamTys[i-1], "argcast", BBI);
+    if (i-1 < ParamTys.size() && V->getType() != ParamTys[i-1]) {
+      // Must insert a cast...
+      V = new CastInst(V, ParamTys[i-1], "argcast", CI);
+    }
 
     Params.push_back(V);
   }
@@ -70,20 +80,18 @@ static void ConvertCallTo(CallInst *CI, Function *Dest) {
   // Replace the old call instruction with a new call instruction that calls
   // the real function.
   //
-  Instruction *NewCall = new CallInst(Dest, Params, "", BBI);
-
-  // Remove the old call instruction from the program...
-  BB->getInstList().remove(BBI);
+  Instruction *NewCall = new CallInst(Dest, Params, "", CI);
+  std::string Name = CI->getName(); CI->setName("");
 
   // Transfer the name over...
   if (NewCall->getType() != Type::VoidTy)
-    NewCall->setName(CI->getName());
+    NewCall->setName(Name);
 
   // Replace uses of the old instruction with the appropriate values...
   //
   if (NewCall->getType() == CI->getType()) {
     CI->replaceAllUsesWith(NewCall);
-    NewCall->setName(CI->getName());
+    NewCall->setName(Name);
 
   } else if (NewCall->getType() == Type::VoidTy) {
     // Resolved function does not return a value but the prototype does.  This
@@ -101,18 +109,17 @@ static void ConvertCallTo(CallInst *CI, Function *Dest) {
     //
     if (!CI->use_empty()) {
       // Insert the new cast instruction...
-      CastInst *NewCast = new CastInst(NewCall, CI->getType(),
-                                       NewCall->getName(), BBI);
+      CastInst *NewCast = new CastInst(NewCall, CI->getType(), Name, CI);
       CI->replaceAllUsesWith(NewCast);
     }
   }
 
   // The old instruction is no longer needed, destroy it!
-  delete CI;
+  BB->getInstList().erase(CI);
 }
 
 
-static bool ResolveFunctions(Module &M, vector<GlobalValue*> &Globals,
+static bool ResolveFunctions(Module &M, std::vector<GlobalValue*> &Globals,
                              Function *Concrete) {
   bool Changed = false;
   for (unsigned i = 0; i != Globals.size(); ++i)
@@ -121,28 +128,40 @@ static bool ResolveFunctions(Module &M, vector<GlobalValue*> &Globals,
       const FunctionType *OldMT = Old->getFunctionType();
       const FunctionType *ConcreteMT = Concrete->getFunctionType();
       
-      assert(OldMT->getParamTypes().size() <=
-             ConcreteMT->getParamTypes().size() &&
-             "Concrete type must have more specified parameters!");
+      if (OldMT->getParamTypes().size() > ConcreteMT->getParamTypes().size() &&
+          !ConcreteMT->isVarArg())
+        if (!Old->use_empty()) {
+          std::cerr << "WARNING: Linking function '" << Old->getName()
+                    << "' is causing arguments to be dropped.\n";
+          std::cerr << "WARNING: Prototype: ";
+          WriteAsOperand(std::cerr, Old);
+          std::cerr << " resolved to ";
+          WriteAsOperand(std::cerr, Concrete);
+          std::cerr << "\n";
+        }
       
       // Check to make sure that if there are specified types, that they
       // match...
       //
-      for (unsigned i = 0; i < OldMT->getParamTypes().size(); ++i)
-        if (OldMT->getParamTypes()[i] != ConcreteMT->getParamTypes()[i]) {
-          cerr << "Parameter types conflict for: '" << OldMT
-               << "' and '" << ConcreteMT << "'\n";
-          return Changed;
-        }
+      unsigned NumArguments = std::min(OldMT->getParamTypes().size(),
+                                       ConcreteMT->getParamTypes().size());
+
+      if (!Old->use_empty() && !Concrete->use_empty())
+        for (unsigned i = 0; i < NumArguments; ++i)
+          if (OldMT->getParamTypes()[i] != ConcreteMT->getParamTypes()[i]) {
+            std::cerr << "WARNING: Function [" << Old->getName()
+                      << "]: Parameter types conflict for: '" << OldMT
+                      << "' and '" << ConcreteMT << "'\n";
+            return Changed;
+          }
       
-      // Attempt to convert all of the uses of the old function to the
-      // concrete form of the function.  If there is a use of the fn that
-      // we don't understand here we punt to avoid making a bad
-      // transformation.
+      // Attempt to convert all of the uses of the old function to the concrete
+      // form of the function.  If there is a use of the fn that we don't
+      // understand here we punt to avoid making a bad transformation.
       //
-      // At this point, we know that the return values are the same for
-      // our two functions and that the Old function has no varargs fns
-      // specified.  In otherwords it's just <retty> (...)
+      // At this point, we know that the return values are the same for our two
+      // functions and that the Old function has no varargs fns specified.  In
+      // otherwords it's just <retty> (...)
       //
       for (unsigned i = 0; i < Old->use_size(); ) {
         User *U = *(Old->use_begin()+i);
@@ -159,12 +178,24 @@ static bool ResolveFunctions(Module &M, vector<GlobalValue*> &Globals,
             Changed = true;
             ++NumResolved;
           } else {
-            cerr << "Couldn't cleanup this function call, must be an"
-                 << " argument or something!" << CI;
+            std::cerr << "Couldn't cleanup this function call, must be an"
+                      << " argument or something!" << CI;
+            ++i;
+          }
+        } else if (ConstantPointerRef *CPR = dyn_cast<ConstantPointerRef>(U)) {
+          if (CPR->use_size() == 1 && isa<ConstantExpr>(CPR->use_back()) &&
+              cast<ConstantExpr>(CPR->use_back())->getOpcode() == 
+                Instruction::Cast) {
+            ConstantExpr *CE = cast<ConstantExpr>(CPR->use_back());
+            Constant *NewCPR = ConstantPointerRef::get(Concrete);
+            CE->replaceAllUsesWith(ConstantExpr::getCast(NewCPR,CE->getType()));
+            CPR->destroyConstant();
+          } else {
+            std::cerr << "Cannot convert use of function: " << CPR << "\n";
             ++i;
           }
         } else {
-          cerr << "Cannot convert use of function: " << U << "\n";
+          std::cerr << "Cannot convert use of function: " << U << "\n";
           ++i;
         }
       }
@@ -173,37 +204,31 @@ static bool ResolveFunctions(Module &M, vector<GlobalValue*> &Globals,
 }
 
 
-static bool ResolveGlobalVariables(Module &M, vector<GlobalValue*> &Globals,
+static bool ResolveGlobalVariables(Module &M,
+                                   std::vector<GlobalValue*> &Globals,
                                    GlobalVariable *Concrete) {
   bool Changed = false;
   assert(isa<ArrayType>(Concrete->getType()->getElementType()) &&
          "Concrete version should be an array type!");
 
   // Get the type of the things that may be resolved to us...
-  const Type *AETy =
-    cast<ArrayType>(Concrete->getType()->getElementType())->getElementType();
-
-  std::vector<Constant*> Args;
-  Args.push_back(Constant::getNullValue(Type::LongTy));
-  Args.push_back(Constant::getNullValue(Type::LongTy));
-  ConstantExpr *Replacement =
-    ConstantExpr::getGetElementPtr(ConstantPointerRef::get(Concrete), Args);
-  
+  const ArrayType *CATy =cast<ArrayType>(Concrete->getType()->getElementType());
+  const Type *AETy = CATy->getElementType();
+
+  Constant *CCPR = ConstantPointerRef::get(Concrete);
+
   for (unsigned i = 0; i != Globals.size(); ++i)
     if (Globals[i] != Concrete) {
       GlobalVariable *Old = cast<GlobalVariable>(Globals[i]);
-      if (Old->getType()->getElementType() != AETy) {
+      const ArrayType *OATy = cast<ArrayType>(Old->getType()->getElementType());
+      if (OATy->getElementType() != AETy || OATy->getNumElements() != 0) {
         std::cerr << "WARNING: Two global variables exist with the same name "
                   << "that cannot be resolved!\n";
         return false;
       }
 
-      // In this case, Old is a pointer to T, Concrete is a pointer to array of
-      // T.  Because of this, replace all uses of Old with a constantexpr
-      // getelementptr that returns the address of the first element of the
-      // array.
-      //
-      Old->replaceAllUsesWith(Replacement);
+      Old->replaceAllUsesWith(ConstantExpr::getCast(CCPR, Old->getType()));
+
       // Since there are no uses of Old anymore, remove it from the module.
       M.getGlobalList().erase(Old);
 
@@ -214,11 +239,10 @@ static bool ResolveGlobalVariables(Module &M, vector<GlobalValue*> &Globals,
 }
 
 static bool ProcessGlobalsWithSameName(Module &M,
-                                       vector<GlobalValue*> &Globals) {
+                                       std::vector<GlobalValue*> &Globals) {
   assert(!Globals.empty() && "Globals list shouldn't be empty here!");
 
   bool isFunction = isa<Function>(Globals[0]);   // Is this group all functions?
-  bool Changed = false;
   GlobalValue *Concrete = 0;  // The most concrete implementation to resolve to
 
   assert((isFunction ^ isa<GlobalVariable>(Globals[0])) &&
@@ -250,32 +274,36 @@ static bool ProcessGlobalsWithSameName(Module &M,
       } else {
         Concrete = F;
       }
-      ++i;
     } else {
       // For global variables, we have to merge C definitions int A[][4] with
-      // int[6][4]
+      // int[6][4].  A[][4] is represented as A[0][4] by the CFE.
       GlobalVariable *GV = cast<GlobalVariable>(Globals[i]);
-      if (Concrete == 0) {
-        if (isa<ArrayType>(GV->getType()->getElementType()))
-          Concrete = GV;
-      } else {    // Must have different types... one is an array of the other?
-        const ArrayType *AT =
-          dyn_cast<ArrayType>(GV->getType()->getElementType());
-
-        // If GV is an array of Concrete, then GV is the array.
-        if (AT && AT->getElementType() == Concrete->getType()->getElementType())
-          Concrete = GV;
-        else {
-          // Concrete must be an array type, check to see if the element type of
-          // concrete is already GV.
-          AT = cast<ArrayType>(Concrete->getType()->getElementType());
-          if (AT->getElementType() != GV->getType()->getElementType())
-            Concrete = 0;           // Don't know how to handle it!
+      if (!isa<ArrayType>(GV->getType()->getElementType())) {
+        Concrete = 0;
+        break;  // Non array's cannot be compatible with other types.
+      } else if (Concrete == 0) {
+        Concrete = GV;
+      } else {
+        // Must have different types... allow merging A[0][4] w/ A[6][4] if
+        // A[0][4] is external.
+        const ArrayType *NAT = cast<ArrayType>(GV->getType()->getElementType());
+        const ArrayType *CAT =
+          cast<ArrayType>(Concrete->getType()->getElementType());
+
+        if (NAT->getElementType() != CAT->getElementType()) {
+          Concrete = 0;  // Non-compatible types
+          break;
+        } else if (NAT->getNumElements() == 0 && GV->isExternal()) {
+          // Concrete remains the same
+        } else if (CAT->getNumElements() == 0 && Concrete->isExternal()) {
+          Concrete = GV;   // Concrete becomes GV
+        } else {
+          Concrete = 0;    // Cannot merge these types...
+          break;
         }
       }
-      
-      ++i;
     }
+    ++i;
   }
 
   if (Globals.size() > 1) {         // Found a multiply defined global...
@@ -284,36 +312,35 @@ static bool ProcessGlobalsWithSameName(Module &M,
     // uses to use it instead.
     //
     if (!Concrete) {
-      cerr << "WARNING: Found function types that are not compatible:\n";
+      std::cerr << "WARNING: Found global types that are not compatible:\n";
       for (unsigned i = 0; i < Globals.size(); ++i) {
-        cerr << "\t" << Globals[i]->getType()->getDescription() << " %"
-             << Globals[i]->getName() << "\n";
+        std::cerr << "\t" << Globals[i]->getType()->getDescription() << " %"
+                  << Globals[i]->getName() << "\n";
       }
-      cerr << "  No linkage of globals named '" << Globals[0]->getName()
-           << "' performed!\n";
-      return Changed;
+      std::cerr << "  No linkage of globals named '" << Globals[0]->getName()
+                << "' performed!\n";
+      return false;
     }
 
     if (isFunction)
-      return Changed | ResolveFunctions(M, Globals, cast<Function>(Concrete));
+      return ResolveFunctions(M, Globals, cast<Function>(Concrete));
     else
-      return Changed | ResolveGlobalVariables(M, Globals,
-                                              cast<GlobalVariable>(Concrete));
+      return ResolveGlobalVariables(M, Globals,
+                                    cast<GlobalVariable>(Concrete));
   }
-  return Changed;
+  return false;
 }
 
 bool FunctionResolvingPass::run(Module &M) {
-  SymbolTable *ST = M.getSymbolTable();
-  if (!ST) return false;
+  SymbolTable &ST = M.getSymbolTable();
 
-  std::map<string, vector<GlobalValue*> > Globals;
+  std::map<std::string, std::vector<GlobalValue*> > Globals;
 
   // Loop over the entries in the symbol table. If an entry is a func pointer,
   // then add it to the Functions map.  We do a two pass algorithm here to avoid
   // problems with iterators getting invalidated if we did a one pass scheme.
   //
-  for (SymbolTable::iterator I = ST->begin(), E = ST->end(); I != E; ++I)
+  for (SymbolTable::iterator I = ST.begin(), E = ST.end(); I != E; ++I)
     if (const PointerType *PT = dyn_cast<PointerType>(I->first)) {
       SymbolTable::VarMap &Plane = I->second;
       for (SymbolTable::type_iterator PI = Plane.begin(), PE = Plane.end();
@@ -321,7 +348,7 @@ bool FunctionResolvingPass::run(Module &M) {
         GlobalValue *GV = cast<GlobalValue>(PI->second);
         assert(PI->first == GV->getName() &&
                "Global name and symbol table do not agree!");
-        if (GV->hasExternalLinkage())  // Only resolve decls to external fns
+        if (!GV->hasInternalLinkage())  // Only resolve decls to external fns
           Globals[PI->first].push_back(GV);
       }
     }
@@ -331,8 +358,8 @@ bool FunctionResolvingPass::run(Module &M) {
   // Now we have a list of all functions with a particular name.  If there is
   // more than one entry in a list, merge the functions together.
   //
-  for (std::map<string, vector<GlobalValue*> >::iterator I = Globals.begin(), 
-         E = Globals.end(); I != E; ++I)
+  for (std::map<std::string, std::vector<GlobalValue*> >::iterator
+         I = Globals.begin(), E = Globals.end(); I != E; ++I)
     Changed |= ProcessGlobalsWithSameName(M, I->second);
 
   // Now loop over all of the globals, checking to see if any are trivially