For PR1043:
[oota-llvm.git] / lib / Transforms / IPO / LoopExtractor.cpp
index 84359beb9b2e57333a6c98065a582c575adb51f7..5a6e7671d1b9cf3b452f4b6810c2d0ca271b637d 100644 (file)
@@ -1,10 +1,10 @@
 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
-// 
+//
 //                     The LLVM Compiler Infrastructure
 //
 // This file was developed by the LLVM research group and is distributed under
 // the University of Illinois Open Source License. See LICENSE.TXT for details.
-// 
+//
 //===----------------------------------------------------------------------===//
 //
 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
@@ -14,6 +14,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#define DEBUG_TYPE "loop-extract"
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Instructions.h"
 #include "llvm/Module.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/FunctionUtils.h"
-#include "Support/Statistic.h"
+#include "llvm/ADT/Statistic.h"
 using namespace llvm;
 
+STATISTIC(NumExtracted, "Number of loops extracted");
+
 namespace {
-  Statistic<> NumExtracted("loop-extract", "Number of loops extracted");
-  
   // FIXME: This is not a function pass, but the PassManager doesn't allow
   // Module passes to require FunctionPasses, so we can't get loop info if we're
   // not a function pass.
@@ -37,7 +38,7 @@ namespace {
     LoopExtractor(unsigned numLoops = ~0) : NumLoops(numLoops) {}
 
     virtual bool runOnFunction(Function &F);
-    
+
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
       AU.addRequiredID(BreakCriticalEdgesID);
       AU.addRequiredID(LoopSimplifyID);
@@ -46,7 +47,7 @@ namespace {
     }
   };
 
-  RegisterOpt<LoopExtractor> 
+  RegisterPass<LoopExtractor>
   X("loop-extract", "Extract loops into new functions");
 
   /// SingleLoopExtractor - For bugpoint.
@@ -54,9 +55,14 @@ namespace {
     SingleLoopExtractor() : LoopExtractor(1) {}
   };
 
-  RegisterOpt<SingleLoopExtractor> 
+  RegisterPass<SingleLoopExtractor>
   Y("loop-extract-single", "Extract at most one loop into a new function");
-} // End anonymous namespace 
+} // End anonymous namespace
+
+// createLoopExtractorPass - This pass extracts all natural loops from the
+// program into a function if it can.
+//
+FunctionPass *llvm::createLoopExtractorPass() { return new LoopExtractor(); }
 
 bool LoopExtractor::runOnFunction(Function &F) {
   LoopInfo &LI = getAnalysis<LoopInfo>();
@@ -82,11 +88,11 @@ bool LoopExtractor::runOnFunction(Function &F) {
     // than a minimal wrapper around the loop, extract the loop.
     Loop *TLL = *LI.begin();
     bool ShouldExtractLoop = false;
-    
+
     // Extract the loop if the entry block doesn't branch to the loop header.
     TerminatorInst *EntryTI = F.getEntryBlock().getTerminator();
     if (!isa<BranchInst>(EntryTI) ||
-        !cast<BranchInst>(EntryTI)->isUnconditional() || 
+        !cast<BranchInst>(EntryTI)->isUnconditional() ||
         EntryTI->getSuccessor(0) != TLL->getHeader())
       ShouldExtractLoop = true;
     else {
@@ -100,7 +106,7 @@ bool LoopExtractor::runOnFunction(Function &F) {
           break;
         }
     }
-    
+
     if (ShouldExtractLoop) {
       if (NumLoops == 0) return Changed;
       --NumLoops;
@@ -126,6 +132,59 @@ bool LoopExtractor::runOnFunction(Function &F) {
 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
 // program into a function if it can.  This is used by bugpoint.
 //
-Pass *llvm::createSingleLoopExtractorPass() {
+FunctionPass *llvm::createSingleLoopExtractorPass() {
   return new SingleLoopExtractor();
 }
+
+
+namespace {
+  /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
+  /// from the module into their own functions except for those specified by the
+  /// BlocksToNotExtract list.
+  class BlockExtractorPass : public ModulePass {
+    std::vector<BasicBlock*> BlocksToNotExtract;
+  public:
+    BlockExtractorPass(std::vector<BasicBlock*> &B) : BlocksToNotExtract(B) {}
+    BlockExtractorPass() {}
+
+    bool runOnModule(Module &M);
+  };
+  RegisterPass<BlockExtractorPass>
+  XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)");
+}
+
+// createBlockExtractorPass - This pass extracts all blocks (except those
+// specified in the argument list) from the functions in the module.
+//
+ModulePass *llvm::createBlockExtractorPass(std::vector<BasicBlock*> &BTNE) {
+  return new BlockExtractorPass(BTNE);
+}
+
+bool BlockExtractorPass::runOnModule(Module &M) {
+  std::set<BasicBlock*> TranslatedBlocksToNotExtract;
+  for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) {
+    BasicBlock *BB = BlocksToNotExtract[i];
+    Function *F = BB->getParent();
+
+    // Map the corresponding function in this module.
+    Function *MF = M.getFunction(F->getName(), F->getFunctionType());
+
+    // Figure out which index the basic block is in its function.
+    Function::iterator BBI = MF->begin();
+    std::advance(BBI, std::distance(F->begin(), Function::iterator(BB)));
+    TranslatedBlocksToNotExtract.insert(BBI);
+  }
+
+  // Now that we know which blocks to not extract, figure out which ones we WANT
+  // to extract.
+  std::vector<BasicBlock*> BlocksToExtract;
+  for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F)
+    for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB)
+      if (!TranslatedBlocksToNotExtract.count(BB))
+        BlocksToExtract.push_back(BB);
+
+  for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i)
+    ExtractBasicBlock(BlocksToExtract[i]);
+
+  return !BlocksToExtract.empty();
+}