Fix LoopAccessAnalysis when potentially nullptr check are involved
[oota-llvm.git] / lib / Analysis / ValueTracking.cpp
index 1171149180bf74312e4f7bd2b0c661ec7c672c00..1187de7b59bd40e98aefc4536451199eca5213e2 100644 (file)
@@ -367,26 +367,30 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW,
 }
 
 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
-                                             APInt &KnownZero) {
+                                             APInt &KnownZero,
+                                             APInt &KnownOne) {
   unsigned BitWidth = KnownZero.getBitWidth();
   unsigned NumRanges = Ranges.getNumOperands() / 2;
   assert(NumRanges >= 1);
 
-  // Use the high end of the ranges to find leading zeros.
-  unsigned MinLeadingZeros = BitWidth;
+  KnownZero.setAllBits();
+  KnownOne.setAllBits();
+
   for (unsigned i = 0; i < NumRanges; ++i) {
     ConstantInt *Lower =
         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
     ConstantInt *Upper =
         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
     ConstantRange Range(Lower->getValue(), Upper->getValue());
-    if (Range.isWrappedSet())
-      MinLeadingZeros = 0; // -1 has no zeros
-    unsigned LeadingZeros = (Upper->getValue() - 1).countLeadingZeros();
-    MinLeadingZeros = std::min(LeadingZeros, MinLeadingZeros);
-  }
 
-  KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros);
+    // The first CommonPrefixBits of all values in Range are equal.
+    unsigned CommonPrefixBits =
+        (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros();
+
+    APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
+    KnownOne &= Range.getUnsignedMax() & Mask;
+    KnownZero &= ~Range.getUnsignedMax() & Mask;
+  }
 }
 
 static bool isEphemeralValueOf(Instruction *I, const Value *E) {
@@ -1062,7 +1066,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero,
   default: break;
   case Instruction::Load:
     if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range))
-      computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+      computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
     break;
   case Instruction::And: {
     // If either the LHS or the RHS are Zero, the result is zero.
@@ -1454,7 +1458,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero,
   case Instruction::Call:
   case Instruction::Invoke:
     if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range))
-      computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+      computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
     // If a range metadata is attached to this IntrinsicInst, intersect the
     // explicit range specified by the metadata and the implicit range of
     // the intrinsic.
@@ -4077,3 +4081,43 @@ ConstantRange llvm::getConstantRangeFromMetadata(MDNode &Ranges) {
 
   return CR;
 }
+
+bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
+  assert(LHS->getType() == RHS->getType() && "mismatched type");
+  Type *OpTy = LHS->getType();
+  assert(OpTy->getScalarType()->isIntegerTy(1));
+
+  // LHS ==> RHS by definition
+  if (LHS == RHS) return true;
+
+  if (OpTy->isVectorTy())
+    // TODO: extending the code below to handle vectors
+    return false;
+  assert(OpTy->isIntegerTy(1) && "implied by above");
+
+  ICmpInst::Predicate APred, BPred;
+  Value *I;
+  Value *L;
+  ConstantInt *CI;
+  // i +_{nsw} C_{>0} <s L ==> i <s L
+  if (match(LHS, m_ICmp(APred,
+                        m_NSWAdd(m_Value(I), m_ConstantInt(CI)),
+                        m_Value(L))) &&
+      APred == ICmpInst::ICMP_SLT &&
+      !CI->isNegative() &&
+      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+      BPred == ICmpInst::ICMP_SLT)
+    return true;
+
+  // i +_{nuw} C_{>0} <u L ==> i <u L
+  if (match(LHS, m_ICmp(APred,
+                        m_NUWAdd(m_Value(I), m_ConstantInt(CI)),
+                        m_Value(L))) &&
+      APred == ICmpInst::ICMP_ULT &&
+      !CI->isNegative() &&
+      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+      BPred == ICmpInst::ICMP_ULT)
+    return true;
+
+  return false;
+}