Narrow right shifts need to encode their immediates differently from a normal
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 84dca47ef8b5674b1c2310eba92b13f7091d9f91..62244ccb3a03a1222e5014c25a5b0ae0ba823b82 100644 (file)
@@ -157,6 +157,10 @@ void SCEV::print(raw_ostream &OS) const {
     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
       OS << ",+," << *AR->getOperand(i);
     OS << "}<";
+    if (AR->hasNoUnsignedWrap())
+      OS << "nuw><";
+    if (AR->hasNoSignedWrap())
+      OS << "nsw><";
     WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false);
     OS << ">";
     return;
@@ -815,6 +819,36 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
 
+  // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
+  // eliminate all the truncates.
+  if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
+    SmallVector<const SCEV *, 4> Operands;
+    bool hasTrunc = false;
+    for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
+      const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
+      hasTrunc = isa<SCEVTruncateExpr>(S);
+      Operands.push_back(S);
+    }
+    if (!hasTrunc)
+      return getAddExpr(Operands, false, false);
+    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
+  }
+
+  // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
+  // eliminate all the truncates.
+  if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
+    SmallVector<const SCEV *, 4> Operands;
+    bool hasTrunc = false;
+    for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
+      const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
+      hasTrunc = isa<SCEVTruncateExpr>(S);
+      Operands.push_back(S);
+    }
+    if (!hasTrunc)
+      return getMulExpr(Operands, false, false);
+    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
+  }
+
   // If the input value is a chrec scev, truncate the chrec's operands.
   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
     SmallVector<const SCEV *, 4> Operands;
@@ -866,6 +900,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
 
+  // zext(trunc(x)) --> zext(x) or x or trunc(x)
+  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+    // It's possible the bits taken off by the truncate were all zero bits. If
+    // so, we should be able to simplify this further.
+    const SCEV *X = ST->getOperand();
+    ConstantRange CR = getUnsignedRange(X);
+    unsigned TruncBits = getTypeSizeInBits(ST->getType());
+    unsigned NewBits = getTypeSizeInBits(Ty);
+    if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
+            CR.zextOrTrunc(NewBits)))
+      return getTruncateOrZeroExtend(X, Ty);
+  }
+
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can zero extend all of the
   // operands (often constants).  This allows analysis of something like
@@ -990,6 +1037,10 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
     return getSignExtendExpr(SS->getOperand(), Ty);
 
+  // sext(zext(x)) --> zext(x)
+  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
+    return getZeroExtendExpr(SZ->getOperand(), Ty);
+
   // Before doing any expensive analysis, check to see if we've already
   // computed a SCEV for this Op and Ty.
   FoldingSetNodeID ID;
@@ -999,6 +1050,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
 
+  // If the input value is provably positive, build a zext instead.
+  if (isKnownNonNegative(Op))
+    return getZeroExtendExpr(Op, Ty);
+
+  // sext(trunc(x)) --> sext(x) or x or trunc(x)
+  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+    // It's possible the bits taken off by the truncate were all sign bits. If
+    // so, we should be able to simplify this further.
+    const SCEV *X = ST->getOperand();
+    ConstantRange CR = getSignedRange(X);
+    unsigned TruncBits = getTypeSizeInBits(ST->getType());
+    unsigned NewBits = getTypeSizeInBits(Ty);
+    if (CR.truncate(TruncBits).signExtend(NewBits).contains(
+            CR.sextOrTrunc(NewBits)))
+      return getTruncateOrSignExtend(X, Ty);
+  }
+
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can sign extend all of the
   // operands (often constants).  This allows analysis of something like
@@ -2442,24 +2510,24 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
   return getMinusSCEV(AllOnes, V);
 }
 
-/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
-///
-const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
-                                          const SCEV *RHS) {
+/// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1,
+/// and thus the HasNUW and HasNSW bits apply to the resultant add, not
+/// whether the sub would have overflowed.
+const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
+                                          bool HasNUW, bool HasNSW) {
   // Fast path: X - X --> 0.
   if (LHS == RHS)
     return getConstant(LHS->getType(), 0);
 
   // X - Y --> X + -Y
-  return getAddExpr(LHS, getNegativeSCEV(RHS));
+  return getAddExpr(LHS, getNegativeSCEV(RHS), HasNUW, HasNSW);
 }
 
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
 /// input value to the specified type.  If the type must be extended, it is zero
 /// extended.
 const SCEV *
-ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
-                                         const Type *Ty) {
+ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
@@ -2715,6 +2783,23 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
                   HasNUW = true;
                 if (OBO->hasNoSignedWrap())
                   HasNSW = true;
+              } else if (const GEPOperator *GEP = 
+                            dyn_cast<GEPOperator>(BEValueV)) {
+                // If the increment is a GEP, then we know it won't perform a
+                // signed overflow, because the address space cannot be
+                // wrapped around.
+                //
+                // NOTE: This isn't strictly true, because you could have an
+                // object straddling the 2G address boundary in a 32-bit address
+                // space (for example).  We really want to model this as a "has
+                // no signed/unsigned wrap" where the base pointer is treated as
+                // unsigned and the increment is known to not have signed
+                // wrapping.
+                //
+                // This is a highly theoretical concern though, and this is good
+                // enough for all cases we know of at this point. :)
+                //                
+                HasNSW |= GEP->isInBounds();
               }
 
               const SCEV *StartVal = getSCEV(StartValueV);
@@ -2785,6 +2870,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
   // Add expression, because the Instruction may be guarded by control flow
   // and the no-overflow bits may not be valid for the expression in any
   // context.
+  bool isInBounds = GEP->isInBounds();
 
   const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
   Value *Base = GEP->getOperand(0);
@@ -2813,7 +2899,8 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
       IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
 
       // Multiply the index by the element size to compute the element offset.
-      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize);
+      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, /*NUW*/ false,
+                                           /*NSW*/ isInBounds);
 
       // Add the element offset to the running total offset.
       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
@@ -2824,7 +2911,8 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
   const SCEV *BaseS = getSCEV(Base);
 
   // Add the total offset from all the GEP indices to the base.
-  return getAddExpr(BaseS, TotalOffset);
+  return getAddExpr(BaseS, TotalOffset, /*NUW*/ false,
+                    /*NSW*/ isInBounds);
 }
 
 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
@@ -3054,6 +3142,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
 ///
 ConstantRange
 ScalarEvolution::getSignedRange(const SCEV *S) {
+  // See if we've computed this range already.
   DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
   if (I != SignedRanges.end())
     return I->second;
@@ -3367,8 +3456,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
               // If C is a single bit, it may be in the sign-bit position
               // before the zero-extend. In this case, represent the xor
               // using an add, which is equivalent, and re-apply the zext.
-              APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize);
-              if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
+              APInt Trunc = CI->getValue().trunc(Z0TySize);
+              if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
                   Trunc.isSignBit())
                 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
                                          UTy);
@@ -3608,60 +3697,61 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
   // backedge-taken count, which could result in infinite recursion.
   std::pair<std::map<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
     BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
-  if (Pair.second) {
-    BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L);
-    if (BECount.Exact != getCouldNotCompute()) {
-      assert(isLoopInvariant(BECount.Exact, L) &&
-             isLoopInvariant(BECount.Max, L) &&
-             "Computed backedge-taken count isn't loop invariant for loop!");
-      ++NumTripCountsComputed;
+  if (!Pair.second)
+    return Pair.first->second;
+
+  BackedgeTakenInfo BECount = ComputeBackedgeTakenCount(L);
+  if (BECount.Exact != getCouldNotCompute()) {
+    assert(isLoopInvariant(BECount.Exact, L) &&
+           isLoopInvariant(BECount.Max, L) &&
+           "Computed backedge-taken count isn't loop invariant for loop!");
+    ++NumTripCountsComputed;
 
+    // Update the value in the map.
+    Pair.first->second = BECount;
+  } else {
+    if (BECount.Max != getCouldNotCompute())
       // Update the value in the map.
       Pair.first->second = BECount;
-    } else {
-      if (BECount.Max != getCouldNotCompute())
-        // Update the value in the map.
-        Pair.first->second = BECount;
-      if (isa<PHINode>(L->getHeader()->begin()))
-        // Only count loops that have phi nodes as not being computable.
-        ++NumTripCountsNotComputed;
-    }
-
-    // Now that we know more about the trip count for this loop, forget any
-    // existing SCEV values for PHI nodes in this loop since they are only
-    // conservative estimates made without the benefit of trip count
-    // information. This is similar to the code in forgetLoop, except that
-    // it handles SCEVUnknown PHI nodes specially.
-    if (BECount.hasAnyInfo()) {
-      SmallVector<Instruction *, 16> Worklist;
-      PushLoopPHIs(L, Worklist);
-
-      SmallPtrSet<Instruction *, 8> Visited;
-      while (!Worklist.empty()) {
-        Instruction *I = Worklist.pop_back_val();
-        if (!Visited.insert(I)) continue;
-
-        ValueExprMapType::iterator It =
-          ValueExprMap.find(static_cast<Value *>(I));
-        if (It != ValueExprMap.end()) {
-          const SCEV *Old = It->second;
-
-          // SCEVUnknown for a PHI either means that it has an unrecognized
-          // structure, or it's a PHI that's in the progress of being computed
-          // by createNodeForPHI.  In the former case, additional loop trip
-          // count information isn't going to change anything. In the later
-          // case, createNodeForPHI will perform the necessary updates on its
-          // own when it gets to that point.
-          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
-            forgetMemoizedResults(Old);
-            ValueExprMap.erase(It);
-          }
-          if (PHINode *PN = dyn_cast<PHINode>(I))
-            ConstantEvolutionLoopExitValue.erase(PN);
+    if (isa<PHINode>(L->getHeader()->begin()))
+      // Only count loops that have phi nodes as not being computable.
+      ++NumTripCountsNotComputed;
+  }
+
+  // Now that we know more about the trip count for this loop, forget any
+  // existing SCEV values for PHI nodes in this loop since they are only
+  // conservative estimates made without the benefit of trip count
+  // information. This is similar to the code in forgetLoop, except that
+  // it handles SCEVUnknown PHI nodes specially.
+  if (BECount.hasAnyInfo()) {
+    SmallVector<Instruction *, 16> Worklist;
+    PushLoopPHIs(L, Worklist);
+
+    SmallPtrSet<Instruction *, 8> Visited;
+    while (!Worklist.empty()) {
+      Instruction *I = Worklist.pop_back_val();
+      if (!Visited.insert(I)) continue;
+
+      ValueExprMapType::iterator It =
+        ValueExprMap.find(static_cast<Value *>(I));
+      if (It != ValueExprMap.end()) {
+        const SCEV *Old = It->second;
+
+        // SCEVUnknown for a PHI either means that it has an unrecognized
+        // structure, or it's a PHI that's in the progress of being computed
+        // by createNodeForPHI.  In the former case, additional loop trip
+        // count information isn't going to change anything. In the later
+        // case, createNodeForPHI will perform the necessary updates on its
+        // own when it gets to that point.
+        if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
+          forgetMemoizedResults(Old);
+          ValueExprMap.erase(It);
         }
-
-        PushDefUseChildren(I, Worklist);
+        if (PHINode *PN = dyn_cast<PHINode>(I))
+          ConstantEvolutionLoopExitValue.erase(PN);
       }
+
+      PushDefUseChildren(I, Worklist);
     }
   }
   return Pair.first->second;
@@ -3932,6 +4022,105 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
   return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
 }
 
+static const SCEVAddRecExpr *
+isSimpleUnwrappingAddRec(const SCEV *S, const Loop *L) {
+  const SCEVAddRecExpr *SA = dyn_cast<SCEVAddRecExpr>(S);
+  
+  // The SCEV must be an addrec of this loop.
+  if (!SA || SA->getLoop() != L || !SA->isAffine())
+    return 0;
+  
+  // The SCEV must be known to not wrap in some way to be interesting.
+  if (!SA->hasNoUnsignedWrap() && !SA->hasNoSignedWrap())
+    return 0;
+
+  // The stride must be a constant so that we know if it is striding up or down.
+  if (!isa<SCEVConstant>(SA->getOperand(1)))
+    return 0;
+  return SA;
+}
+
+/// getMinusSCEVForExitTest - When considering an exit test for a loop with a
+/// "x != y" exit test, we turn this into a computation that evaluates x-y != 0,
+/// and this function returns the expression to use for x-y.  We know and take
+/// advantage of the fact that this subtraction is only being used in a
+/// comparison by zero context.
+///
+static const SCEV *getMinusSCEVForExitTest(const SCEV *LHS, const SCEV *RHS,
+                                           const Loop *L, ScalarEvolution &SE) {
+  // If either LHS or RHS is an AddRec SCEV (of this loop) that is known to not
+  // wrap (either NSW or NUW), then we know that the value will either become
+  // the other one (and thus the loop terminates), that the loop will terminate
+  // through some other exit condition first, or that the loop has undefined
+  // behavior.  This information is useful when the addrec has a stride that is
+  // != 1 or -1, because it means we can't "miss" the exit value.
+  //
+  // In any of these three cases, it is safe to turn the exit condition into a
+  // "counting down" AddRec (to zero) by subtracting the two inputs as normal,
+  // but since we know that the "end cannot be missed" we can force the
+  // resulting AddRec to be a NUW addrec.  Since it is counting down, this means
+  // that the AddRec *cannot* pass zero.
+
+  // See if LHS and RHS are addrec's we can handle.
+  const SCEVAddRecExpr *LHSA = isSimpleUnwrappingAddRec(LHS, L);
+  const SCEVAddRecExpr *RHSA = isSimpleUnwrappingAddRec(RHS, L);
+  
+  // If neither addrec is interesting, just return a minus.
+  if (RHSA == 0 && LHSA == 0)
+    return SE.getMinusSCEV(LHS, RHS);
+  
+  // If only one of LHS and RHS are an AddRec of this loop, make sure it is LHS.
+  if (RHSA && LHSA == 0) {
+    // Safe because a-b === b-a for comparisons against zero.
+    std::swap(LHS, RHS);
+    std::swap(LHSA, RHSA);
+  }
+  
+  // Handle the case when only one is advancing in a non-overflowing way.
+  if (RHSA == 0) {
+    // If RHS is loop varying, then we can't predict when LHS will cross it.
+    if (!SE.isLoopInvariant(RHS, L))
+      return SE.getMinusSCEV(LHS, RHS);
+    
+    // If LHS has a positive stride, then we compute RHS-LHS, because the loop
+    // is counting up until it crosses RHS (which must be larger than LHS).  If
+    // it is negative, we compute LHS-RHS because we're counting down to RHS.
+    const ConstantInt *Stride =
+      cast<SCEVConstant>(LHSA->getOperand(1))->getValue();
+    if (Stride->getValue().isNegative())
+      std::swap(LHS, RHS);
+
+    return SE.getMinusSCEV(RHS, LHS, true /*HasNUW*/);
+  }
+  
+  // If both LHS and RHS are interesting, we have something like:
+  //  a+i*4 != b+i*8.
+  const ConstantInt *LHSStride =
+    cast<SCEVConstant>(LHSA->getOperand(1))->getValue();
+  const ConstantInt *RHSStride =
+    cast<SCEVConstant>(RHSA->getOperand(1))->getValue();
+  
+  // If the strides are equal, then this is just a (complex) loop invariant
+  // comparison of a and b.
+  if (LHSStride == RHSStride)
+    return SE.getMinusSCEV(LHSA->getStart(), RHSA->getStart());
+  
+  // If the signs of the strides differ, then the negative stride is counting
+  // down to the positive stride.
+  if (LHSStride->getValue().isNegative() != RHSStride->getValue().isNegative()){
+    if (RHSStride->getValue().isNegative())
+      std::swap(LHS, RHS);
+  } else {
+    // If LHS's stride is smaller than RHS's stride, then "b" must be less than
+    // "a" and "b" is RHS is counting up (catching up) to LHS.  This is true
+    // whether the strides are positive or negative.
+    if (RHSStride->getValue().slt(LHSStride->getValue()))
+      std::swap(LHS, RHS);
+  }
+    
+  return SE.getMinusSCEV(LHS, RHS, true /*HasNUW*/);
+}
+
 /// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
 /// backedge of the specified loop will execute if its exit condition
 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
@@ -3991,7 +4180,8 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
   switch (Cond) {
   case ICmpInst::ICMP_NE: {                     // while (X != Y)
     // Convert to: while (X-Y != 0)
-    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L);
+    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEVForExitTest(LHS, RHS, L,
+                                                                 *this), L);
     if (BTI.hasAnyInfo()) return BTI;
     break;
   }
@@ -4602,7 +4792,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
   // bit width during computations.
   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
   APInt Mod(BW + 1, 0);
-  Mod.set(BW - Mult2);  // Mod = N / D
+  Mod.setBit(BW - Mult2);  // Mod = N / D
   APInt I = AD.multiplicativeInverse(Mod);
 
   // 4. Compute the minimum unsigned root of the equation:
@@ -4694,58 +4884,26 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
   if (!AddRec || AddRec->getLoop() != L)
     return getCouldNotCompute();
 
-  if (AddRec->isAffine()) {
-    // If this is an affine expression, the execution count of this branch is
-    // the minimum unsigned root of the following equation:
-    //
-    //     Start + Step*N = 0 (mod 2^BW)
-    //
-    // equivalent to:
-    //
-    //             Step*N = -Start (mod 2^BW)
-    //
-    // where BW is the common bit width of Start and Step.
-
-    // Get the initial value for the loop.
-    const SCEV *Start = getSCEVAtScope(AddRec->getStart(),
-                                       L->getParentLoop());
-    const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1),
-                                      L->getParentLoop());
-
-    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
-      // For now we handle only constant steps.
-
-      // First, handle unitary steps.
-      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
-        return getNegativeSCEV(Start);          //   N = -Start (as unsigned)
-      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
-        return Start;                           //    N = Start (as unsigned)
-
-      // Then, try to solve the above equation provided that Start is constant.
-      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
-        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
-                                            -StartC->getValue()->getValue(),
-                                            *this);
-    }
-  } else if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
-    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
-    // the quadratic equation to solve it.
-    std::pair<const SCEV *,const SCEV *> Roots = SolveQuadraticEquation(AddRec,
-                                                                    *this);
+  // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
+  // the quadratic equation to solve it.
+  if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
+    std::pair<const SCEV *,const SCEV *> Roots =
+      SolveQuadraticEquation(AddRec, *this);
     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
-    if (R1) {
+    if (R1 && R2) {
 #if 0
       dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
              << "  sol#2: " << *R2 << "\n";
 #endif
       // Pick the smallest positive root value.
       if (ConstantInt *CB =
-          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
-                                   R1->getValue(), R2->getValue()))) {
+          dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
+                                                      R1->getValue(),
+                                                      R2->getValue()))) {
         if (CB->getZExtValue() == false)
           std::swap(R1, R2);   // R1 is the minimum root now.
-
+        
         // We can only use this value if the chrec ends up with an exact zero
         // value at this index.  When solving for "X*X != 5", for example, we
         // should not accept a root of 2.
@@ -4754,8 +4912,54 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
           return R1;  // We found a quadratic root!
       }
     }
+    return getCouldNotCompute();
   }
 
+  // Otherwise we can only handle this if it is affine.
+  if (!AddRec->isAffine())
+    return getCouldNotCompute();
+
+  // If this is an affine expression, the execution count of this branch is
+  // the minimum unsigned root of the following equation:
+  //
+  //     Start + Step*N = 0 (mod 2^BW)
+  //
+  // equivalent to:
+  //
+  //             Step*N = -Start (mod 2^BW)
+  //
+  // where BW is the common bit width of Start and Step.
+
+  // Get the initial value for the loop.
+  const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
+  const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
+
+  // If the AddRec is NUW, then (in an unsigned sense) it cannot be counting up
+  // to wrap to 0, it must be counting down to equal 0.  Also, while counting
+  // down, it cannot "miss" 0 (which would cause it to wrap), regardless of what
+  // the stride is.  As such, NUW addrec's will always become zero in
+  // "start / -stride" steps, and we know that the division is exact.
+  if (AddRec->hasNoUnsignedWrap())
+    // FIXME: We really want an "isexact" bit for udiv.
+    return getUDivExpr(Start, getNegativeSCEV(Step));
+  
+  // For now we handle only constant steps.
+  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
+  if (StepC == 0)
+    return getCouldNotCompute();
+
+  // First, handle unitary steps.
+  if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
+    return getNegativeSCEV(Start);          //   N = -Start (as unsigned)
+  
+  if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
+    return Start;                           //    N = Start (as unsigned)
+
+  // Then, try to solve the above equation provided that Start is constant.
+  if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
+    return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
+                                        -StartC->getValue()->getValue(),
+                                        *this);
   return getCouldNotCompute();
 }