Improve LSR's dead-phi detection to handle use-def cycles
[oota-llvm.git] / lib / Transforms / Scalar / LoopStrengthReduce.cpp
index c65b83e9a56a5f9c77906f4c7075eff753c5f6f8..d825ea789b596a7862387ed18501df5f01227952 100644 (file)
@@ -31,6 +31,7 @@
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Target/TargetData.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Support/Debug.h"
@@ -136,7 +137,7 @@ namespace {
 
     /// DeadInsts - Keep track of instructions we may have made dead, so that
     /// we can remove them after we are done working.
-    SmallPtrSet<Instruction*,16> DeadInsts;
+    SetVector<Instruction*> DeadInsts;
 
     /// TLI - Keep a pointer of a TargetLowering to consult for determining
     /// transformation profitability.
@@ -192,7 +193,7 @@ private:
     void StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
                                       IVUsersOfOneStride &Uses,
                                       Loop *L, bool isOnlyStride);
-    void DeleteTriviallyDeadInstructions(SmallPtrSet<Instruction*,16> &Insts);
+    void DeleteTriviallyDeadInstructions(SetVector<Instruction*> &Insts);
   };
 }
 
@@ -226,10 +227,10 @@ Value *LoopStrengthReduce::getCastedVersionOf(Instruction::CastOps opcode,
 /// specified set are trivially dead, delete them and see if this makes any of
 /// their operands subsequently dead.
 void LoopStrengthReduce::
-DeleteTriviallyDeadInstructions(SmallPtrSet<Instruction*,16> &Insts) {
+DeleteTriviallyDeadInstructions(SetVector<Instruction*> &Insts) {
   while (!Insts.empty()) {
-    Instruction *I = *Insts.begin();
-    Insts.erase(I);
+    Instruction *I = Insts.back();
+    Insts.pop_back();
 
     if (PHINode *PN = dyn_cast<PHINode>(I)) {
       // If all incoming values to the Phi are the same, we can replace the Phi
@@ -246,8 +247,8 @@ DeleteTriviallyDeadInstructions(SmallPtrSet<Instruction*,16> &Insts) {
     }
 
     if (isInstructionTriviallyDead(I)) {
-      for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
-        if (Instruction *U = dyn_cast<Instruction>(I->getOperand(i)))
+      for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i)
+        if (Instruction *U = dyn_cast<Instruction>(*i))
           Insts.insert(U);
       SE->deleteValueFromRecords(I);
       I->eraseFromParent();
@@ -289,24 +290,25 @@ SCEVHandle LoopStrengthReduce::GetExpressionSCEV(Instruction *Exp) {
 
   gep_type_iterator GTI = gep_type_begin(GEP);
   
-  for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
+  for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end();
+       i != e; ++i, ++GTI) {
     // If this is a use of a recurrence that we can analyze, and it comes before
     // Op does in the GEP operand list, we will handle this when we process this
     // operand.
     if (const StructType *STy = dyn_cast<StructType>(*GTI)) {
       const StructLayout *SL = TD->getStructLayout(STy);
-      unsigned Idx = cast<ConstantInt>(GEP->getOperand(i))->getZExtValue();
+      unsigned Idx = cast<ConstantInt>(*i)->getZExtValue();
       uint64_t Offset = SL->getElementOffset(Idx);
       GEPVal = SE->getAddExpr(GEPVal,
                              SE->getIntegerSCEV(Offset, UIntPtrTy));
     } else {
       unsigned GEPOpiBits = 
-        GEP->getOperand(i)->getType()->getPrimitiveSizeInBits();
+        (*i)->getType()->getPrimitiveSizeInBits();
       unsigned IntPtrBits = UIntPtrTy->getPrimitiveSizeInBits();
       Instruction::CastOps opcode = (GEPOpiBits < IntPtrBits ? 
           Instruction::SExt : (GEPOpiBits > IntPtrBits ? Instruction::Trunc :
             Instruction::BitCast));
-      Value *OpVal = getCastedVersionOf(opcode, GEP->getOperand(i));
+      Value *OpVal = getCastedVersionOf(opcode, *i);
       SCEVHandle Idx = SE->getSCEV(OpVal);
 
       uint64_t TypeSize = TD->getABITypeSize(GTI.getIndexedType());
@@ -377,7 +379,7 @@ static bool getSCEVStartAndStride(const SCEVHandle &SH, Loop *L,
 /// should use the post-inc value).
 static bool IVUseShouldUsePostIncValue(Instruction *User, Instruction *IV,
                                        Loop *L, DominatorTree *DT, Pass *P,
-                                       SmallPtrSet<Instruction*,16> &DeadInsts){
+                                       SetVector<Instruction*> &DeadInsts){
   // If the user is in the loop, use the preinc value.
   if (L->contains(User->getParent())) return false;
   
@@ -545,7 +547,7 @@ namespace {
     void RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
                                         Instruction *InsertPt,
                                        SCEVExpander &Rewriter, Loop *L, Pass *P,
-                                       SmallPtrSet<Instruction*,16> &DeadInsts);
+                                       SetVector<Instruction*> &DeadInsts);
     
     Value *InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, 
                                        SCEVExpander &Rewriter,
@@ -584,9 +586,8 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase,
   }
   
   // If there is no immediate value, skip the next part.
-  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Imm))
-    if (SC->getValue()->isZero())
-      return Rewriter.expandCodeFor(NewBase, BaseInsertPt);
+  if (Imm->isZero())
+    return Rewriter.expandCodeFor(NewBase, BaseInsertPt);
 
   Value *Base = Rewriter.expandCodeFor(NewBase, BaseInsertPt);
 
@@ -611,7 +612,7 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase,
 void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
                                                Instruction *NewBasePt,
                                       SCEVExpander &Rewriter, Loop *L, Pass *P,
-                                      SmallPtrSet<Instruction*,16> &DeadInsts) {
+                                      SetVector<Instruction*> &DeadInsts) {
   if (!isa<PHINode>(Inst)) {
     // By default, insert code at the user instruction.
     BasicBlock::iterator InsertPt = Inst;
@@ -628,7 +629,8 @@ void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
       if (NewBasePt && isa<PHINode>(OperandValToReplace)) {
         InsertPt = NewBasePt;
         ++InsertPt;
-      } else if (Instruction *OpInst = dyn_cast<Instruction>(OperandValToReplace)) { 
+      } else if (Instruction *OpInst
+                 = dyn_cast<Instruction>(OperandValToReplace)) {
         InsertPt = OpInst;
         while (isa<PHINode>(InsertPt)) ++InsertPt;
       }
@@ -888,8 +890,7 @@ static void SeparateSubExprs(std::vector<SCEVHandle> &SubExprs,
 
       SeparateSubExprs(SubExprs, SARE->getOperand(0), SE);
     }
-  } else if (!isa<SCEVConstant>(Expr) ||
-             !cast<SCEVConstant>(Expr)->getValue()->isZero()) {
+  } else if (!Expr->isZero()) {
     // Do not add zero.
     SubExprs.push_back(Expr);
   }
@@ -976,14 +977,6 @@ RemoveCommonExpressionsFromUseBases(std::vector<BasedUser> &Uses,
   return Result;
 }
 
-/// isZero - returns true if the scalar evolution expression is zero.
-///
-static bool isZero(const SCEVHandle &V) {
-  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(V))
-    return SC->getValue()->isZero();
-  return false;
-}
-
 /// ValidStride - Check whether the given Scale is valid for all loads and 
 /// stores in UsersToProcess.
 ///
@@ -1006,7 +999,7 @@ bool LoopStrengthReduce::ValidStride(bool HasBaseReg,
     TargetLowering::AddrMode AM;
     if (SCEVConstant *SC = dyn_cast<SCEVConstant>(UsersToProcess[i].Imm))
       AM.BaseOffs = SC->getValue()->getSExtValue();
-    AM.HasBaseReg = HasBaseReg || !isZero(UsersToProcess[i].Base);
+    AM.HasBaseReg = HasBaseReg || !UsersToProcess[i].Base->isZero();
     AM.Scale = Scale;
 
     // If load[imm+r*scale] is illegal, bail out.
@@ -1066,7 +1059,7 @@ unsigned LoopStrengthReduce::CheckForIVReuse(bool HasBaseReg,
                IE = SI->second.IVs.end(); II != IE; ++II)
           // FIXME: Only handle base == 0 for now.
           // Only reuse previous IV if it would not require a type conversion.
-          if (isZero(II->Base) &&
+          if (II->Base->isZero() &&
               !RequiresTypeConversion(II->Base->getType(), Ty)) {
             IV = *II;
             return Scale;
@@ -1119,11 +1112,6 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) {
         if (II->getOperand(1) == OperandVal)
           isAddress = true;
         break;
-      case Intrinsic::x86_sse2_loadh_pd:
-      case Intrinsic::x86_sse2_loadl_pd:
-        if (II->getOperand(2) == OperandVal)
-          isAddress = true;
-        break;
     }
   }
   return isAddress;
@@ -1235,7 +1223,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
   // their value in a register and add it in for each use. This will take up
   // a register operand, which potentially restricts what stride values are
   // valid.
-  bool HaveCommonExprs = !isZero(CommonExprs);
+  bool HaveCommonExprs = !CommonExprs->isZero();
   
   // If all uses are addresses, check if it is possible to reuse an IV with a
   // stride that is a factor of this stride. And that the multiple is a number
@@ -1598,8 +1586,8 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
         ? UIntPtrTy->getPrimitiveSizeInBits()
         : NewCmpTy->getPrimitiveSizeInBits();
       if (RequiresTypeConversion(NewCmpTy, CmpTy)) {
-        // Check if it is possible to rewrite it using a iv / stride of a smaller
-        // integer type.
+        // Check if it is possible to rewrite it using
+        // an iv / stride of a smaller integer type.
         bool TruncOk = false;
         if (NewCmpTy->isInteger()) {
           unsigned Bits = NewTyBits;
@@ -1631,7 +1619,7 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
       // Avoid rewriting the compare instruction with an iv of new stride
       // if it's likely the new stride uses will be rewritten using the
       if (AllUsesAreAddresses &&
-          ValidStride(!isZero(CommonExprs), Scale, UsersToProcess)) {        
+          ValidStride(!CommonExprs->isZero(), Scale, UsersToProcess)) {
         NewCmpVal = CmpVal;
         continue;
       }
@@ -1646,6 +1634,18 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
     }
   }
 
+  // Forgo this transformation if it the increment happens to be
+  // unfortunately positioned after the condition, and the condition
+  // has multiple uses which prevent it from being moved immediately
+  // before the branch. See
+  // test/Transforms/LoopStrengthReduce/change-compare-stride-trickiness-*.ll
+  // for an example of this situation.
+  if (!Cond->hasOneUse())
+    for (BasicBlock::iterator I = Cond, E = Cond->getParent()->end();
+         I != E; ++I)
+      if (I == NewIncV)
+        return Cond;
+
   if (NewCmpVal != CmpVal) {
     // Create a new compare instruction using new stride / iv.
     ICmpInst *OldCond = Cond;
@@ -1657,9 +1657,9 @@ ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
       RHS = SCEVExpander::InsertCastOfTo(Instruction::IntToPtr, RHS, NewCmpTy);
     }
     // Insert new compare instruction.
-    Cond = new ICmpInst(Predicate, NewIncV, RHS);
-    Cond->setName(L->getHeader()->getName() + ".termcond");
-    OldCond->getParent()->getInstList().insert(OldCond, Cond);
+    Cond = new ICmpInst(Predicate, NewIncV, RHS,
+                        L->getHeader()->getName() + ".termcond",
+                        OldCond);
 
     // Remove the old compare instruction. The old indvar is probably dead too.
     DeadInsts.insert(cast<Instruction>(CondUse->OperandValToReplace));
@@ -1810,31 +1810,28 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) {
     DeleteTriviallyDeadInstructions(DeadInsts);
 
     BasicBlock::iterator I = L->getHeader()->begin();
-    PHINode *PN;
-    while ((PN = dyn_cast<PHINode>(I))) {
-      ++I;  // Preincrement iterator to avoid invalidating it when deleting PN.
-
-      // At this point, we know that we have killed one or more GEP
-      // instructions.  It is worth checking to see if the cann indvar is also
-      // dead, so that we can remove it as well.  The requirements for the cann
-      // indvar to be considered dead are:
-      // 1. the cann indvar has one use
-      // 2. the use is an add instruction
-      // 3. the add has one use
-      // 4. the add is used by the cann indvar
-      // If all four cases above are true, then we can remove both the add and
-      // the cann indvar.
+    while (PHINode *PN = dyn_cast<PHINode>(I++)) {
+      // At this point, we know that we have killed one or more IV users.
+      // It is worth checking to see if the cann indvar is also
+      // dead, so that we can remove it as well.
+      //
+      // We can remove a PHI if it is on a cycle in the def-use graph
+      // where each node in the cycle has degree one, i.e. only one use,
+      // and is an instruction with no side effects.
+      //
       // FIXME: this needs to eliminate an induction variable even if it's being
       // compared against some value to decide loop termination.
       if (PN->hasOneUse()) {
-        Instruction *BO = dyn_cast<Instruction>(*PN->use_begin());
-        if (BO && (isa<BinaryOperator>(BO) || isa<CmpInst>(BO))) {
-          if (BO->hasOneUse() && PN == *(BO->use_begin())) {
-            DeadInsts.insert(BO);
-            // Break the cycle, then delete the PHI.
+        for (Instruction *J = dyn_cast<Instruction>(*PN->use_begin());
+             J && J->hasOneUse() && !J->mayWriteToMemory();
+             J = dyn_cast<Instruction>(*J->use_begin())) {
+          // If we find the original PHI, we've discovered a cycle.
+          if (J == PN) {
+            // Break the cycle and mark the PHI for deletion.
             SE->deleteValueFromRecords(PN);
             PN->replaceAllUsesWith(UndefValue::get(PN->getType()));
-            PN->eraseFromParent();
+            DeadInsts.insert(PN);
+            break;
           }
         }
       }