Change errs() to dbgs().
[oota-llvm.git] / lib / Transforms / Utils / LoopUnroll.cpp
index 9b68df05b844fa59afba0a518f6ca3d5042efe59..53117a01a3dc512fc1e472df3d04f3189e44e0b9 100644 (file)
@@ -49,6 +49,52 @@ static inline void RemapInstruction(Instruction *I,
   }
 }
 
+/// FoldBlockIntoPredecessor - Folds a basic block into its predecessor if it
+/// only has one predecessor, and that predecessor only has one successor.
+/// The LoopInfo Analysis that is passed will be kept consistent.
+/// Returns the new combined block.
+static BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI) {
+  // Merge basic blocks into their predecessor if there is only one distinct
+  // pred, and if there is only one distinct successor of the predecessor, and
+  // if there are no PHI nodes.
+  BasicBlock *OnlyPred = BB->getSinglePredecessor();
+  if (!OnlyPred) return 0;
+
+  if (OnlyPred->getTerminator()->getNumSuccessors() != 1)
+    return 0;
+
+  DEBUG(dbgs() << "Merging: " << *BB << "into: " << *OnlyPred);
+
+  // Resolve any PHI nodes at the start of the block.  They are all
+  // guaranteed to have exactly one entry if they exist, unless there are
+  // multiple duplicate (but guaranteed to be equal) entries for the
+  // incoming edges.  This occurs when there are multiple edges from
+  // OnlyPred to OnlySucc.
+  FoldSingleEntryPHINodes(BB);
+
+  // Delete the unconditional branch from the predecessor...
+  OnlyPred->getInstList().pop_back();
+
+  // Move all definitions in the successor to the predecessor...
+  OnlyPred->getInstList().splice(OnlyPred->end(), BB->getInstList());
+
+  // Make all PHI nodes that referred to BB now refer to Pred as their
+  // source...
+  BB->replaceAllUsesWith(OnlyPred);
+
+  std::string OldName = BB->getName();
+
+  // Erase basic block from the function...
+  LI->removeBlock(BB);
+  BB->eraseFromParent();
+
+  // Inherit predecessor's name if it exists...
+  if (!OldName.empty() && !OnlyPred->hasName())
+    OnlyPred->setName(OldName);
+
+  return OnlyPred;
+}
+
 /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true
 /// if unrolling was succesful, or false if the loop was unmodified. Unrolling
 /// can only fail when the loop's latch block is not terminated by a conditional
@@ -62,13 +108,24 @@ static inline void RemapInstruction(Instruction *I,
 bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM) {
   assert(L->isLCSSAForm());
 
-  BasicBlock *Header = L->getHeader();
+  BasicBlock *Preheader = L->getLoopPreheader();
+  if (!Preheader) {
+    DEBUG(dbgs() << "  Can't unroll; loop preheader-insertion failed.\n");
+    return false;
+  }
+
   BasicBlock *LatchBlock = L->getLoopLatch();
+  if (!LatchBlock) {
+    DEBUG(dbgs() << "  Can't unroll; loop exit-block-insertion failed.\n");
+    return false;
+  }
+
+  BasicBlock *Header = L->getHeader();
   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
   
   if (!BI || BI->isUnconditional()) {
     // The loop-rotate pass can be helpful to avoid this in many cases.
-    DEBUG(errs() <<
+    DEBUG(dbgs() <<
              "  Can't unroll; loop not terminated by a conditional branch.\n");
     return false;
   }
@@ -81,9 +138,9 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
     TripMultiple = L->getSmallConstantTripMultiple();
 
   if (TripCount != 0)
-    DEBUG(errs() << "  Trip Count = " << TripCount << "\n");
+    DEBUG(dbgs() << "  Trip Count = " << TripCount << "\n");
   if (TripMultiple != 1)
-    DEBUG(errs() << "  Trip Multiple = " << TripMultiple << "\n");
+    DEBUG(dbgs() << "  Trip Multiple = " << TripMultiple << "\n");
 
   // Effectively "DCE" unrolled iterations that are beyond the tripcount
   // and will never be executed.
@@ -109,17 +166,17 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
   }
 
   if (CompletelyUnroll) {
-    DEBUG(errs() << "COMPLETELY UNROLLING loop %" << Header->getName()
+    DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName()
           << " with trip count " << TripCount << "!\n");
   } else {
-    DEBUG(errs() << "UNROLLING loop %" << Header->getName()
+    DEBUG(dbgs() << "UNROLLING loop %" << Header->getName()
           << " by " << Count);
     if (TripMultiple == 0 || BreakoutTrip != TripMultiple) {
-      DEBUG(errs() << " with a breakout at trip " << BreakoutTrip);
+      DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip);
     } else if (TripMultiple != 1) {
-      DEBUG(errs() << " with " << TripMultiple << " trips per branch");
+      DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
     }
-    DEBUG(errs() << "!\n");
+    DEBUG(dbgs() << "!\n");
   }
 
   std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
@@ -137,7 +194,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
     OrigPHINode.push_back(PN);
     if (Instruction *I = 
                 dyn_cast<Instruction>(PN->getIncomingValueForBlock(LatchBlock)))
-      if (L->contains(I->getParent()))
+      if (L->contains(I))
         LastValueMap[I] = I;
   }
 
@@ -165,7 +222,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
           PHINode *NewPHI = cast<PHINode>(ValueMap[OrigPHINode[i]]);
           Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock);
           if (Instruction *InValI = dyn_cast<Instruction>(InVal))
-            if (It > 1 && L->contains(InValI->getParent()))
+            if (It > 1 && L->contains(InValI))
               InVal = LastValueMap[InValI];
           ValueMap[OrigPHINode[i]] = InVal;
           New->getInstList().erase(NewPHI);
@@ -187,7 +244,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
              UI != UE;) {
           Instruction *UseInst = cast<Instruction>(*UI);
           ++UI;
-          if (isa<PHINode>(UseInst) && !L->contains(UseInst->getParent())) {
+          if (isa<PHINode>(UseInst) && !L->contains(UseInst)) {
             PHINode *phi = cast<PHINode>(UseInst);
             Value *Incoming = phi->getIncomingValueForBlock(*BB);
             phi->addIncoming(Incoming, New);
@@ -238,7 +295,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
       // If this value was defined in the loop, take the value defined by the
       // last iteration of the loop.
       if (Instruction *InValI = dyn_cast<Instruction>(InVal)) {
-        if (L->contains(InValI->getParent()))
+        if (L->contains(InValI))
           InVal = LastValueMap[InVal];
       }
       PN->addIncoming(InVal, LastIterationBB);
@@ -287,7 +344,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
     } else {
       Term->setUnconditionalDest(Dest);
       // Merge adjacent basic blocks, if possible.
-      if (BasicBlock *Fold = MergeBlockIntoPredecessor(Dest, LI)) {
+      if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest, LI)) {
         std::replace(Latches.begin(), Latches.end(), Dest, Fold);
         std::replace(Headers.begin(), Headers.end(), Dest, Fold);
       }
@@ -305,8 +362,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM)
 
       if (isInstructionTriviallyDead(Inst))
         (*BB)->getInstList().erase(Inst);
-      else if (Constant *C = ConstantFoldInstruction(Inst, 
-                                                     Header->getContext())) {
+      else if (Constant *C = ConstantFoldInstruction(Inst)) {
         Inst->replaceAllUsesWith(C);
         (*BB)->getInstList().erase(Inst);
       }