Hrm, operator new and new[] do not belong here. We should not CSE them! :)
[oota-llvm.git] / tools / bugpoint / ExtractFunction.cpp
index 41b5641ff50630db69a7d4a01184087bca276b17..9f8dca8abd0e0e3a2a5b85d51c43be707c6941d6 100644 (file)
@@ -7,8 +7,8 @@
 // 
 //===----------------------------------------------------------------------===//
 //
-// This file implements a method that extracts a function from program, cleans
-// it up, and returns it as a new module.
+// This file implements several methods that are used to extract functions,
+// loops, or portions of a module from the rest of the module.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,7 @@
 #include "Support/CommandLine.h"
 #include "Support/Debug.h"
 #include "Support/FileUtilities.h"
+#include <set>
 using namespace llvm;
 
 namespace llvm {
@@ -76,6 +77,8 @@ Module *BugDriver::deleteInstructionFromProgram(const Instruction *I,
   // Make sure that the appropriate target data is always used...
   Passes.add(new TargetData("bugpoint", Result));
 
+  /// FIXME: If this used runPasses() like the methods below, we could get rid
+  /// of the -disable-* options!
   if (Simplification > 1 && !NoDCE)
     Passes.add(createDeadCodeEliminationPass());
   if (Simplification && !DisableSimplifyCFG)
@@ -110,25 +113,41 @@ Module *BugDriver::performFinalCleanups(Module *M, bool MayModifySemantics) {
     CleanupPasses.push_back(getPI(createDeadArgHackingPass()));
   else
     CleanupPasses.push_back(getPI(createDeadArgEliminationPass()));
-  
-  std::swap(Program, M);
-  std::string Filename;
-  bool Failed = runPasses(CleanupPasses, Filename);
-  std::swap(Program, M);
-
-  if (Failed) {
-    std::cerr << "Final cleanups failed.  Sorry.  :(\n";
-  } else {
-    delete M;
-    M = ParseInputFile(Filename);
-    if (M == 0) {
-      std::cerr << getToolName() << ": Error reading bytecode file '"
-                << Filename << "'!\n";
-      exit(1);
-    }
-    removeFile(Filename);
+
+  Module *New = runPassesOn(M, CleanupPasses);
+  if (New == 0) {
+    std::cerr << "Final cleanups failed.  Sorry. :(  Please report a bug!\n";
+  }
+  delete M;
+  return New;
+}
+
+
+/// ExtractLoop - Given a module, extract up to one loop from it into a new
+/// function.  This returns null if there are no extractable loops in the
+/// program or if the loop extractor crashes.
+Module *BugDriver::ExtractLoop(Module *M) {
+  std::vector<const PassInfo*> LoopExtractPasses;
+  LoopExtractPasses.push_back(getPI(createSingleLoopExtractorPass()));
+
+  Module *NewM = runPassesOn(M, LoopExtractPasses);
+  if (NewM == 0) {
+    Module *Old = swapProgramIn(M);
+    std::cout << "*** Loop extraction failed: ";
+    EmitProgressBytecode("loopextraction", true);
+    std::cout << "*** Sorry. :(  Please report a bug!\n";
+    swapProgramIn(Old);
+    return 0;
+  }
+
+  // Check to see if we created any new functions.  If not, no loops were
+  // extracted and we should return null.
+  if (M->size() == NewM->size()) {
+    delete NewM;
+    return 0;
   }
-  return M;
+  
+  return NewM;
 }
 
 
@@ -165,7 +184,9 @@ Module *llvm::SplitFunctionsOutOfModule(Module *M,
     I->setInitializer(0);  // Delete the initializer to make it external
 
   // Remove the Test functions from the Safe module
+  std::set<std::pair<std::string, const PointerType*> > TestFunctions;
   for (unsigned i = 0, e = F.size(); i != e; ++i) {
+    TestFunctions.insert(std::make_pair(F[i]->getName(), F[i]->getType()));
     Function *TNOF = M->getFunction(F[i]->getName(), F[i]->getFunctionType());
     DEBUG(std::cerr << "Removing function " << F[i]->getName() << "\n");
     assert(TNOF && "Function doesn't exist in module!");
@@ -173,14 +194,8 @@ Module *llvm::SplitFunctionsOutOfModule(Module *M,
   }
 
   // Remove the Safe functions from the Test module
-  for (Module::iterator I = New->begin(), E = New->end(); I != E; ++I) {
-    bool funcFound = false;
-    for (std::vector<Function*>::const_iterator FI = F.begin(), Fe = F.end();
-         FI != Fe; ++FI)
-      if (I->getName() == (*FI)->getName()) funcFound = true;
-
-    if (!funcFound)
+  for (Module::iterator I = New->begin(), E = New->end(); I != E; ++I)
+    if (!TestFunctions.count(std::make_pair(I->getName(), I->getType())))
       DeleteFunctionBody(I);
-  }
   return New;
 }