[SCEV] Reapply 'Exploit A < B => (A+K) < (B+K) when possible'
authorSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 25 Sep 2015 23:53:45 +0000 (23:53 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 25 Sep 2015 23:53:45 +0000 (23:53 +0000)
Summary:

This change teaches SCEV's `isImpliedCond` two new identities:

  A u< B u< -C          =>  (A + C) u< (B + C)
  A s< B s< INT_MIN - C =>  (A + C) s< (B + C)

While these are useful on their own, they're really intended to support
D12950.

The original checkin, r248606 had to be backed out due to an issue with
a ObjCXX unit test.  That issue is now fixed, so re-landing.

Reviewers: atrick, reames, majnemer, nlewycky, hfinkel

Subscribers: aadg, sanjoy, llvm-commits

Differential Revision: http://reviews.llvm.org/D12948

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@248637 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
test/Transforms/IndVarSimplify/eliminate-comparison.ll

index 72975d423a7d9f06eaa2eb681edd5931619a0959..62d66c246d717bedfa2bee84fc1b48b18c664209 100644 (file)
@@ -536,6 +536,17 @@ namespace llvm {
                                         const SCEV *FoundLHS,
                                         const SCEV *FoundRHS);
 
+    /// Test whether the condition described by Pred, LHS, and RHS is true
+    /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
+    /// true.
+    ///
+    /// This routine tries to rule out certain kinds of integer overflow, and
+    /// then tries to reason about arithmetic properties of the predicates.
+    bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred,
+                                            const SCEV *LHS, const SCEV *RHS,
+                                            const SCEV *FoundLHS,
+                                            const SCEV *FoundRHS);
+
     /// If we know that the specified Phi is in the header of its containing
     /// loop, we know the loop executes a constant number of times, and the PHI
     /// node is just a recurrence involving constants, fold it.
index c8cd4ec8ae111aebd4ed5ec4b9a5898053d0d900..f87a5596c86689d6e812c82bf7a6e2fa6fb5beb0 100644 (file)
@@ -7288,6 +7288,146 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   return false;
 }
 
+// Return true if More == (Less + C), where C is a constant.
+static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More,
+                        APInt &C) {
+  // We avoid subtracting expressions here because this function is usually
+  // fairly deep in the call stack (i.e. is called many times).
+
+  auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) {
+    const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
+    if (!AE || AE->getNumOperands() != 2)
+      return false;
+
+    L = AE->getOperand(0);
+    R = AE->getOperand(1);
+    return true;
+  };
+
+  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
+    const auto *LAR = cast<SCEVAddRecExpr>(Less);
+    const auto *MAR = cast<SCEVAddRecExpr>(More);
+
+    if (LAR->getLoop() != MAR->getLoop())
+      return false;
+
+    // We look at affine expressions only; not for correctness but to keep
+    // getStepRecurrence cheap.
+    if (!LAR->isAffine() || !MAR->isAffine())
+      return false;
+
+    if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE))
+      return false;
+
+    Less = LAR->getStart();
+    More = MAR->getStart();
+
+    // fall through
+  }
+
+  if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
+    const auto &M = cast<SCEVConstant>(More)->getValue()->getValue();
+    const auto &L = cast<SCEVConstant>(Less)->getValue()->getValue();
+    C = M - L;
+    return true;
+  }
+
+  const SCEV *L, *R;
+  if (SplitBinaryAdd(Less, L, R))
+    if (const auto *LC = dyn_cast<SCEVConstant>(L))
+      if (R == More) {
+        C = -(LC->getValue()->getValue());
+        return true;
+      }
+
+  if (SplitBinaryAdd(More, L, R))
+    if (const auto *LC = dyn_cast<SCEVConstant>(L))
+      if (R == Less) {
+        C = LC->getValue()->getValue();
+        return true;
+      }
+
+  return false;
+}
+
+bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+    const SCEV *FoundLHS, const SCEV *FoundRHS) {
+  if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
+    return false;
+
+  const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+  if (!AddRecLHS)
+    return false;
+
+  const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
+  if (!AddRecFoundLHS)
+    return false;
+
+  // We'd like to let SCEV reason about control dependencies, so we constrain
+  // both the inequalities to be about add recurrences on the same loop.  This
+  // way we can use isLoopEntryGuardedByCond later.
+
+  const Loop *L = AddRecFoundLHS->getLoop();
+  if (L != AddRecLHS->getLoop())
+    return false;
+
+  //  FoundLHS u< FoundRHS u< -C =>  (FoundLHS + C) u< (FoundRHS + C) ... (1)
+  //
+  //  FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
+  //                                                                  ... (2)
+  //
+  // Informal proof for (2), assuming (1) [*]:
+  //
+  // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
+  //
+  // Then
+  //
+  //       FoundLHS s< FoundRHS s< INT_MIN - C
+  // <=>  (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C   [ using (3) ]
+  // <=>  (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
+  // <=>  (FoundLHS + INT_MIN + C + INT_MIN) s<
+  //                        (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
+  // <=>  FoundLHS + C s< FoundRHS + C
+  //
+  // [*]: (1) can be proved by ruling out overflow.
+  //
+  // [**]: This can be proved by analyzing all the four possibilities:
+  //    (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
+  //    (A s>= 0, B s>= 0).
+  //
+  // Note:
+  // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
+  // will not sign underflow.  For instance, say FoundLHS = (i8 -128), FoundRHS
+  // = (i8 -127) and C = (i8 -100).  Then INT_MIN - C = (i8 -28), and FoundRHS
+  // s< (INT_MIN - C).  Lack of sign overflow / underflow in "FoundRHS + C" is
+  // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
+  // C)".
+
+  APInt LDiff, RDiff;
+  if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) ||
+      !IsConstDiff(*this, FoundRHS, RHS, RDiff) ||
+      LDiff != RDiff)
+    return false;
+
+  if (LDiff == 0)
+    return true;
+
+  unsigned Width = cast<IntegerType>(RHS->getType())->getBitWidth();
+  APInt FoundRHSLimit;
+
+  if (Pred == CmpInst::ICMP_ULT) {
+    FoundRHSLimit = -RDiff;
+  } else {
+    assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
+    FoundRHSLimit = APInt::getSignedMinValue(Width) - RDiff;
+  }
+
+  // Try to prove (1) or (2), as needed.
+  return isLoopEntryGuardedByCond(L, Pred, FoundRHS,
+                                  getConstant(FoundRHSLimit));
+}
+
 /// isImpliedCondOperands - Test whether the condition described by Pred,
 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
 /// and FoundRHS is true.
@@ -7298,6 +7438,9 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
     return true;
 
+  if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
+    return true;
+
   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
                                      FoundLHS, FoundRHS) ||
          // ~x < ~y --> x > y
index 4d14b3681c5d89eee3ea107d635647d613607807..261ba387028bb1dc2ff4a3fea2d2a5c86d35061e 100644 (file)
@@ -209,3 +209,152 @@ assert77:                                         ; preds = %noassert68
 unrolledend:                                      ; preds = %forcond38
   ret i32 0
 }
+
+declare void @side_effect()
+
+define void @func_13(i32* %len.ptr) {
+; CHECK-LABEL: @func_13(
+ entry:
+  %len = load i32, i32* %len.ptr, !range !0
+  %len.sub.1 = add i32 %len, -1
+  %len.is.zero = icmp eq i32 %len, 0
+  br i1 %len.is.zero, label %leave, label %loop
+
+ loop:
+; CHECK: loop:
+  %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
+  call void @side_effect()
+  %iv.inc = add i32 %iv, 1
+  %iv.cmp = icmp ult i32 %iv, %len
+  br i1 %iv.cmp, label %be, label %leave
+; CHECK: br i1 true, label %be, label %leave
+
+ be:
+  call void @side_effect()
+  %be.cond = icmp ult i32 %iv, %len.sub.1
+  br i1 %be.cond, label %loop, label %leave
+
+ leave:
+  ret void
+}
+
+define void @func_14(i32* %len.ptr) {
+; CHECK-LABEL: @func_14(
+ entry:
+  %len = load i32, i32* %len.ptr, !range !0
+  %len.sub.1 = add i32 %len, -1
+  %len.is.zero = icmp eq i32 %len, 0
+  %len.is.int_min = icmp eq i32 %len, 2147483648
+  %no.entry = or i1 %len.is.zero, %len.is.int_min
+  br i1 %no.entry, label %leave, label %loop
+
+ loop:
+; CHECK: loop:
+  %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
+  call void @side_effect()
+  %iv.inc = add i32 %iv, 1
+  %iv.cmp = icmp slt i32 %iv, %len
+  br i1 %iv.cmp, label %be, label %leave
+; CHECK: br i1 true, label %be, label %leave
+
+ be:
+  call void @side_effect()
+  %be.cond = icmp slt i32 %iv, %len.sub.1
+  br i1 %be.cond, label %loop, label %leave
+
+ leave:
+  ret void
+}
+
+define void @func_15(i32* %len.ptr) {
+; CHECK-LABEL: @func_15(
+ entry:
+  %len = load i32, i32* %len.ptr, !range !0
+  %len.add.1 = add i32 %len, 1
+  %len.add.1.is.zero = icmp eq i32 %len.add.1, 0
+  br i1 %len.add.1.is.zero, label %leave, label %loop
+
+ loop:
+; CHECK: loop:
+  %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
+  call void @side_effect()
+  %iv.inc = add i32 %iv, 1
+  %iv.cmp = icmp ult i32 %iv, %len.add.1
+  br i1 %iv.cmp, label %be, label %leave
+; CHECK: br i1 true, label %be, label %leave
+
+ be:
+  call void @side_effect()
+  %be.cond = icmp ult i32 %iv, %len
+  br i1 %be.cond, label %loop, label %leave
+
+ leave:
+  ret void
+}
+
+define void @func_16(i32* %len.ptr) {
+; CHECK-LABEL: @func_16(
+ entry:
+  %len = load i32, i32* %len.ptr, !range !0
+  %len.add.5 = add i32 %len, 5
+  %entry.cond.0 = icmp slt i32 %len, 2147483643
+  %entry.cond.1 = icmp slt i32 4, %len.add.5
+  %entry.cond = and i1 %entry.cond.0, %entry.cond.1
+  br i1 %entry.cond, label %loop, label %leave
+
+ loop:
+; CHECK: loop:
+  %iv = phi i32 [ 0, %entry ], [ %iv.inc, %be ]
+  call void @side_effect()
+  %iv.inc = add i32 %iv, 1
+  %iv.add.4 = add i32 %iv, 4
+  %iv.cmp = icmp slt i32 %iv.add.4, %len.add.5
+  br i1 %iv.cmp, label %be, label %leave
+; CHECK: br i1 true, label %be, label %leave
+
+ be:
+  call void @side_effect()
+  %be.cond = icmp slt i32 %iv, %len
+  br i1 %be.cond, label %loop, label %leave
+
+ leave:
+  ret void
+}
+
+define void @func_17(i32* %len.ptr) {
+; CHECK-LABEL: @func_17(
+ entry:
+  %len = load i32, i32* %len.ptr
+  %len.add.5 = add i32 %len, -5
+  %entry.cond.0 = icmp slt i32 %len, 2147483653 ;; 2147483653 == INT_MIN - (-5)
+  %entry.cond.1 = icmp slt i32 -6, %len.add.5
+  %entry.cond = and i1 %entry.cond.0, %entry.cond.1
+  br i1 %entry.cond, label %loop, label %leave
+
+ loop:
+; CHECK: loop:
+  %iv.2 = phi i32 [ 0, %entry ], [ %iv.2.inc, %be ]
+  %iv = phi i32 [ -6, %entry ], [ %iv.inc, %be ]
+  call void @side_effect()
+  %iv.inc = add i32 %iv, 1
+  %iv.2.inc = add i32 %iv.2, 1
+  %iv.cmp = icmp slt i32 %iv, %len.add.5
+
+; Deduces {-5,+,1} s< (-5 + %len) from {0,+,1} < %len
+; since %len s< INT_MIN - (-5) from the entry condition
+
+; CHECK: br i1 true, label %be, label %leave
+  br i1 %iv.cmp, label %be, label %leave
+
+ be:
+; CHECK: be:
+  call void @side_effect()
+  %be.cond = icmp slt i32 %iv.2, %len
+  br i1 %be.cond, label %loop, label %leave
+
+ leave:
+  ret void
+}
+
+!0 = !{i32 0, i32 2147483647}
+