Fix a problem that nate reduced for me.
[oota-llvm.git] / lib / Transforms / Utils / UnifyFunctionExitNodes.cpp
index 064bba5dddd0f36de4596f3bc4ce6940c4f82271..0c1eda7c0cc82922b7f14e4dadd26b59c9bd61d2 100644 (file)
@@ -1,10 +1,10 @@
 //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
-// 
+//
 //                     The LLVM Compiler Infrastructure
 //
 // This file was developed by the LLVM research group and is distributed under
 // the University of Illinois Open Source License. See LICENSE.TXT for details.
-// 
+//
 //===----------------------------------------------------------------------===//
 //
 // This pass is used to ensure that functions have at most one return
@@ -18,8 +18,7 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/BasicBlock.h"
 #include "llvm/Function.h"
-#include "llvm/iTerminators.h"
-#include "llvm/iPHINode.h"
+#include "llvm/Instructions.h"
 #include "llvm/Type.h"
 using namespace llvm;
 
@@ -47,13 +46,16 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
   //
   std::vector<BasicBlock*> ReturningBlocks;
   std::vector<BasicBlock*> UnwindingBlocks;
+  std::vector<BasicBlock*> UnreachableBlocks;
   for(Function::iterator I = F.begin(), E = F.end(); I != E; ++I)
     if (isa<ReturnInst>(I->getTerminator()))
       ReturningBlocks.push_back(I);
     else if (isa<UnwindInst>(I->getTerminator()))
       UnwindingBlocks.push_back(I);
+    else if (isa<UnreachableInst>(I->getTerminator()))
+      UnreachableBlocks.push_back(I);
 
-  // Handle unwinding blocks first...
+  // Handle unwinding blocks first.
   if (UnwindingBlocks.empty()) {
     UnwindBlock = 0;
   } else if (UnwindingBlocks.size() == 1) {
@@ -62,15 +64,32 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
     UnwindBlock = new BasicBlock("UnifiedUnwindBlock", &F);
     new UnwindInst(UnwindBlock);
 
-    for (std::vector<BasicBlock*>::iterator I = UnwindingBlocks.begin(), 
+    for (std::vector<BasicBlock*>::iterator I = UnwindingBlocks.begin(),
            E = UnwindingBlocks.end(); I != E; ++I) {
       BasicBlock *BB = *I;
-      BB->getInstList().pop_back();  // Remove the return insn
-      new BranchInst(UnwindBlock, 0, 0, BB);
+      BB->getInstList().pop_back();  // Remove the unwind insn
+      new BranchInst(UnwindBlock, BB);
+    }
+  }
+
+  // Then unreachable blocks.
+  if (UnreachableBlocks.empty()) {
+    UnreachableBlock = 0;
+  } else if (UnreachableBlocks.size() == 1) {
+    UnreachableBlock = UnreachableBlocks.front();
+  } else {
+    UnreachableBlock = new BasicBlock("UnifiedUnreachableBlock", &F);
+    new UnreachableInst(UnreachableBlock);
+
+    for (std::vector<BasicBlock*>::iterator I = UnreachableBlocks.begin(),
+           E = UnreachableBlocks.end(); I != E; ++I) {
+      BasicBlock *BB = *I;
+      BB->getInstList().pop_back();  // Remove the unreachable inst.
+      new BranchInst(UnreachableBlock, BB);
     }
   }
 
-  // Now handle return blocks...
+  // Now handle return blocks.
   if (ReturningBlocks.empty()) {
     ReturnBlock = 0;
     return false;                          // No blocks return
@@ -80,7 +99,7 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
   }
 
   // Otherwise, we need to insert a new basic block into the function, add a PHI
-  // node (if the function returns a value), and convert all of the return 
+  // node (if the function returns a value), and convert all of the return
   // instructions into unconditional branches.
   //
   BasicBlock *NewRetBlock = new BasicBlock("UnifiedReturnBlock", &F);
@@ -96,7 +115,7 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
   // Loop over all of the blocks, replacing the return instruction with an
   // unconditional branch.
   //
-  for (std::vector<BasicBlock*>::iterator I = ReturningBlocks.begin(), 
+  for (std::vector<BasicBlock*>::iterator I = ReturningBlocks.begin(),
          E = ReturningBlocks.end(); I != E; ++I) {
     BasicBlock *BB = *I;