Teach legalize to promote SetCC results.
[oota-llvm.git] / tools / bugpoint / ExtractFunction.cpp
index 5119c801a024f368d8f1842c268b5808ae0e41cb..078c7baf4e90e404f0d857693e0f6e93084b3e6a 100644 (file)
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/FunctionUtils.h"
 #include "llvm/Target/TargetData.h"
-#include "Support/CommandLine.h"
-#include "Support/Debug.h"
-#include "Support/FileUtilities.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FileUtilities.h"
+#include <set>
 using namespace llvm;
 
 namespace llvm {
@@ -140,10 +142,17 @@ Module *BugDriver::ExtractLoop(Module *M) {
   }
 
   // Check to see if we created any new functions.  If not, no loops were
-  // extracted and we should return null.
-  if (M->size() == NewM->size()) {
+  // extracted and we should return null.  Limit the number of loops we extract
+  // to avoid taking forever.
+  static unsigned NumExtracted = 32;
+  if (M->size() == NewM->size() || --NumExtracted == 0) {
     delete NewM;
     return 0;
+  } else {
+    assert(M->size() < NewM->size() && "Loop extract removed functions?");
+    Module::iterator MI = NewM->begin();
+    for (unsigned i = 0, e = M->size(); i != e; ++i)
+      ++MI;
   }
   
   return NewM;
@@ -183,7 +192,9 @@ Module *llvm::SplitFunctionsOutOfModule(Module *M,
     I->setInitializer(0);  // Delete the initializer to make it external
 
   // Remove the Test functions from the Safe module
+  std::set<std::pair<std::string, const PointerType*> > TestFunctions;
   for (unsigned i = 0, e = F.size(); i != e; ++i) {
+    TestFunctions.insert(std::make_pair(F[i]->getName(), F[i]->getType()));
     Function *TNOF = M->getFunction(F[i]->getName(), F[i]->getFunctionType());
     DEBUG(std::cerr << "Removing function " << F[i]->getName() << "\n");
     assert(TNOF && "Function doesn't exist in module!");
@@ -191,14 +202,78 @@ Module *llvm::SplitFunctionsOutOfModule(Module *M,
   }
 
   // Remove the Safe functions from the Test module
-  for (Module::iterator I = New->begin(), E = New->end(); I != E; ++I) {
-    bool funcFound = false;
-    for (std::vector<Function*>::const_iterator FI = F.begin(), Fe = F.end();
-         FI != Fe; ++FI)
-      if (I->getName() == (*FI)->getName()) funcFound = true;
-
-    if (!funcFound)
+  for (Module::iterator I = New->begin(), E = New->end(); I != E; ++I)
+    if (!TestFunctions.count(std::make_pair(I->getName(), I->getType())))
       DeleteFunctionBody(I);
-  }
   return New;
 }
+
+//===----------------------------------------------------------------------===//
+// Basic Block Extraction Code
+//===----------------------------------------------------------------------===//
+
+namespace {
+  std::vector<BasicBlock*> BlocksToNotExtract;
+
+  /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
+  /// from the module into their own functions except for those specified by the
+  /// BlocksToNotExtract list.
+  class BlockExtractorPass : public ModulePass {
+    bool runOnModule(Module &M);
+  };
+  RegisterOpt<BlockExtractorPass>
+  XX("extract-bbs", "Extract Basic Blocks From Module (for bugpoint use)");
+}
+
+bool BlockExtractorPass::runOnModule(Module &M) {
+  std::set<BasicBlock*> TranslatedBlocksToNotExtract;
+  for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) {
+    BasicBlock *BB = BlocksToNotExtract[i];
+    Function *F = BB->getParent();
+
+    // Map the corresponding function in this module.
+    Function *MF = M.getFunction(F->getName(), F->getFunctionType());
+
+    // Figure out which index the basic block is in its function.
+    Function::iterator BBI = MF->begin();
+    std::advance(BBI, std::distance(F->begin(), Function::iterator(BB)));
+    TranslatedBlocksToNotExtract.insert(BBI);
+  }
+
+  // Now that we know which blocks to not extract, figure out which ones we WANT
+  // to extract.
+  std::vector<BasicBlock*> BlocksToExtract;
+  for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F)
+    for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB)
+      if (!TranslatedBlocksToNotExtract.count(BB))
+        BlocksToExtract.push_back(BB);
+
+  for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i)
+    ExtractBasicBlock(BlocksToExtract[i]);
+  
+  return !BlocksToExtract.empty();
+}
+
+/// ExtractMappedBlocksFromModule - Extract all but the specified basic blocks
+/// into their own functions.  The only detail is that M is actually a module
+/// cloned from the one the BBs are in, so some mapping needs to be performed.
+/// If this operation fails for some reason (ie the implementation is buggy),
+/// this function should return null, otherwise it returns a new Module.
+Module *BugDriver::ExtractMappedBlocksFromModule(const
+                                                 std::vector<BasicBlock*> &BBs,
+                                                 Module *M) {
+  // Set the global list so that pass will be able to access it.
+  BlocksToNotExtract = BBs;
+
+  std::vector<const PassInfo*> PI;
+  PI.push_back(getPI(new BlockExtractorPass()));
+  Module *Ret = runPassesOn(M, PI);
+  BlocksToNotExtract.clear();
+  if (Ret == 0) {
+    std::cout << "*** Basic Block extraction failed, please report a bug!\n";
+    M = swapProgramIn(M);
+    EmitProgressBytecode("basicblockextractfail", true);
+    M = swapProgramIn(M);
+  }
+  return Ret;
+}