Push analysis passes to InstSimplify when they're around anyways.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index 617d8e7a5c1ce20f6ddebacab3f7542fba2e1164..2c292ce29d0a22e99bb2044f59b2914889b52286 100644 (file)
@@ -227,7 +227,8 @@ Instruction *InstCombiner::
 FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
                              CmpInst &ICI, ConstantInt *AndCst) {
   // We need TD information to know the pointer size unless this is inbounds.
-  if (!GEP->isInBounds() && TD == 0) return 0;
+  if (!GEP->isInBounds() && TD == 0)
+    return 0;
 
   Constant *Init = GV->getInitializer();
   if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init))
@@ -562,16 +563,18 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
     }
   }
 
+
+
   // Okay, we know we have a single variable index, which must be a
   // pointer/array/vector index.  If there is no offset, life is simple, return
   // the index.
-  unsigned IntPtrWidth = TD.getPointerSizeInBits();
+  Type *IntPtrTy = TD.getIntPtrType(GEP->getOperand(0)->getType());
+  unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth();
   if (Offset == 0) {
     // Cast to intptrty in case a truncation occurs.  If an extension is needed,
     // we don't need to bother extending: the extension won't affect where the
     // computation crosses zero.
     if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) {
-      Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext());
       VariableIdx = IC.Builder->CreateTrunc(VariableIdx, IntPtrTy);
     }
     return VariableIdx;
@@ -593,7 +596,6 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
     return 0;
 
   // Okay, we can do this evaluation.  Start by converting the index to intptr.
-  Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext());
   if (VariableIdx->getType() != IntPtrTy)
     VariableIdx = IC.Builder->CreateIntCast(VariableIdx, IntPtrTy,
                                             true /*Signed*/);
@@ -647,8 +649,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
 
       // If all indices are the same, just compare the base pointers.
       if (IndicesTheSame)
-        return new ICmpInst(ICmpInst::getSignedPredicate(Cond),
-                            GEPLHS->getOperand(0), GEPRHS->getOperand(0));
+        return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0));
 
       // If we're comparing GEPs with two base pointers that only differ in type
       // and both GEPs have only constant indices or just one use, then fold
@@ -679,7 +680,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
       }
     if (AllZeros)
       return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
-                          ICmpInst::getSwappedPredicate(Cond), I);
+                         ICmpInst::getSwappedPredicate(Cond), I);
 
     // If the other GEP has all zero indices, recurse.
     AllZeros = true;
@@ -738,10 +739,9 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
 }
 
 /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X".
-Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI,
+Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
-                                            ICmpInst::Predicate Pred,
-                                            Value *TheAdd) {
+                                            ICmpInst::Predicate Pred) {
   // If we have X+0, exit early (simplifying logic below) and let it get folded
   // elsewhere.   icmp X+0, X  -> icmp X, X
   if (CI->isZero()) {
@@ -1127,6 +1127,18 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
                               Builder->getInt(RHSV ^ NotSignBit));
         }
       }
+
+      // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C)
+      //   iff -C is a power of 2
+      if (ICI.getPredicate() == ICmpInst::ICMP_UGT &&
+          XorCST->getValue() == ~RHSV && (RHSV + 1).isPowerOf2())
+        return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCST);
+
+      // (icmp ult (xor X, C), -C) -> (icmp uge X, C)
+      //   iff -C is a power of 2
+      if (ICI.getPredicate() == ICmpInst::ICMP_ULT &&
+          XorCST->getValue() == -RHSV && RHSV.isPowerOf2())
+        return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCST);
     }
     break;
   case Instruction::And:         // (icmp pred (and X, AndCST), RHS)
@@ -1278,6 +1290,15 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
               return Res;
           }
     }
+
+    // X & -C == -C -> X >  u ~C
+    // X & -C != -C -> X <= u ~C
+    //   iff C is a power of 2
+    if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2())
+      return new ICmpInst(
+          ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT
+                                                  : ICmpInst::ICMP_ULE,
+          LHSI->getOperand(0), SubOne(RHS));
     break;
 
   case Instruction::Or: {
@@ -1319,10 +1340,80 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
   }
 
   case Instruction::Shl: {       // (icmp pred (shl X, ShAmt), CI)
+    uint32_t TypeBits = RHSV.getBitWidth();
     ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
-    if (!ShAmt) break;
+    if (!ShAmt) {
+      Value *X;
+      // (1 << X) pred P2 -> X pred Log2(P2)
+      if (match(LHSI, m_Shl(m_One(), m_Value(X)))) {
+        bool RHSVIsPowerOf2 = RHSV.isPowerOf2();
+        ICmpInst::Predicate Pred = ICI.getPredicate();
+        if (ICI.isUnsigned()) {
+          if (!RHSVIsPowerOf2) {
+            // (1 << X) <  30 -> X <= 4
+            // (1 << X) <= 30 -> X <= 4
+            // (1 << X) >= 30 -> X >  4
+            // (1 << X) >  30 -> X >  4
+            if (Pred == ICmpInst::ICMP_ULT)
+              Pred = ICmpInst::ICMP_ULE;
+            else if (Pred == ICmpInst::ICMP_UGE)
+              Pred = ICmpInst::ICMP_UGT;
+          }
+          unsigned RHSLog2 = RHSV.logBase2();
+
+          // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
+          // (1 << X) >  2147483648 -> X >  31 -> false
+          // (1 << X) <= 2147483648 -> X <= 31 -> true
+          // (1 << X) <  2147483648 -> X <  31 -> X != 31
+          if (RHSLog2 == TypeBits-1) {
+            if (Pred == ICmpInst::ICMP_UGE)
+              Pred = ICmpInst::ICMP_EQ;
+            else if (Pred == ICmpInst::ICMP_UGT)
+              return ReplaceInstUsesWith(ICI, Builder->getFalse());
+            else if (Pred == ICmpInst::ICMP_ULE)
+              return ReplaceInstUsesWith(ICI, Builder->getTrue());
+            else if (Pred == ICmpInst::ICMP_ULT)
+              Pred = ICmpInst::ICMP_NE;
+          }
 
-    uint32_t TypeBits = RHSV.getBitWidth();
+          return new ICmpInst(Pred, X,
+                              ConstantInt::get(RHS->getType(), RHSLog2));
+        } else if (ICI.isSigned()) {
+          if (RHSV.isAllOnesValue()) {
+            // (1 << X) <= -1 -> X == 31
+            if (Pred == ICmpInst::ICMP_SLE)
+              return new ICmpInst(ICmpInst::ICMP_EQ, X,
+                                  ConstantInt::get(RHS->getType(), TypeBits-1));
+
+            // (1 << X) >  -1 -> X != 31
+            if (Pred == ICmpInst::ICMP_SGT)
+              return new ICmpInst(ICmpInst::ICMP_NE, X,
+                                  ConstantInt::get(RHS->getType(), TypeBits-1));
+          } else if (!RHSV) {
+            // (1 << X) <  0 -> X == 31
+            // (1 << X) <= 0 -> X == 31
+            if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
+              return new ICmpInst(ICmpInst::ICMP_EQ, X,
+                                  ConstantInt::get(RHS->getType(), TypeBits-1));
+
+            // (1 << X) >= 0 -> X != 31
+            // (1 << X) >  0 -> X != 31
+            if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
+              return new ICmpInst(ICmpInst::ICMP_NE, X,
+                                  ConstantInt::get(RHS->getType(), TypeBits-1));
+          }
+        } else if (ICI.isEquality()) {
+          if (RHSVIsPowerOf2)
+            return new ICmpInst(
+                Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
+
+          return ReplaceInstUsesWith(
+              ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse()
+                                             : Builder->getTrue());
+        }
+      }
+      break;
+    }
 
     // Check that the shift amount is in range.  If not, don't perform
     // undefined shifts.  When the shift is visited it will be
@@ -1443,6 +1534,30 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         return R;
     break;
 
+  case Instruction::Sub: {
+    ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(0));
+    if (!LHSC) break;
+    const APInt &LHSV = LHSC->getValue();
+
+    // C1-X <u C2 -> (X|(C2-1)) == C1
+    //   iff C1 & (C2-1) == C2-1
+    //       C2 is a power of 2
+    if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
+        RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1))
+      return new ICmpInst(ICmpInst::ICMP_EQ,
+                          Builder->CreateOr(LHSI->getOperand(1), RHSV - 1),
+                          LHSC);
+
+    // C1-X >u C2 -> (X|C2) != C1
+    //   iff C1 & C2 == C2
+    //       C2+1 is a power of 2
+    if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
+        (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV)
+      return new ICmpInst(ICmpInst::ICMP_NE,
+                          Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC);
+    break;
+  }
+
   case Instruction::Add:
     // Fold: icmp pred (add X, C1), C2
     if (!ICI.isEquality()) {
@@ -1470,6 +1585,24 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
                               Builder->getInt(CR.getLower()));
         }
       }
+
+      // X-C1 <u C2 -> (X & -C2) == C1
+      //   iff C1 & (C2-1) == 0
+      //       C2 is a power of 2
+      if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
+          RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0)
+        return new ICmpInst(ICmpInst::ICMP_EQ,
+                            Builder->CreateAnd(LHSI->getOperand(0), -RHSV),
+                            ConstantExpr::getNeg(LHSC));
+
+      // X-C1 >u C2 -> (X & ~C2) != C1
+      //   iff C1 & C2 == 0
+      //       C2+1 is a power of 2
+      if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
+          (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0)
+        return new ICmpInst(ICmpInst::ICMP_NE,
+                            Builder->CreateAnd(LHSI->getOperand(0), ~RHSV),
+                            ConstantExpr::getNeg(LHSC));
     }
     break;
   }
@@ -1903,14 +2036,59 @@ static APInt DemandedBitsLHSMask(ICmpInst &I,
 
 }
 
+/// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst
+/// should be swapped.
+/// The descision is based on how many times these two operands are reused
+/// as subtract operands and their positions in those instructions.
+/// The rational is that several architectures use the same instruction for
+/// both subtract and cmp, thus it is better if the order of those operands
+/// match.
+/// \return true if Op0 and Op1 should be swapped.
+static bool swapMayExposeCSEOpportunities(const Value * Op0,
+                                          const Value * Op1) {
+  // Filter out pointer value as those cannot appears directly in subtract.
+  // FIXME: we may want to go through inttoptrs or bitcasts.
+  if (Op0->getType()->isPointerTy())
+    return false;
+  // Count every uses of both Op0 and Op1 in a subtract.
+  // Each time Op0 is the first operand, count -1: swapping is bad, the
+  // subtract has already the same layout as the compare.
+  // Each time Op0 is the second operand, count +1: swapping is good, the
+  // subtract has a diffrent layout as the compare.
+  // At the end, if the benefit is greater than 0, Op0 should come second to
+  // expose more CSE opportunities.
+  int GlobalSwapBenefits = 0;
+  for (Value::const_use_iterator UI = Op0->use_begin(), UIEnd = Op0->use_end(); UI != UIEnd; ++UI) {
+    const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(*UI);
+    if (!BinOp || BinOp->getOpcode() != Instruction::Sub)
+      continue;
+    // If Op0 is the first argument, this is not beneficial to swap the
+    // arguments.
+    int LocalSwapBenefits = -1;
+    unsigned Op1Idx = 1;
+    if (BinOp->getOperand(Op1Idx) == Op0) {
+      Op1Idx = 0;
+      LocalSwapBenefits = 1;
+    }
+    if (BinOp->getOperand(Op1Idx) != Op1)
+      continue;
+    GlobalSwapBenefits += LocalSwapBenefits;
+  }
+  return GlobalSwapBenefits > 0;
+}
+
 Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   bool Changed = false;
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  unsigned Op0Cplxity = getComplexity(Op0);
+  unsigned Op1Cplxity = getComplexity(Op1);
 
   /// Orders the operands of the compare so that they are listed from most
   /// complex to least complex.  This puts constants before unary operators,
   /// before binary operators.
-  if (getComplexity(Op0) < getComplexity(Op1)) {
+  if (Op0Cplxity < Op1Cplxity ||
+        (Op0Cplxity == Op1Cplxity &&
+         swapMayExposeCSEOpportunities(Op0, Op1))) {
     I.swapOperands();
     std::swap(Op0, Op1);
     Changed = true;
@@ -2345,7 +2523,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       case Instruction::IntToPtr:
         // icmp pred inttoptr(X), null -> icmp pred X, 0
         if (RHSC->isNullValue() && TD &&
-            TD->getIntPtrType(RHSC->getContext()) ==
+            TD->getIntPtrType(RHSC->getType()) ==
                LHSI->getOperand(0)->getType())
           return new ICmpInst(I.getPredicate(), LHSI->getOperand(0),
                         Constant::getNullValue(LHSI->getOperand(0)->getType()));
@@ -2798,20 +2976,15 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     Value *X; ConstantInt *Cst;
     // icmp X+Cst, X
     if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X)
-      return FoldICmpAddOpCst(I, X, Cst, I.getPredicate(), Op0);
+      return FoldICmpAddOpCst(I, X, Cst, I.getPredicate());
 
     // icmp X, X+Cst
     if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X)
-      return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate(), Op1);
+      return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate());
   }
   return Changed ? &I : 0;
 }
 
-
-
-
-
-
 /// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible.
 ///
 Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I,