InstCombine: Respect recursion depth in visitUDivOperand
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index 6c6e7d8151634b37bbce43ff4ec7977a0ef67ec0..3f86ddfd104e01a1e505eeb170ab4fa4a90c7ab7 100644 (file)
@@ -97,6 +97,21 @@ static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) {
   return MulExt.slt(Min) || MulExt.sgt(Max);
 }
 
+/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1.
+static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient,
+                       bool IsSigned) {
+  assert(C1.getBitWidth() == C2.getBitWidth() &&
+         "Inconsistent width of constants!");
+
+  APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned);
+  if (IsSigned)
+    APInt::sdivrem(C1, C2, Quotient, Remainder);
+  else
+    APInt::udivrem(C1, C2, Quotient, Remainder);
+
+  return Remainder.isMinValue();
+}
+
 /// \brief A helper routine of InstCombiner::visitMul().
 ///
 /// If C is a vector of known powers of 2, then this function returns
@@ -596,36 +611,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
       }
     }
 
-    // B * (uitofp i1 C) -> select C, B, 0
-    if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) {
-      Value *LHS = Op0, *RHS = Op1;
-      Value *B, *C;
-      if (!match(RHS, m_UIToFP(m_Value(C))))
-        std::swap(LHS, RHS);
-
-      if (match(RHS, m_UIToFP(m_Value(C))) &&
-          C->getType()->getScalarType()->isIntegerTy(1)) {
-        B = LHS;
-        Value *Zero = ConstantFP::getNegativeZero(B->getType());
-        return SelectInst::Create(C, B, Zero);
-      }
-    }
-
-    // A * (1 - uitofp i1 C) -> select C, 0, A
-    if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) {
-      Value *LHS = Op0, *RHS = Op1;
-      Value *A, *C;
-      if (!match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C)))))
-        std::swap(LHS, RHS);
-
-      if (match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C)))) &&
-          C->getType()->getScalarType()->isIntegerTy(1)) {
-        A = LHS;
-        Value *Zero = ConstantFP::getNegativeZero(A->getType());
-        return SelectInst::Create(C, Zero, A);
-      }
-    }
-
     if (!isa<Constant>(Op1))
       std::swap(Opnd0, Opnd1);
     else
@@ -725,8 +710,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
     return &I;
 
   if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
-    // (X / C1) / C2  -> X / (C1*C2)
-    if (Instruction *LHS = dyn_cast<Instruction>(Op0))
+    if (Instruction *LHS = dyn_cast<Instruction>(Op0)) {
+      // (X / C1) / C2  -> X / (C1*C2)
       if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode())
         if (ConstantInt *LHSRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) {
           if (MultiplyOverflows(RHS, LHSRHS,
@@ -736,6 +721,64 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
                                         ConstantExpr::getMul(RHS, LHSRHS));
         }
 
+      Value *X;
+      const APInt *C1, *C2;
+      if (match(RHS, m_APInt(C2))) {
+        bool IsSigned = I.getOpcode() == Instruction::SDiv;
+        if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) ||
+            (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) {
+          APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
+
+          // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1.
+          if (IsMultiple(*C2, *C1, Quotient, IsSigned)) {
+            BinaryOperator *BO = BinaryOperator::Create(
+                I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient));
+            BO->setIsExact(I.isExact());
+            return BO;
+          }
+
+          // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2.
+          if (IsMultiple(*C1, *C2, Quotient, IsSigned)) {
+            BinaryOperator *BO = BinaryOperator::Create(
+                Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient));
+            BO->setHasNoUnsignedWrap(
+                !IsSigned &&
+                cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap());
+            BO->setHasNoSignedWrap(
+                cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap());
+            return BO;
+          }
+        }
+
+        if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1)))) ||
+            (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) {
+          APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
+          APInt C1Shifted = APInt::getOneBitSet(
+              C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue()));
+
+          // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1.
+          if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) {
+            BinaryOperator *BO = BinaryOperator::Create(
+                I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient));
+            BO->setIsExact(I.isExact());
+            return BO;
+          }
+
+          // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2.
+          if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) {
+            BinaryOperator *BO = BinaryOperator::Create(
+                Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient));
+            BO->setHasNoUnsignedWrap(
+                !IsSigned &&
+                cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap());
+            BO->setHasNoSignedWrap(
+                cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap());
+            return BO;
+          }
+        }
+      }
+    }
+
     if (!RHS->isZero()) { // avoid X udiv 0
       if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
         if (Instruction *R = FoldOpIntoSelect(I, SI))
@@ -893,10 +936,10 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I,
     return 0;
 
   if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
-    if (size_t LHSIdx = visitUDivOperand(Op0, SI->getOperand(1), I, Actions))
-      if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions)) {
-        Actions.push_back(UDivFoldAction((FoldUDivOperandCb)nullptr, Op1,
-                                         LHSIdx-1));
+    if (size_t LHSIdx =
+            visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth))
+      if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) {
+        Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1));
         return Actions.size();
       }