InstCombine: Remove a special case pattern
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index f204eeb548c7bdc6aa740b416e5466c9c21e6c0f..7a2d175b58e6c443ad6255cc93c02b43b239ed4d 100644 (file)
@@ -683,26 +683,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
     }
 
     // If one of the GEPs has all zero indices, recurse.
-    bool AllZeros = true;
-    for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i)
-      if (!isa<Constant>(GEPLHS->getOperand(i)) ||
-          !cast<Constant>(GEPLHS->getOperand(i))->isNullValue()) {
-        AllZeros = false;
-        break;
-      }
-    if (AllZeros)
+    if (GEPLHS->hasAllZeroIndices())
       return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
                          ICmpInst::getSwappedPredicate(Cond), I);
 
     // If the other GEP has all zero indices, recurse.
-    AllZeros = true;
-    for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i)
-      if (!isa<Constant>(GEPRHS->getOperand(i)) ||
-          !cast<Constant>(GEPRHS->getOperand(i))->isNullValue()) {
-        AllZeros = false;
-        break;
-      }
-    if (AllZeros)
+    if (GEPRHS->hasAllZeroIndices())
       return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
 
     bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
@@ -754,21 +740,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
                                             ICmpInst::Predicate Pred) {
-  // If we have X+0, exit early (simplifying logic below) and let it get folded
-  // elsewhere.   icmp X+0, X  -> icmp X, X
-  if (CI->isZero()) {
-    bool isTrue = ICmpInst::isTrueWhenEqual(Pred);
-    return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
-  }
-
-  // (X+4) == X -> false.
-  if (Pred == ICmpInst::ICMP_EQ)
-    return ReplaceInstUsesWith(ICI, Builder->getFalse());
-
-  // (X+4) != X -> true.
-  if (Pred == ICmpInst::ICMP_NE)
-    return ReplaceInstUsesWith(ICI, Builder->getTrue());
-
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similarly for all other "or equals"
   // operators.
@@ -1058,6 +1029,90 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
   return nullptr;
 }
 
+/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" ->
+/// (icmp eq/ne A, Log2(const2/const1)) ->
+/// (icmp eq/ne A, Log2(const2) - Log2(const1)).
+Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
+                                             ConstantInt *CI1,
+                                             ConstantInt *CI2) {
+  assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+  auto getConstant = [&I, this](bool IsTrue) {
+    if (I.getPredicate() == I.ICMP_NE)
+      IsTrue = !IsTrue;
+    return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+  };
+
+  auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+    if (I.getPredicate() == I.ICMP_NE)
+      Pred = CmpInst::getInversePredicate(Pred);
+    return new ICmpInst(Pred, LHS, RHS);
+  };
+
+  APInt AP1 = CI1->getValue();
+  APInt AP2 = CI2->getValue();
+
+  if (!AP1) {
+    if (!AP2) {
+      // Both Constants are 0.
+      return getConstant(true);
+    }
+
+    if (cast<BinaryOperator>(Op)->isExact())
+      return getConstant(false);
+
+    if (AP2.isNegative()) {
+      // MSB is set, so a lshr with a large enough 'A' would be undefined.
+      return getConstant(false);
+    }
+
+    // 'A' must be large enough to shift out the highest set bit.
+    return getICmp(I.ICMP_UGT, A,
+                   ConstantInt::get(A->getType(), AP2.logBase2()));
+  }
+
+  if (!AP2) {
+    // Shifting 0 by any value gives 0.
+    return getConstant(false);
+  }
+
+  bool IsAShr = isa<AShrOperator>(Op);
+  if (AP1 == AP2) {
+    if (AP1.isAllOnesValue() && IsAShr) {
+      // Arithmatic shift of -1 is always -1.
+      return getConstant(true);
+    }
+    return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+  }
+
+  if (IsAShr) {
+    if (AP1.isNegative() != AP2.isNegative()) {
+      // Arithmetic shift will never change the sign.
+      return getConstant(false);
+    }
+    // Both the constants are negative, take their positive to calculate
+    // log.
+    if (AP1.isNegative()) {
+      AP1 = -AP1;
+      AP2 = -AP2;
+    }
+  }
+
+  if (AP1.ugt(AP2)) {
+    // Right-shifting will not increase the value.
+    return getConstant(false);
+  }
+
+  // Get the distance between the highest bit that's set.
+  int Shift = AP2.logBase2() - AP1.logBase2();
+
+  // Use lshr here, since we've canonicalized to +ve numbers.
+  if (AP1 == AP2.lshr(Shift))
+    return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+
+  // Shifting const2 will never be equal to const1.
+  return getConstant(false);
+}
 
 /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)".
 ///
@@ -1296,6 +1351,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         return &ICI;
       }
 
+      // (icmp pred (and (or (lshr X, Y), X), 1), 0) -->
+      //    (icmp pred (and X, (or (shl 1, Y), 1), 0))
+      //
+      // iff pred isn't signed
+      {
+        Value *X, *Y, *LShr;
+        if (!ICI.isSigned() && RHSV == 0) {
+          if (match(LHSI->getOperand(1), m_One())) {
+            Constant *One = cast<Constant>(LHSI->getOperand(1));
+            Value *Or = LHSI->getOperand(0);
+            if (match(Or, m_Or(m_Value(LShr), m_Value(X))) &&
+                match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) {
+              unsigned UsesRemoved = 0;
+              if (LHSI->hasOneUse())
+                ++UsesRemoved;
+              if (Or->hasOneUse())
+                ++UsesRemoved;
+              if (LShr->hasOneUse())
+                ++UsesRemoved;
+              Value *NewOr = nullptr;
+              // Compute X & ((1 << Y) | 1)
+              if (auto *C = dyn_cast<Constant>(Y)) {
+                if (UsesRemoved >= 1)
+                  NewOr =
+                      ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
+              } else {
+                if (UsesRemoved >= 3)
+                  NewOr = Builder->CreateOr(Builder->CreateShl(One, Y,
+                                                               LShr->getName(),
+                                                               /*HasNUW=*/true),
+                                            One, Or->getName());
+              }
+              if (NewOr) {
+                Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName());
+                ICI.setOperand(0, NewAnd);
+                return &ICI;
+              }
+            }
+          }
+        }
+      }
+
       // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any
       // bit set in (X & AndCst) will produce a result greater than RHSV.
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
@@ -1391,16 +1488,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           unsigned RHSLog2 = RHSV.logBase2();
 
           // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
-          // (1 << X) >  2147483648 -> X >  31 -> false
-          // (1 << X) <= 2147483648 -> X <= 31 -> true
           // (1 << X) <  2147483648 -> X <  31 -> X != 31
           if (RHSLog2 == TypeBits-1) {
             if (Pred == ICmpInst::ICMP_UGE)
               Pred = ICmpInst::ICMP_EQ;
-            else if (Pred == ICmpInst::ICMP_UGT)
-              return ReplaceInstUsesWith(ICI, Builder->getFalse());
-            else if (Pred == ICmpInst::ICMP_ULE)
-              return ReplaceInstUsesWith(ICI, Builder->getTrue());
             else if (Pred == ICmpInst::ICMP_ULT)
               Pred = ICmpInst::ICMP_NE;
           }
@@ -1435,10 +1526,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           if (RHSVIsPowerOf2)
             return new ICmpInst(
                 Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
-
-          return ReplaceInstUsesWith(
-              ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse()
-                                             : Builder->getTrue());
         }
       }
       break;
@@ -2483,6 +2570,15 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                           Builder->getInt(CI->getValue()-1));
     }
 
+    // (icmp eq/ne (ashr/lshr const2, A), const1)
+    if (I.isEquality()) {
+      ConstantInt *CI2;
+      if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
+          match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
+        return FoldICmpCstShrCst(I, Op0, A, CI, CI2);
+      }
+    }
+
     // If this comparison is a normal comparison, it demands all
     // bits, if it is a sign bit comparison, it only demands the sign bit.
     bool UnusedBit;