Teach ScalarEvolution to sharpen range information.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index cf45fe86359ea8c610cd9f6ed3c4a940ead4572d..8d1c632008a21be6d6a0d6cc3a3ec2b315bbb56e 100644 (file)
@@ -59,6 +59,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
@@ -79,6 +80,7 @@
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -3662,6 +3664,31 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
   return 0;
 }
 
+/// GetRangeFromMetadata - Helper method to assign a range to V from
+/// metadata present in the IR.
+static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
+  if (Instruction *I = dyn_cast<Instruction>(V)) {
+    if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
+      ConstantRange TotalRange(
+          cast<IntegerType>(I->getType())->getBitWidth(), false);
+
+      unsigned NumRanges = MD->getNumOperands() / 2;
+      assert(NumRanges >= 1);
+
+      for (unsigned i = 0; i < NumRanges; ++i) {
+        ConstantInt *Lower = cast<ConstantInt>(MD->getOperand(2*i + 0));
+        ConstantInt *Upper = cast<ConstantInt>(MD->getOperand(2*i + 1));
+        ConstantRange Range(Lower->getValue(), Upper->getValue());
+        TotalRange = TotalRange.unionWith(Range);
+      }
+
+      return TotalRange;
+    }
+  }
+
+  return None;
+}
+
 /// getUnsignedRange - Determine the unsigned range for a particular SCEV.
 ///
 ConstantRange
@@ -3791,6 +3818,11 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
   }
 
   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    // Check if the IR explicitly contains !range metadata.
+    Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
+    if (MDRange.hasValue())
+      ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
+
     // For a SCEVUnknown, ask ValueTracking.
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
     computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT);
@@ -3942,6 +3974,11 @@ ScalarEvolution::getSignedRange(const SCEV *S) {
   }
 
   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    // Check if the IR explicitly contains !range metadata.
+    Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
+    if (MDRange.hasValue())
+      ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
+
     // For a SCEVUnknown, ask ValueTracking.
     if (!U->getValue()->getType()->isIntegerTy() && !DL)
       return setSignedRange(U, ConservativeResult);
@@ -4330,6 +4367,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 //                   Iteration Count Computation Code
 //
 
+unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
+  if (BasicBlock *ExitingBB = L->getExitingBlock())
+    return getSmallConstantTripCount(L, ExitingBB);
+
+  // No trip count information for multiple exits.
+  return 0;
+}
+
 /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
 /// normal unsigned value. Returns 0 if the trip count is unknown or not
 /// constant. Will also return 0 if the maximum trip count is very large (>=
@@ -4342,6 +4387,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 /// prematurely via another branch.
 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
                                                     BasicBlock *ExitingBlock) {
+  assert(ExitingBlock && "Must pass a non-null exiting block!");
+  assert(L->isLoopExiting(ExitingBlock) &&
+         "Exiting block must actually branch out of the loop!");
   const SCEVConstant *ExitCount =
       dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
   if (!ExitCount)
@@ -4357,6 +4405,14 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
   return ((unsigned)ExitConst->getZExtValue()) + 1;
 }
 
+unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
+  if (BasicBlock *ExitingBB = L->getExitingBlock())
+    return getSmallConstantTripMultiple(L, ExitingBB);
+
+  // No trip multiple information for multiple exits.
+  return 0;
+}
+
 /// getSmallConstantTripMultiple - Returns the largest constant divisor of the
 /// trip count of this loop as a normal unsigned value, if possible. This
 /// means that the actual trip count is always a multiple of the returned
@@ -4372,6 +4428,9 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
 unsigned
 ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
                                               BasicBlock *ExitingBlock) {
+  assert(ExitingBlock && "Must pass a non-null exiting block!");
+  assert(L->isLoopExiting(ExitingBlock) &&
+         "Exiting block must actually branch out of the loop!");
   const SCEV *ExitCount = getExitCount(L, ExitingBlock);
   if (ExitCount == getCouldNotCompute())
     return 1;
@@ -6541,6 +6600,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
   // (interprocedural conditions notwithstanding).
   if (!L) return true;
 
+  if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+
   BasicBlock *Latch = L->getLoopLatch();
   if (!Latch)
     return false;
@@ -6576,6 +6637,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
   // (interprocedural conditions notwithstanding).
   if (!L) return false;
 
+  if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+
   // Starting at the loop predecessor, climb up the predecessor chain, as long
   // as there are predecessors that can be found that have unique successors
   // leading to the original header.
@@ -6719,6 +6782,44 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
                                    RHS, LHS, FoundLHS, FoundRHS);
   }
 
+  if (FoundPred == ICmpInst::ICMP_NE) {
+    // If we are predicated on something with range [a, b) known not
+    // equal to a, the range can be sharpened to [a + 1, b).  Use this
+    // fact.
+    auto CheckRange = [this, Pred, LHS, RHS](const SCEV *C, const SCEV *V) {
+      if (!isa<SCEVConstant>(C)) return false;
+
+      ConstantInt *CI = cast<SCEVConstant>(C)->getValue();
+      APInt Min = ICmpInst::isSigned(Pred) ?
+        getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
+
+      if (Min!= CI->getValue()) return false; // nothing to sharpen
+      Min++;
+
+      // We know V >= Min, in the same signedness as in Pred.  We can
+      // use this to simplify slt and ult.  If the range had a single
+      // value, Min, we now know that FoundValue can never be true;
+      // and any answer is a correct answer.
+
+      switch (Pred) {
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_UGT:
+          return isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min));
+
+        default:
+          llvm_unreachable("don't call with predicates other than ICMP_SGT "
+                           "and ICMP_UGT");
+      }
+    };
+
+    if ((Pred == ICmpInst::ICMP_SGT) || (Pred == ICmpInst::ICMP_UGT)) {
+      // Inequality is reflexive -- check both the combinations.
+      if (CheckRange(FoundLHS, FoundRHS) || CheckRange(FoundRHS, FoundLHS)) {
+        return true;
+      }
+    }
+  }
+
   // Check whether the actual condition is beyond sufficient.
   if (FoundPred == ICmpInst::ICMP_EQ)
     if (ICmpInst::isTrueWhenEqual(Pred))