[IndVars] Make loop varying predicates loop invariant.
authorSanjoy Das <sanjoy@playingwithpointers.com>
Mon, 27 Jul 2015 21:42:49 +0000 (21:42 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Mon, 27 Jul 2015 21:42:49 +0000 (21:42 +0000)
Summary:
Was D9784: "Remove loop variant range check when induction variable is
strictly increasing"

This change re-implements D9784 with the two differences:

 1. It does not use SCEVExpander and does not generate new
    instructions.  Instead, it does a quick local search for existing
    `llvm::Value`s that it needs when modifying the `icmp`
    instruction.

 2. It is more general -- it deals with both increasing and decreasing
    induction variables.

I've added all of the tests included with D9784, and two more.

As an example on what this change does (copied from D9784):

Given C code:

```
for (int i = M; i < N; i++) // i is known not to overflow
  if (i < 0) break;
  a[i] = 0;
}
```

This transformation produces:

```
for (int i = M; i < N; i++)
  if (M < 0) break;
  a[i] = 0;
}
```

Which can be unswitched into:

```
if (!(M < 0))
  for (int i = M; i < N; i++)
    a[i] = 0;
}
```

I went back and forth on whether the top level logic should live in
`SimplifyIndvar::eliminateIVComparison` or be put into its own
routine.  Right now I've put it under `eliminateIVComparison` because
even though the `icmp` is not *eliminated*, it no longer is an IV
comparison.  I'm open to putting it in its own helper routine if you
think that is better.

Reviewers: reames, nicholas, atrick

Subscribers: llvm-commits

Differential Revision: http://reviews.llvm.org/D11278

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@243331 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
lib/Transforms/Utils/SimplifyIndVar.cpp
test/Transforms/IndVarSimplify/loop-invariant-conditions.ll [new file with mode: 0644]

index a4ad5573d804530c5349cff1efb738e33501020c..0f3f397d50bc083777a0e3d7cdccb9c58510fb9b 100644 (file)
@@ -48,6 +48,7 @@ namespace llvm {
   class LoopInfo;
   class Operator;
   class SCEVUnknown;
+  class SCEVAddRecExpr;
   class SCEV;
   template<> struct FoldingSetTrait<SCEV>;
 
@@ -579,6 +580,21 @@ namespace llvm {
     bool proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step,
                                    const Loop *L);
 
+    bool isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
+                                  ICmpInst::Predicate Pred, bool &Increasing);
+
+    /// Return true if, for all loop invariant X, the predicate "LHS `Pred` X"
+    /// is monotonically increasing or decreasing.  In the former case set
+    /// `Increasing` to true and in the latter case set `Increasing` to false.
+    ///
+    /// A predicate is said to be monotonically increasing if may go from being
+    /// false to being true as the loop iterates, but never the other way
+    /// around.  A predicate is said to be monotonically decreasing if may go
+    /// from being true to being false as the loop iterates, but never the other
+    /// way around.
+    bool isMonotonicPredicate(const SCEVAddRecExpr *LHS,
+                              ICmpInst::Predicate Pred, bool &Increasing);
+
   public:
     static char ID; // Pass identification, replacement for typeid
     ScalarEvolution();
@@ -900,6 +916,16 @@ namespace llvm {
     bool isKnownPredicate(ICmpInst::Predicate Pred,
                           const SCEV *LHS, const SCEV *RHS);
 
+    /// Return true if the result of the predicate LHS `Pred` RHS is loop
+    /// invariant with respect to L.  Set InvariantPred, InvariantLHS and
+    /// InvariantLHS so that InvariantLHS `InvariantPred` InvariantRHS is the
+    /// loop invariant form of LHS `Pred` RHS.
+    bool isLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS,
+                                  const SCEV *RHS, const Loop *L,
+                                  ICmpInst::Predicate &InvariantPred,
+                                  const SCEV *&InvariantLHS,
+                                  const SCEV *&InvariantRHS);
+
     /// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
     /// predicate Pred. Return true iff any changes were made. If the
     /// operands are provably equal or unequal, LHS and RHS are set to
index 85555a0ca50f3762f875d66de6d678e88431b1e0..8e65c1a97f0d6365aa5d9d7d43abccf5ef66c6e2 100644 (file)
@@ -6616,6 +6616,139 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
   return isKnownPredicateWithRanges(Pred, LHS, RHS);
 }
 
+bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
+                                           ICmpInst::Predicate Pred,
+                                           bool &Increasing) {
+  bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing);
+
+#ifndef NDEBUG
+  // Verify an invariant: inverting the predicate should turn a monotonically
+  // increasing change to a monotonically decreasing one, and vice versa.
+  bool IncreasingSwapped;
+  bool ResultSwapped = isMonotonicPredicateImpl(
+      LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped);
+
+  assert(Result == ResultSwapped && "should be able to analyze both!");
+  if (ResultSwapped)
+    assert(Increasing == !IncreasingSwapped &&
+           "monotonicity should flip as we flip the predicate");
+#endif
+
+  return Result;
+}
+
+bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
+                                               ICmpInst::Predicate Pred,
+                                               bool &Increasing) {
+  SCEV::NoWrapFlags FlagsRequired = SCEV::FlagAnyWrap;
+  bool IncreasingOnNonNegativeStep = false;
+
+  switch (Pred) {
+  default:
+    return false; // Conservative answer
+
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_UGE:
+    FlagsRequired = SCEV::FlagNUW;
+    IncreasingOnNonNegativeStep = true;
+    break;
+
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULE:
+    FlagsRequired = SCEV::FlagNUW;
+    IncreasingOnNonNegativeStep = false;
+    break;
+
+  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_SGE:
+    FlagsRequired = SCEV::FlagNSW;
+    IncreasingOnNonNegativeStep = true;
+    break;
+
+  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLE:
+    FlagsRequired = SCEV::FlagNSW;
+    IncreasingOnNonNegativeStep = false;
+    break;
+  }
+
+  if (!LHS->getNoWrapFlags(FlagsRequired))
+    return false;
+
+  // A zero step value for LHS means the induction variable is essentially a
+  // loop invariant value. We don't really depend on the predicate actually
+  // flipping from false to true (for increasing predicates, and the other way
+  // around for decreasing predicates), all we care about is that *if* the
+  // predicate changes then it only changes from false to true.
+  //
+  // A zero step value in itself is not very useful, but there may be places
+  // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
+  // as general as possible.
+
+  if (isKnownNonNegative(LHS->getStepRecurrence(*this))) {
+    Increasing = IncreasingOnNonNegativeStep;
+    return true;
+  }
+
+  if (isKnownNonPositive(LHS->getStepRecurrence(*this))) {
+    Increasing = !IncreasingOnNonNegativeStep;
+    return true;
+  }
+
+  return false;
+}
+
+bool ScalarEvolution::isLoopInvariantPredicate(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+    ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
+    const SCEV *&InvariantRHS) {
+
+  // If there is a loop-invariant, force it into the RHS, otherwise bail out.
+  if (!isLoopInvariant(RHS, L)) {
+    if (!isLoopInvariant(LHS, L))
+      return false;
+
+    std::swap(LHS, RHS);
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+  }
+
+  const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+  if (!ArLHS || ArLHS->getLoop() != L)
+    return false;
+
+  bool Increasing;
+  if (!isMonotonicPredicate(ArLHS, Pred, Increasing))
+    return false;
+
+  // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
+  // true as the loop iterates, and the backedge is control dependent on
+  // "ArLHS `Pred` RHS" == true then we can reason as follows:
+  //
+  //   * if the predicate was false in the first iteration then the predicate
+  //     is never evaluated again, since the loop exits without taking the
+  //     backedge.
+  //   * if the predicate was true in the first iteration then it will
+  //     continue to be true for all future iterations since it is
+  //     monotonically increasing.
+  //
+  // For both the above possibilities, we can replace the loop varying
+  // predicate with its value on the first iteration of the loop (which is
+  // loop invariant).
+  //
+  // A similar reasoning applies for a monotonically decreasing predicate, by
+  // replacing true with false and false with true in the above two bullets.
+
+  auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
+
+  if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
+    return false;
+
+  InvariantPred = Pred;
+  InvariantLHS = ArLHS->getStart();
+  InvariantRHS = RHS;
+  return true;
+}
+
 bool
 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
                                             const SCEV *LHS, const SCEV *RHS) {
index ab30aa17c76bc912a56849517cd680de70cf27e8..b3055a492be8a55d0e8a1f0a998c867a4b1537c0 100644 (file)
@@ -166,19 +166,68 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) {
   S = SE->getSCEVAtScope(S, ICmpLoop);
   X = SE->getSCEVAtScope(X, ICmpLoop);
 
+  ICmpInst::Predicate InvariantPredicate;
+  const SCEV *InvariantLHS, *InvariantRHS;
+
+  const char *Verb = nullptr;
+
   // If the condition is always true or always false, replace it with
   // a constant value.
-  if (SE->isKnownPredicate(Pred, S, X))
+  if (SE->isKnownPredicate(Pred, S, X)) {
     ICmp->replaceAllUsesWith(ConstantInt::getTrue(ICmp->getContext()));
-  else if (SE->isKnownPredicate(ICmpInst::getInversePredicate(Pred), S, X))
+    DeadInsts.emplace_back(ICmp);
+    Verb = "Eliminated";
+  } else if (SE->isKnownPredicate(ICmpInst::getInversePredicate(Pred), S, X)) {
     ICmp->replaceAllUsesWith(ConstantInt::getFalse(ICmp->getContext()));
-  else
+    DeadInsts.emplace_back(ICmp);
+    Verb = "Eliminated";
+  } else if (isa<PHINode>(IVOperand) &&
+             SE->isLoopInvariantPredicate(Pred, S, X, ICmpLoop,
+                                          InvariantPredicate, InvariantLHS,
+                                          InvariantRHS)) {
+
+    // Rewrite the comparision to a loop invariant comparision if it can be done
+    // cheaply, where cheaply means "we don't need to emit any new
+    // instructions".
+
+    Value *NewLHS = nullptr, *NewRHS = nullptr;
+
+    if (S == InvariantLHS || X == InvariantLHS)
+      NewLHS =
+          ICmp->getOperand(S == InvariantLHS ? IVOperIdx : (1 - IVOperIdx));
+
+    if (S == InvariantRHS || X == InvariantRHS)
+      NewRHS =
+          ICmp->getOperand(S == InvariantRHS ? IVOperIdx : (1 - IVOperIdx));
+
+    for (Value *Incoming : cast<PHINode>(IVOperand)->incoming_values()) {
+      if (NewLHS && NewRHS)
+        break;
+
+      const SCEV *IncomingS = SE->getSCEV(Incoming);
+
+      if (!NewLHS && IncomingS == InvariantLHS)
+        NewLHS = Incoming;
+      if (!NewRHS && IncomingS == InvariantRHS)
+        NewRHS = Incoming;
+    }
+
+    if (!NewLHS || !NewRHS)
+      // We could not find an existing value to replace either LHS or RHS.
+      // Generating new instructions has subtler tradeoffs, so avoid doing that
+      // for now.
+      return;
+
+    Verb = "Simplified";
+    ICmp->setPredicate(InvariantPredicate);
+    ICmp->setOperand(0, NewLHS);
+    ICmp->setOperand(1, NewRHS);
+  } else
     return;
 
-  DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n');
+  DEBUG(dbgs() << "INDVARS: " << Verb << " comparison: " << *ICmp << '\n');
   ++NumElimCmp;
   Changed = true;
-  DeadInsts.emplace_back(ICmp);
 }
 
 /// SimplifyIVUsers helper for eliminating useless
diff --git a/test/Transforms/IndVarSimplify/loop-invariant-conditions.ll b/test/Transforms/IndVarSimplify/loop-invariant-conditions.ll
new file mode 100644 (file)
index 0000000..eee321d
--- /dev/null
@@ -0,0 +1,279 @@
+; RUN: opt -S -indvars %s | FileCheck %s
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define void @test1(i64 %start) {
+; CHECK-LABEL: @test1
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+; CHECK: %cmp1 = icmp slt i64 %start, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test2(i64 %start) {
+; CHECK-LABEL: @test2
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+; CHECK: %cmp1 = icmp sle i64 %start, -1
+  %cmp1 = icmp sle i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+; As long as the test dominates the backedge, we're good
+define void @test3(i64 %start) {
+; CHECK-LABEL: @test3
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %for.end
+
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp slt i64 %start, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test4(i64 %start) {
+; CHECK-LABEL: @test4
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %for.end
+
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp sgt i64 %start, -1
+  %cmp1 = icmp sgt i64 %indvars.iv, -1
+  br i1 %cmp1, label %loop, label %for.end
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test5(i64 %start) {
+; CHECK-LABEL: @test5
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nuw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %for.end
+
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp ugt i64 %start, 100
+  %cmp1 = icmp ugt i64 %indvars.iv, 100
+  br i1 %cmp1, label %loop, label %for.end
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test6(i64 %start) {
+; CHECK-LABEL: @test6
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nuw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %for.end
+
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp ult i64 %start, 100
+  %cmp1 = icmp ult i64 %indvars.iv, 100
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test7(i64 %start, i64* %inc_ptr) {
+; CHECK-LABEL: @test7
+entry:
+  %inc = load i64, i64* %inc_ptr, !range !0
+  %ok = icmp sge i64 %inc, 0
+  br i1 %ok, label %loop, label %for.end
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, %inc
+; CHECK: %cmp1 = icmp slt i64 %start, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+!0 = !{i64 0, i64 100}
+
+; Negative test - we can't show that the internal branch executes, so we can't
+; fold the test to a loop invariant one.
+define void @test1_neg(i64 %start) {
+; CHECK-LABEL: @test1_neg
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %skip
+skip:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp slt i64 %indvars.iv, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %backedge
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+  br label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+; Slightly subtle version of @test4 where the icmp dominates the backedge,
+; but the exit branch doesn't.  
+define void @test2_neg(i64 %start) {
+; CHECK-LABEL: @test2_neg
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+; CHECK: %cmp1 = icmp slt i64 %indvars.iv, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp, label %backedge, label %skip
+skip:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+  br i1 %cmp1, label %for.end, label %backedge
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+  br label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+; The branch has to exit the loop if the condition is true
+define void @test3_neg(i64 %start) {
+; CHECK-LABEL: @test3_neg
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+; CHECK: %cmp1 = icmp slt i64 %indvars.iv, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %loop, label %for.end
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test4_neg(i64 %start) {
+; CHECK-LABEL: @test4_neg
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %backedge ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %cmp = icmp eq i64 %indvars.iv.next, 25
+  br i1 %cmp, label %backedge, label %for.end
+
+backedge:
+  ; prevent flattening, needed to make sure we're testing what we intend
+  call void @foo() 
+; CHECK: %cmp1 = icmp sgt i64 %indvars.iv, -1
+  %cmp1 = icmp sgt i64 %indvars.iv, -1
+
+; %cmp1 can be made loop invariant only if the branch below goes to
+; %the header when %cmp1 is true.
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test5_neg(i64 %start, i64 %inc) {
+; CHECK-LABEL: @test5_neg
+entry:
+  br label %loop
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, %inc
+; CHECK: %cmp1 = icmp slt i64 %indvars.iv, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+define void @test8(i64 %start, i64* %inc_ptr) {
+; CHECK-LABEL: @test8
+entry:
+  %inc = load i64, i64* %inc_ptr, !range !1
+  %ok = icmp sge i64 %inc, 0
+  br i1 %ok, label %loop, label %for.end
+
+loop:
+  %indvars.iv = phi i64 [ %start, %entry ], [ %indvars.iv.next, %loop ]
+  %indvars.iv.next = add nsw i64 %indvars.iv, %inc
+; CHECK: %cmp1 = icmp slt i64 %indvars.iv, -1
+  %cmp1 = icmp slt i64 %indvars.iv, -1
+  br i1 %cmp1, label %for.end, label %loop
+
+for.end:                                          ; preds = %if.end, %entry
+  ret void
+}
+
+!1 = !{i64 -1, i64 100}
+
+
+declare void @foo()