Sink the collection of return instructions until after *all*
[oota-llvm.git] / lib / Transforms / Utils / Local.cpp
index bbbfcba4709c0db51a852aea6e38f40e80783928..d1c4d5968231f6046de76523de61e6de950fc7c0 100644 (file)
@@ -106,33 +106,32 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions) {
     // If we are switching on a constant, we can convert the switch into a
     // single branch instruction!
     ConstantInt *CI = dyn_cast<ConstantInt>(SI->getCondition());
-    BasicBlock *TheOnlyDest = SI->getSuccessor(0);  // The default dest
+    BasicBlock *TheOnlyDest = SI->getDefaultDest();
     BasicBlock *DefaultDest = TheOnlyDest;
-    assert(TheOnlyDest == SI->getDefaultDest() &&
-           "Default destination is not successor #0?");
 
     // Figure out which case it goes to.
-    for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) {
+    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
+         i != e; ++i) {
       // Found case matching a constant operand?
-      if (SI->getSuccessorValue(i) == CI) {
-        TheOnlyDest = SI->getSuccessor(i);
+      if (i.getCaseValue() == CI) {
+        TheOnlyDest = i.getCaseSuccessor();
         break;
       }
 
       // Check to see if this branch is going to the same place as the default
       // dest.  If so, eliminate it as an explicit compare.
-      if (SI->getSuccessor(i) == DefaultDest) {
+      if (i.getCaseSuccessor() == DefaultDest) {
         // Remove this entry.
         DefaultDest->removePredecessor(SI->getParent());
         SI->removeCase(i);
-        --i; --e;  // Don't skip an entry...
+        --i; --e;
         continue;
       }
 
       // Otherwise, check to see if the switch only branches to one destination.
       // We do this by reseting "TheOnlyDest" to null when we find two non-equal
       // destinations.
-      if (SI->getSuccessor(i) != TheOnlyDest) TheOnlyDest = 0;
+      if (i.getCaseSuccessor() != TheOnlyDest) TheOnlyDest = 0;
     }
 
     if (CI && !TheOnlyDest) {
@@ -166,14 +165,16 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions) {
       return true;
     }
     
-    if (SI->getNumSuccessors() == 2) {
+    if (SI->getNumCases() == 1) {
       // Otherwise, we can fold this switch into a conditional branch
       // instruction if it has only one non-default destination.
+      SwitchInst::CaseIt FirstCase = SI->case_begin();
       Value *Cond = Builder.CreateICmpEQ(SI->getCondition(),
-                                         SI->getSuccessorValue(1), "cond");
+          FirstCase.getCaseValue(), "cond");
 
       // Insert the new branch.
-      Builder.CreateCondBr(Cond, SI->getSuccessor(1), SI->getSuccessor(0));
+      Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(),
+                           SI->getDefaultDest());
 
       // Delete the old switch.
       SI->eraseFromParent();
@@ -354,22 +355,27 @@ bool llvm::RecursivelyDeleteDeadPHINode(PHINode *PN) {
 /// instructions in other blocks as well in this block.
 bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, const TargetData *TD) {
   bool MadeChange = false;
-  for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) {
+
+#ifndef NDEBUG
+  // In debug builds, ensure that the terminator of the block is never replaced
+  // or deleted by these simplifications. The idea of simplification is that it
+  // cannot introduce new instructions, and there is no way to replace the
+  // terminator of a block without introducing a new instruction.
+  AssertingVH<Instruction> TerminatorVH(--BB->end());
+#endif
+
+  for (BasicBlock::iterator BI = BB->begin(), E = --BB->end(); BI != E; ) {
+    assert(!BI->isTerminator());
     Instruction *Inst = BI++;
-    
-    if (Value *V = SimplifyInstruction(Inst, TD)) {
-      WeakVH BIHandle(BI);
-      ReplaceAndSimplifyAllUses(Inst, V, TD);
+
+    WeakVH BIHandle(BI);
+    if (recursivelySimplifyInstruction(Inst, TD)) {
       MadeChange = true;
       if (BIHandle != BI)
         BI = BB->begin();
       continue;
     }
 
-    if (Inst->isTerminator())
-      break;
-
-    WeakVH BIHandle(BI);
     MadeChange |= RecursivelyDeleteTriviallyDeadInstructions(Inst);
     if (BIHandle != BI)
       BI = BB->begin();
@@ -407,17 +413,11 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred,
   WeakVH PhiIt = &BB->front();
   while (PHINode *PN = dyn_cast<PHINode>(PhiIt)) {
     PhiIt = &*++BasicBlock::iterator(cast<Instruction>(PhiIt));
+    Value *OldPhiIt = PhiIt;
 
-    Value *PNV = SimplifyInstruction(PN, TD);
-    if (PNV == 0) continue;
+    if (!recursivelySimplifyInstruction(PN, TD))
+      continue;
 
-    // If we're able to simplify the phi to a single value, substitute the new
-    // value into all of its uses.
-    assert(PNV != PN && "SimplifyInstruction broken!");
-    
-    Value *OldPhiIt = PhiIt;
-    ReplaceAndSimplifyAllUses(PN, PNV, TD);
-    
     // If recursive simplification ended up deleting the next PHI node we would
     // iterate to, then our iterator is invalid, restart scanning from the top
     // of the block.
@@ -494,22 +494,8 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) {
   if (Succ->getSinglePredecessor()) return true;
 
   // Make a list of the predecessors of BB
-  typedef SmallPtrSet<BasicBlock*, 16> BlockSet;
-  BlockSet BBPreds(pred_begin(BB), pred_end(BB));
-
-  // Use that list to make another list of common predecessors of BB and Succ
-  BlockSet CommonPreds;
-  for (pred_iterator PI = pred_begin(Succ), PE = pred_end(Succ);
-       PI != PE; ++PI) {
-    BasicBlock *P = *PI;
-    if (BBPreds.count(P))
-      CommonPreds.insert(P);
-  }
+  SmallPtrSet<BasicBlock*, 16> BBPreds(pred_begin(BB), pred_end(BB));
 
-  // Shortcut, if there are no common predecessors, merging is always safe
-  if (CommonPreds.empty())
-    return true;
-  
   // Look at all the phi nodes in Succ, to see if they present a conflict when
   // merging these blocks
   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
@@ -520,28 +506,28 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) {
     // merge the phi nodes and then the blocks can still be merged
     PHINode *BBPN = dyn_cast<PHINode>(PN->getIncomingValueForBlock(BB));
     if (BBPN && BBPN->getParent() == BB) {
-      for (BlockSet::iterator PI = CommonPreds.begin(), PE = CommonPreds.end();
-            PI != PE; PI++) {
-        if (BBPN->getIncomingValueForBlock(*PI) 
-              != PN->getIncomingValueForBlock(*PI)) {
+      for (unsigned PI = 0, PE = PN->getNumIncomingValues(); PI != PE; ++PI) {
+        BasicBlock *IBB = PN->getIncomingBlock(PI);
+        if (BBPreds.count(IBB) &&
+            BBPN->getIncomingValueForBlock(IBB) != PN->getIncomingValue(PI)) {
           DEBUG(dbgs() << "Can't fold, phi node " << PN->getName() << " in " 
                 << Succ->getName() << " is conflicting with " 
                 << BBPN->getName() << " with regard to common predecessor "
-                << (*PI)->getName() << "\n");
+                << IBB->getName() << "\n");
           return false;
         }
       }
     } else {
       Value* Val = PN->getIncomingValueForBlock(BB);
-      for (BlockSet::iterator PI = CommonPreds.begin(), PE = CommonPreds.end();
-            PI != PE; PI++) {
+      for (unsigned PI = 0, PE = PN->getNumIncomingValues(); PI != PE; ++PI) {
         // See if the incoming value for the common predecessor is equal to the
         // one for BB, in which case this phi node will not prevent the merging
         // of the block.
-        if (Val != PN->getIncomingValueForBlock(*PI)) {
+        BasicBlock *IBB = PN->getIncomingBlock(PI);
+        if (BBPreds.count(IBB) && Val != PN->getIncomingValue(PI)) {
           DEBUG(dbgs() << "Can't fold, phi node " << PN->getName() << " in " 
                 << Succ->getName() << " is conflicting with regard to common "
-                << "predecessor " << (*PI)->getName() << "\n");
+                << "predecessor " << IBB->getName() << "\n");
           return false;
         }
       }
@@ -776,9 +762,8 @@ unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign,
   assert(V->getType()->isPointerTy() &&
          "getOrEnforceKnownAlignment expects a pointer!");
   unsigned BitWidth = TD ? TD->getPointerSizeInBits() : 64;
-  APInt Mask = APInt::getAllOnesValue(BitWidth);
   APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-  ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD);
+  ComputeMaskedBits(V, KnownZero, KnownOne, TD);
   unsigned TrailZ = KnownZero.countTrailingOnes();
   
   // Avoid trouble with rediculously large TrailZ values, such as