fix grammaro's pointed out by daniel
[oota-llvm.git] / lib / Transforms / IPO / PartialSpecialization.cpp
index 1774c97f8cb9124a3d0b41990bd769425538d6a7..084b94e53566132bcfab5a3255c4abc9654b3fb4 100644 (file)
 // This pass finds function arguments that are often a common constant and 
 // specializes a version of the called function for that constant.
 //
-// The initial heuristic favors constant arguments that used in control flow.
+// This pass simply does the cloning for functions it specializes.  It depends
+// on IPSCCP and DAE to clean up the results.
+//
+// The initial heuristic favors constant arguments that are used in control 
+// flow.
 //
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Pass.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Transforms/Utils/Cloning.h"
-#include "llvm/Support/Compiler.h"
+#include "llvm/Support/CallSite.h"
+#include "llvm/ADT/DenseSet.h"
 #include <map>
-#include <vector>
 using namespace llvm;
 
 STATISTIC(numSpecialized, "Number of specialized functions created");
 
-//Call must be used at least occasionally
+// Call must be used at least occasionally
 static const int CallsMin = 5;
-//Must have 10% of calls having the same constant to specialize on
+
+// Must have 10% of calls having the same constant to specialize on
 static const double ConstValPercent = .1;
 
 namespace {
-  class VISIBILITY_HIDDEN PartSpec : public ModulePass {
-    void scanForInterest(Function&, std::vector<int>&);
-    void replaceUsersFor(Function&, int, Constant*, Function*);
+  class PartSpec : public ModulePass {
+    void scanForInterest(Function&, SmallVector<int, 6>&);
     int scanDistribution(Function&, int, std::map<Constant*, int>&);
   public :
     static char ID; // Pass identification, replacement for typeid
-    PartSpec() : ModulePass((intptr_t)&ID) {}
+    PartSpec() : ModulePass(&ID) {}
     bool runOnModule(Module &M);
   };
 }
@@ -50,46 +54,112 @@ char PartSpec::ID = 0;
 static RegisterPass<PartSpec>
 X("partialspecialization", "Partial Specialization");
 
+// Specialize F by replacing the arguments (keys) in replacements with the 
+// constants (values).  Replace all calls to F with those constants with
+// a call to the specialized function.  Returns the specialized function
+static Function* 
+SpecializeFunction(Function* F, 
+                   DenseMap<const Value*, Value*>& replacements) {
+  // arg numbers of deleted arguments
+  DenseSet<unsigned> deleted;
+  for (DenseMap<const Value*, Value*>::iterator 
+         repb = replacements.begin(), repe = replacements.end();
+       repb != repe; ++repb)
+    deleted.insert(cast<Argument>(repb->first)->getArgNo());
+
+  Function* NF = CloneFunction(F, replacements);
+  NF->setLinkage(GlobalValue::InternalLinkage);
+  F->getParent()->getFunctionList().push_back(NF);
+
+  for (Value::use_iterator ii = F->use_begin(), ee = F->use_end(); 
+       ii != ee; ) {
+    Value::use_iterator i = ii;
+    ++ii;
+    if (isa<CallInst>(i) || isa<InvokeInst>(i)) {
+      CallSite CS(cast<Instruction>(i));
+      if (CS.getCalledFunction() == F) {
+        
+        SmallVector<Value*, 6> args;
+        for (unsigned x = 0; x < CS.arg_size(); ++x)
+          if (!deleted.count(x))
+            args.push_back(CS.getArgument(x));
+        Value* NCall;
+        if (CallInst *CI = dyn_cast<CallInst>(i)) {
+          NCall = CallInst::Create(NF, args.begin(), args.end(), 
+                                   CI->getName(), CI);
+          cast<CallInst>(NCall)->setTailCall(CI->isTailCall());
+          cast<CallInst>(NCall)->setCallingConv(CI->getCallingConv());
+        } else {
+          InvokeInst *II = cast<InvokeInst>(i);
+          NCall = InvokeInst::Create(NF, II->getNormalDest(),
+                                     II->getUnwindDest(),
+                                     args.begin(), args.end(), 
+                                     II->getName(), II);
+          cast<InvokeInst>(NCall)->setCallingConv(II->getCallingConv());
+        }
+        CS.getInstruction()->replaceAllUsesWith(NCall);
+        CS.getInstruction()->eraseFromParent();
+      }
+    }
+  }
+  return NF;
+}
+
+
 bool PartSpec::runOnModule(Module &M) {
   bool Changed = false;
   for (Module::iterator I = M.begin(); I != M.end(); ++I) {
     Function &F = *I;
-    if (!F.isDeclaration()) {
-      std::vector<int> interestingArgs;
-      scanForInterest(F, interestingArgs);
-      //Find the first interesting Argument that we can specialize on
-      //If there are multiple intersting Arguments, then those will be found
-      //when processing the cloned function.
-      bool breakOuter = false;
-      for (unsigned int x = 0; !breakOuter && x < interestingArgs.size(); ++x) {
-        std::map<Constant*, int> distribution;
-        int total = scanDistribution(F, interestingArgs[x], distribution);
-        if (total > CallsMin) 
-          for (std::map<Constant*, int>::iterator ii = distribution.begin(),
-                 ee = distribution.end(); ii != ee; ++ii)
-            if ( total > ii->second  && ii->first &&
-                 ii->second > total * ConstValPercent ) {
-              Function* NF = CloneFunction(&F);
-              NF->setLinkage(GlobalValue::InternalLinkage);
-              M.getFunctionList().push_back(NF);
-              replaceUsersFor(F, interestingArgs[x], ii->first, NF);
-              breakOuter = true;
-              Changed = true;
-            }
-      }
+    if (F.isDeclaration() || F.mayBeOverridden()) continue;
+    SmallVector<int, 6> interestingArgs;
+    scanForInterest(F, interestingArgs);
+
+    // Find the first interesting Argument that we can specialize on
+    // If there are multiple interesting Arguments, then those will be found
+    // when processing the cloned function.
+    bool breakOuter = false;
+    for (unsigned int x = 0; !breakOuter && x < interestingArgs.size(); ++x) {
+      std::map<Constant*, int> distribution;
+      int total = scanDistribution(F, interestingArgs[x], distribution);
+      if (total > CallsMin) 
+        for (std::map<Constant*, int>::iterator ii = distribution.begin(),
+               ee = distribution.end(); ii != ee; ++ii)
+          if (total > ii->second && ii->first &&
+               ii->second > total * ConstValPercent) {
+            DenseMap<const Value*, Value*> m;
+            Function::arg_iterator arg = F.arg_begin();
+            for (int y = 0; y < interestingArgs[x]; ++y)
+              ++arg;
+            m[&*arg] = ii->first;
+            SpecializeFunction(&F, m);
+            ++numSpecialized;
+            breakOuter = true;
+            Changed = true;
+          }
     }
   }
   return Changed;
 }
 
 /// scanForInterest - This function decides which arguments would be worth
-///                    specializing on.
-void PartSpec::scanForInterest(Function& F, std::vector<int>& args) {
+/// specializing on.
+void PartSpec::scanForInterest(Function& F, SmallVector<int, 6>& args) {
   for(Function::arg_iterator ii = F.arg_begin(), ee = F.arg_end();
       ii != ee; ++ii) {
     for(Value::use_iterator ui = ii->use_begin(), ue = ii->use_end();
         ui != ue; ++ui) {
-      if (isa<CmpInst>(ui)) {
+
+      bool interesting = false;
+
+      if (isa<CmpInst>(ui)) interesting = true;
+      else if (isa<CallInst>(ui))
+        interesting = ui->getOperand(0) == ii;
+      else if (isa<InvokeInst>(ui))
+        interesting = ui->getOperand(0) == ii;
+      else if (isa<SwitchInst>(ui)) interesting = true;
+      else if (isa<BranchInst>(ui)) interesting = true;
+
+      if (interesting) {
         args.push_back(std::distance(F.arg_begin(), ii));
         break;
       }
@@ -97,30 +167,23 @@ void PartSpec::scanForInterest(Function& F, std::vector<int>& args) {
   }
 }
 
-/// replaceUsersFor - Replace direct calls to F with NF if the arg argnum is
-/// the constant val
-void PartSpec::replaceUsersFor(Function& F , int argnum, Constant* val, 
-                               Function* NF) {
-  ++numSpecialized;
-  for(Value::use_iterator ii = F.use_begin(), ee = F.use_end();
-      ii != ee; ++ii)
-    if (CallInst* CI = dyn_cast<CallInst>(ii))
-      if (CI->getOperand(0) == &F && CI->getOperand(argnum + 1) == val)
-        CI->setOperand(0, NF);
-}
-
+/// scanDistribution - Construct a histogram of constants for arg of F at arg.
 int PartSpec::scanDistribution(Function& F, int arg, 
                                std::map<Constant*, int>& dist) {
-  bool hasInd = false;
+  bool hasIndirect = false;
   int total = 0;
   for(Value::use_iterator ii = F.use_begin(), ee = F.use_end();
       ii != ee; ++ii)
-    if (CallInst* CI = dyn_cast<CallInst>(ii)) {
-      ++dist[dyn_cast<Constant>(CI->getOperand(arg + 1))];
+    if ((isa<CallInst>(ii) || isa<InvokeInst>(ii))
+        && ii->getOperand(0) == &F) {
+      ++dist[dyn_cast<Constant>(ii->getOperand(arg + 1))];
       ++total;
     } else
-      hasInd = true;
-  if (hasInd) ++total;
+      hasIndirect = true;
+
+  // Preserve the original address taken function even if all other uses
+  // will be specialized.
+  if (hasIndirect) ++total;
   return total;
 }