When checking whether the special handling for an addrec increment which
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index f540bcbdcd1bbbc03b8c0b55b6a5b767abb1fc55..0ef5d84be5abe572f653a60665e577675c83e6cc 100644 (file)
@@ -1355,9 +1355,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
         }
         LargeOps.push_back(T->getOperand());
       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
-        // This could be either sign or zero extension, but sign extension
-        // is much more likely to be foldable here.
-        LargeOps.push_back(getSignExtendExpr(C, SrcType));
+        LargeOps.push_back(getAnyExtendExpr(C, SrcType));
       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
         SmallVector<const SCEV *, 8> LargeMulOps;
         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
@@ -1370,9 +1368,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
             LargeMulOps.push_back(T->getOperand());
           } else if (const SCEVConstant *C =
                        dyn_cast<SCEVConstant>(M->getOperand(j))) {
-            // This could be either sign or zero extension, but sign extension
-            // is much more likely to be foldable here.
-            LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
+            LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
           } else {
             Ok = false;
             break;
@@ -3420,10 +3416,22 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         // fall through
       case ICmpInst::ICMP_SGT:
       case ICmpInst::ICMP_SGE:
-        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
-          return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
-        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
-          return getSMinExpr(getSCEV(LHS), getSCEV(RHS));
+        // a >s b ? a+x : b+x  ->  smax(a, b)+x
+        // a >s b ? b+x : a+x  ->  smin(a, b)+x
+        if (LHS->getType() == U->getType()) {
+          const SCEV *LS = getSCEV(LHS);
+          const SCEV *RS = getSCEV(RHS);
+          const SCEV *LA = getSCEV(U->getOperand(1));
+          const SCEV *RA = getSCEV(U->getOperand(2));
+          const SCEV *LDiff = getMinusSCEV(LA, LS);
+          const SCEV *RDiff = getMinusSCEV(RA, RS);
+          if (LDiff == RDiff)
+            return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+          LDiff = getMinusSCEV(LA, RS);
+          RDiff = getMinusSCEV(RA, LS);
+          if (LDiff == RDiff)
+            return getAddExpr(getSMinExpr(LS, RS), LDiff);
+        }
         break;
       case ICmpInst::ICMP_ULT:
       case ICmpInst::ICMP_ULE:
@@ -3431,28 +3439,52 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         // fall through
       case ICmpInst::ICMP_UGT:
       case ICmpInst::ICMP_UGE:
-        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
-          return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
-        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
-          return getUMinExpr(getSCEV(LHS), getSCEV(RHS));
+        // a >u b ? a+x : b+x  ->  umax(a, b)+x
+        // a >u b ? b+x : a+x  ->  umin(a, b)+x
+        if (LHS->getType() == U->getType()) {
+          const SCEV *LS = getSCEV(LHS);
+          const SCEV *RS = getSCEV(RHS);
+          const SCEV *LA = getSCEV(U->getOperand(1));
+          const SCEV *RA = getSCEV(U->getOperand(2));
+          const SCEV *LDiff = getMinusSCEV(LA, LS);
+          const SCEV *RDiff = getMinusSCEV(RA, RS);
+          if (LDiff == RDiff)
+            return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+          LDiff = getMinusSCEV(LA, RS);
+          RDiff = getMinusSCEV(RA, LS);
+          if (LDiff == RDiff)
+            return getAddExpr(getUMinExpr(LS, RS), LDiff);
+        }
         break;
       case ICmpInst::ICMP_NE:
-        // n != 0 ? n : 1  ->  umax(n, 1)
-        if (LHS == U->getOperand(1) &&
-            isa<ConstantInt>(U->getOperand(2)) &&
-            cast<ConstantInt>(U->getOperand(2))->isOne() &&
+        // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
+        if (LHS->getType() == U->getType() &&
             isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero())
-          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2)));
+            cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(LHS->getType(), 1);
+          const SCEV *LS = getSCEV(LHS);
+          const SCEV *LA = getSCEV(U->getOperand(1));
+          const SCEV *RA = getSCEV(U->getOperand(2));
+          const SCEV *LDiff = getMinusSCEV(LA, LS);
+          const SCEV *RDiff = getMinusSCEV(RA, One);
+          if (LDiff == RDiff)
+            return getAddExpr(getUMaxExpr(LS, One), LDiff);
+        }
         break;
       case ICmpInst::ICMP_EQ:
-        // n == 0 ? 1 : n  ->  umax(n, 1)
-        if (LHS == U->getOperand(2) &&
-            isa<ConstantInt>(U->getOperand(1)) &&
-            cast<ConstantInt>(U->getOperand(1))->isOne() &&
+        // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
+        if (LHS->getType() == U->getType() &&
             isa<ConstantInt>(RHS) &&
-            cast<ConstantInt>(RHS)->isZero())
-          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1)));
+            cast<ConstantInt>(RHS)->isZero()) {
+          const SCEV *One = getConstant(LHS->getType(), 1);
+          const SCEV *LS = getSCEV(LHS);
+          const SCEV *LA = getSCEV(U->getOperand(1));
+          const SCEV *RA = getSCEV(U->getOperand(2));
+          const SCEV *LDiff = getMinusSCEV(LA, One);
+          const SCEV *RDiff = getMinusSCEV(RA, LS);
+          if (LDiff == RDiff)
+            return getAddExpr(getUMaxExpr(LS, One), LDiff);
+        }
         break;
       default:
         break;
@@ -3889,6 +3921,57 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
       }
 
+  // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
+  // adding or subtracting 1 from one of the operands.
+  switch (Cond) {
+  case ICmpInst::ICMP_SLE:
+    if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
+      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
+                       /*HasNUW=*/false, /*HasNSW=*/true);
+      Cond = ICmpInst::ICMP_SLT;
+    } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
+      LHS = getAddExpr(getConstant(RHS->getType(), -1, true), LHS,
+                       /*HasNUW=*/false, /*HasNSW=*/true);
+      Cond = ICmpInst::ICMP_SLT;
+    }
+    break;
+  case ICmpInst::ICMP_SGE:
+    if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
+      RHS = getAddExpr(getConstant(RHS->getType(), -1, true), RHS,
+                       /*HasNUW=*/false, /*HasNSW=*/true);
+      Cond = ICmpInst::ICMP_SGT;
+    } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
+      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
+                       /*HasNUW=*/false, /*HasNSW=*/true);
+      Cond = ICmpInst::ICMP_SGT;
+    }
+    break;
+  case ICmpInst::ICMP_ULE:
+    if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
+      RHS = getAddExpr(getConstant(RHS->getType(), 1, false), RHS,
+                       /*HasNUW=*/true, /*HasNSW=*/false);
+      Cond = ICmpInst::ICMP_ULT;
+    } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
+      LHS = getAddExpr(getConstant(RHS->getType(), -1, false), LHS,
+                       /*HasNUW=*/true, /*HasNSW=*/false);
+      Cond = ICmpInst::ICMP_ULT;
+    }
+    break;
+  case ICmpInst::ICMP_UGE:
+    if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
+      RHS = getAddExpr(getConstant(RHS->getType(), -1, false), RHS,
+                       /*HasNUW=*/true, /*HasNSW=*/false);
+      Cond = ICmpInst::ICMP_UGT;
+    } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
+      LHS = getAddExpr(getConstant(RHS->getType(), 1, false), LHS,
+                       /*HasNUW=*/true, /*HasNSW=*/false);
+      Cond = ICmpInst::ICMP_UGT;
+    }
+    break;
+  default:
+    break;
+  }
+
   switch (Cond) {
   case ICmpInst::ICMP_NE: {                     // while (X != Y)
     // Convert to: while (X-Y != 0)
@@ -4723,6 +4806,204 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
   return false;
 }
 
+/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
+/// predicate Pred. Return true iff any changes were made.
+///
+bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
+                                           const SCEV *&LHS, const SCEV *&RHS) {
+  bool Changed = false;
+
+  // Canonicalize a constant to the right side.
+  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
+    // Check for both operands constant.
+    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+      if (ConstantExpr::getICmp(Pred,
+                                LHSC->getValue(),
+                                RHSC->getValue())->isNullValue())
+        goto trivially_false;
+      else
+        goto trivially_true;
+    }
+    // Otherwise swap the operands to put the constant on the right.
+    std::swap(LHS, RHS);
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+    Changed = true;
+  }
+
+  // If we're comparing an addrec with a value which is loop-invariant in the
+  // addrec's loop, put the addrec on the left.
+  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
+    if (LHS->isLoopInvariant(AR->getLoop())) {
+      std::swap(LHS, RHS);
+      Pred = ICmpInst::getSwappedPredicate(Pred);
+      Changed = true;
+    }
+
+  // If there's a constant operand, canonicalize comparisons with boundary
+  // cases, and canonicalize *-or-equal comparisons to regular comparisons.
+  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
+    const APInt &RA = RC->getValue()->getValue();
+    switch (Pred) {
+    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
+    case ICmpInst::ICMP_EQ:
+    case ICmpInst::ICMP_NE:
+      break;
+    case ICmpInst::ICMP_UGE:
+      if ((RA - 1).isMinValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA - 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        Changed = true;
+        break;
+      }
+      if (RA.isMinValue()) goto trivially_true;
+
+      Pred = ICmpInst::ICMP_UGT;
+      RHS = getConstant(RA - 1);
+      Changed = true;
+      break;
+    case ICmpInst::ICMP_ULE:
+      if ((RA + 1).isMaxValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA + 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMinValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxValue()) goto trivially_true;
+
+      Pred = ICmpInst::ICMP_ULT;
+      RHS = getConstant(RA + 1);
+      Changed = true;
+      break;
+    case ICmpInst::ICMP_SGE:
+      if ((RA - 1).isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA - 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        Changed = true;
+        break;
+      }
+      if (RA.isMinSignedValue()) goto trivially_true;
+
+      Pred = ICmpInst::ICMP_SGT;
+      RHS = getConstant(RA - 1);
+      Changed = true;
+      break;
+    case ICmpInst::ICMP_SLE:
+      if ((RA + 1).isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA + 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxSignedValue()) goto trivially_true;
+
+      Pred = ICmpInst::ICMP_SLT;
+      RHS = getConstant(RA + 1);
+      Changed = true;
+      break;
+    case ICmpInst::ICMP_UGT:
+      if (RA.isMinValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        Changed = true;
+        break;
+      }
+      if ((RA + 1).isMaxValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA + 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxValue()) goto trivially_false;
+      break;
+    case ICmpInst::ICMP_ULT:
+      if (RA.isMaxValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        Changed = true;
+        break;
+      }
+      if ((RA - 1).isMinValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA - 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMinValue()) goto trivially_false;
+      break;
+    case ICmpInst::ICMP_SGT:
+      if (RA.isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        Changed = true;
+        break;
+      }
+      if ((RA + 1).isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA + 1);
+        Changed = true;
+        break;
+      }
+      if (RA.isMaxSignedValue()) goto trivially_false;
+      break;
+    case ICmpInst::ICMP_SLT:
+      if (RA.isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        Changed = true;
+        break;
+      }
+      if ((RA - 1).isMinSignedValue()) {
+       Pred = ICmpInst::ICMP_EQ;
+       RHS = getConstant(RA - 1);
+        Changed = true;
+       break;
+      }
+      if (RA.isMinSignedValue()) goto trivially_false;
+      break;
+    }
+  }
+
+  // Check for obvious equality.
+  if (HasSameValue(LHS, RHS)) {
+    if (ICmpInst::isTrueWhenEqual(Pred))
+      goto trivially_true;
+    if (ICmpInst::isFalseWhenEqual(Pred))
+      goto trivially_false;
+  }
+
+  // TODO: More simplifications are possible here.
+
+  return Changed;
+
+trivially_true:
+  // Return 0 == 0.
+  LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
+  Pred = ICmpInst::ICMP_EQ;
+  return true;
+
+trivially_false:
+  // Return 0 != 0.
+  LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
+  Pred = ICmpInst::ICMP_NE;
+  return true;
+}
+
 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
   return getSignedRange(S).getSignedMax().isNegative();
 }
@@ -4745,6 +5026,9 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
 
 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
                                        const SCEV *LHS, const SCEV *RHS) {
+  // Canonicalize the inputs first.
+  (void)SimplifyICmpOperands(Pred, LHS, RHS);
+
   // If LHS or RHS is an addrec, check to see if the condition is true in
   // every iteration of the loop.
   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
@@ -4960,117 +5244,12 @@ bool ScalarEvolution::isImpliedCond(Value *CondValue,
 
   // Canonicalize the query to match the way instcombine will have
   // canonicalized the comparison.
-  // First, put a constant operand on the right.
-  if (isa<SCEVConstant>(LHS)) {
-    std::swap(LHS, RHS);
-    Pred = ICmpInst::getSwappedPredicate(Pred);
-  }
-  // Then, canonicalize comparisons with boundary cases.
-  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
-    const APInt &RA = RC->getValue()->getValue();
-    switch (Pred) {
-    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
-    case ICmpInst::ICMP_EQ:
-    case ICmpInst::ICMP_NE:
-      break;
-    case ICmpInst::ICMP_UGE:
-      if ((RA - 1).isMinValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        RHS = getConstant(RA - 1);
-        break;
-      }
-      if (RA.isMaxValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        break;
-      }
-      if (RA.isMinValue()) return true;
-      break;
-    case ICmpInst::ICMP_ULE:
-      if ((RA + 1).isMaxValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        RHS = getConstant(RA + 1);
-        break;
-      }
-      if (RA.isMinValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        break;
-      }
-      if (RA.isMaxValue()) return true;
-      break;
-    case ICmpInst::ICMP_SGE:
-      if ((RA - 1).isMinSignedValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        RHS = getConstant(RA - 1);
-        break;
-      }
-      if (RA.isMaxSignedValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        break;
-      }
-      if (RA.isMinSignedValue()) return true;
-      break;
-    case ICmpInst::ICMP_SLE:
-      if ((RA + 1).isMaxSignedValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        RHS = getConstant(RA + 1);
-        break;
-      }
-      if (RA.isMinSignedValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        break;
-      }
-      if (RA.isMaxSignedValue()) return true;
-      break;
-    case ICmpInst::ICMP_UGT:
-      if (RA.isMinValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        break;
-      }
-      if ((RA + 1).isMaxValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        RHS = getConstant(RA + 1);
-        break;
-      }
-      if (RA.isMaxValue()) return false;
-      break;
-    case ICmpInst::ICMP_ULT:
-      if (RA.isMaxValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        break;
-      }
-      if ((RA - 1).isMinValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        RHS = getConstant(RA - 1);
-        break;
-      }
-      if (RA.isMinValue()) return false;
-      break;
-    case ICmpInst::ICMP_SGT:
-      if (RA.isMinSignedValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        break;
-      }
-      if ((RA + 1).isMaxSignedValue()) {
-        Pred = ICmpInst::ICMP_EQ;
-        RHS = getConstant(RA + 1);
-        break;
-      }
-      if (RA.isMaxSignedValue()) return false;
-      break;
-    case ICmpInst::ICMP_SLT:
-      if (RA.isMaxSignedValue()) {
-        Pred = ICmpInst::ICMP_NE;
-        break;
-      }
-      if ((RA - 1).isMinSignedValue()) {
-       Pred = ICmpInst::ICMP_EQ;
-       RHS = getConstant(RA - 1);
-       break;
-      }
-      if (RA.isMinSignedValue()) return false;
-      break;
-    }
-  }
+  if (SimplifyICmpOperands(Pred, LHS, RHS))
+    if (LHS == RHS)
+      return Pred == ICmpInst::ICMP_EQ;
+  if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
+    if (FoundLHS == FoundRHS)
+      return Pred == ICmpInst::ICMP_NE;
 
   // Check to see if we can make the LHS or RHS match.
   if (LHS == FoundRHS || RHS == FoundLHS) {