Fix LoopAccessAnalysis when potentially nullptr check are involved
[oota-llvm.git] / lib / Analysis / ValueTracking.cpp
index 0fe87176f88b2388b6bc24831b397eeaae2a06a1..1187de7b59bd40e98aefc4536451199eca5213e2 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/InstructionSimplify.h"
@@ -366,26 +367,30 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW,
 }
 
 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
-                                             APInt &KnownZero) {
+                                             APInt &KnownZero,
+                                             APInt &KnownOne) {
   unsigned BitWidth = KnownZero.getBitWidth();
   unsigned NumRanges = Ranges.getNumOperands() / 2;
   assert(NumRanges >= 1);
 
-  // Use the high end of the ranges to find leading zeros.
-  unsigned MinLeadingZeros = BitWidth;
+  KnownZero.setAllBits();
+  KnownOne.setAllBits();
+
   for (unsigned i = 0; i < NumRanges; ++i) {
     ConstantInt *Lower =
         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
     ConstantInt *Upper =
         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
     ConstantRange Range(Lower->getValue(), Upper->getValue());
-    if (Range.isWrappedSet())
-      MinLeadingZeros = 0; // -1 has no zeros
-    unsigned LeadingZeros = (Upper->getValue() - 1).countLeadingZeros();
-    MinLeadingZeros = std::min(LeadingZeros, MinLeadingZeros);
-  }
 
-  KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros);
+    // The first CommonPrefixBits of all values in Range are equal.
+    unsigned CommonPrefixBits =
+        (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros();
+
+    APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
+    KnownOne &= Range.getUnsignedMax() & Mask;
+    KnownZero &= ~Range.getUnsignedMax() & Mask;
+  }
 }
 
 static bool isEphemeralValueOf(Instruction *I, const Value *E) {
@@ -405,14 +410,8 @@ static bool isEphemeralValueOf(Instruction *I, const Value *E) {
       continue;
 
     // If all uses of this value are ephemeral, then so is this value.
-    bool FoundNEUse = false;
-    for (const User *I : V->users())
-      if (!EphValues.count(I)) {
-        FoundNEUse = true;
-        break;
-      }
-
-    if (!FoundNEUse) {
+    if (std::all_of(V->user_begin(), V->user_end(),
+                    [&](const User *U) { return EphValues.count(U); })) {
       if (V == E)
         return true;
 
@@ -1010,9 +1009,18 @@ static void computeKnownBitsFromShiftOperator(Operator *I,
   // calculation. Reusing the APInts here to prevent unnecessary allocations.
   KnownZero.clearAllBits(), KnownOne.clearAllBits();
 
+  // If we know the shifter operand is nonzero, we can sometimes infer more
+  // known bits. However this is expensive to compute, so be lazy about it and
+  // only compute it when absolutely necessary.
+  Optional<bool> ShifterOperandIsNonZero;
+
   // Early exit if we can't constrain any well-defined shift amount.
-  if (!(ShiftAmtKZ & (BitWidth-1)) && !(ShiftAmtKO & (BitWidth-1)))
-    return;
+  if (!(ShiftAmtKZ & (BitWidth - 1)) && !(ShiftAmtKO & (BitWidth - 1))) {
+    ShifterOperandIsNonZero =
+        isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q);
+    if (!*ShifterOperandIsNonZero)
+      return;
+  }
 
   computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q);
 
@@ -1024,6 +1032,16 @@ static void computeKnownBitsFromShiftOperator(Operator *I,
       continue;
     if ((ShiftAmt | ShiftAmtKO) != ShiftAmt)
       continue;
+    // If we know the shifter is nonzero, we may be able to infer more known
+    // bits. This check is sunk down as far as possible to avoid the expensive
+    // call to isKnownNonZero if the cheaper checks above fail.
+    if (ShiftAmt == 0) {
+      if (!ShifterOperandIsNonZero.hasValue())
+        ShifterOperandIsNonZero =
+            isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q);
+      if (*ShifterOperandIsNonZero)
+        continue;
+    }
 
     KnownZero &= KZF(KnownZero2, ShiftAmt);
     KnownOne  &= KOF(KnownOne2, ShiftAmt);
@@ -1048,7 +1066,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero,
   default: break;
   case Instruction::Load:
     if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range))
-      computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+      computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
     break;
   case Instruction::And: {
     // If either the LHS or the RHS are Zero, the result is zero.
@@ -1440,7 +1458,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero,
   case Instruction::Call:
   case Instruction::Invoke:
     if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range))
-      computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+      computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
     // If a range metadata is attached to this IntrinsicInst, intersect the
     // explicit range specified by the metadata and the implicit range of
     // the intrinsic.
@@ -4063,3 +4081,43 @@ ConstantRange llvm::getConstantRangeFromMetadata(MDNode &Ranges) {
 
   return CR;
 }
+
+bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
+  assert(LHS->getType() == RHS->getType() && "mismatched type");
+  Type *OpTy = LHS->getType();
+  assert(OpTy->getScalarType()->isIntegerTy(1));
+
+  // LHS ==> RHS by definition
+  if (LHS == RHS) return true;
+
+  if (OpTy->isVectorTy())
+    // TODO: extending the code below to handle vectors
+    return false;
+  assert(OpTy->isIntegerTy(1) && "implied by above");
+
+  ICmpInst::Predicate APred, BPred;
+  Value *I;
+  Value *L;
+  ConstantInt *CI;
+  // i +_{nsw} C_{>0} <s L ==> i <s L
+  if (match(LHS, m_ICmp(APred,
+                        m_NSWAdd(m_Value(I), m_ConstantInt(CI)),
+                        m_Value(L))) &&
+      APred == ICmpInst::ICMP_SLT &&
+      !CI->isNegative() &&
+      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+      BPred == ICmpInst::ICMP_SLT)
+    return true;
+
+  // i +_{nuw} C_{>0} <u L ==> i <u L
+  if (match(LHS, m_ICmp(APred,
+                        m_NUWAdd(m_Value(I), m_ConstantInt(CI)),
+                        m_Value(L))) &&
+      APred == ICmpInst::ICMP_ULT &&
+      !CI->isNegative() &&
+      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+      BPred == ICmpInst::ICMP_ULT)
+    return true;
+
+  return false;
+}