Fix a typo 'iff' => 'if'
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index 94b619b2037b2639b716cec9b7e59656aefe8cd5..6d81d6dff8e0e0428261d378363bcfc9d0ead2cc 100644 (file)
@@ -38,7 +38,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
                       m_Value(B))) &&
       // The "1" can be any value known to be a power of 2.
       isPowerOfTwo(PowerOf2, IC.getTargetData())) {
-    A = IC.Builder->CreateSub(A, B, "tmp");
+    A = IC.Builder->CreateSub(A, B);
     return IC.Builder->CreateShl(PowerOf2, A);
   }
   
@@ -131,24 +131,30 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
     { Value *X; ConstantInt *C1;
       if (Op0->hasOneUse() &&
           match(Op0, m_Add(m_Value(X), m_ConstantInt(C1)))) {
-        Value *Add = Builder->CreateMul(X, CI, "tmp");
+        Value *Add = Builder->CreateMul(X, CI);
         return BinaryOperator::CreateAdd(Add, Builder->CreateMul(C1, CI));
       }
     }
 
-    // (1 - X) * (-2) -> (x - 1) * 2, for all positive nonzero powers of 2
-    // The "* 2" thus becomes a potential shifting opportunity.
+    // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
+    // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
+    // The "* (2**n)" thus becomes a potential shifting opportunity.
     {
       const APInt &   Val = CI->getValue();
       const APInt &PosVal = Val.abs();
       if (Val.isNegative() && PosVal.isPowerOf2()) {
-        Value *X = 0;
-        if (match(Op0, m_Sub(m_One(), m_Value(X)))) {
-          // ConstantInt::get(Op0->getType(), 2);
-          Value *Sub = Builder->CreateSub(X, ConstantInt::get(X->getType(), 1),
-                                          "dec1");
-          return BinaryOperator::CreateMul(Sub, ConstantInt::get(X->getType(),
-                                                                 PosVal));
+        Value *X = 0, *Y = 0;
+        if (Op0->hasOneUse()) {
+          ConstantInt *C1;
+          Value *Sub = 0;
+          if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
+            Sub = Builder->CreateSub(X, Y, "suba");
+          else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
+            Sub = Builder->CreateSub(Builder->CreateNeg(C1), Y, "subc");
+          if (Sub)
+            return
+              BinaryOperator::CreateMul(Sub,
+                                        ConstantInt::get(Y->getType(), PosVal));
         }
       }
     }
@@ -238,7 +244,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
 
     if (BoolCast) {
       Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
-                                    BoolCast, "tmp");
+                                    BoolCast);
       return BinaryOperator::CreateAnd(V, OtherOp);
     }
   }
@@ -250,22 +256,18 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // Simplify mul instructions with a constant RHS...
+  // Simplify mul instructions with a constant RHS.
   if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
     if (ConstantFP *Op1F = dyn_cast<ConstantFP>(Op1C)) {
       // "In IEEE floating point, x*1 is not equivalent to x for nans.  However,
       // ANSI says we can drop signals, so we can do this anyway." (from GCC)
       if (Op1F->isExactlyValue(1.0))
         return ReplaceInstUsesWith(I, Op0);  // Eliminate 'fmul double %X, 1.0'
-    } else if (Op1C->getType()->isVectorTy()) {
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantFP *F = dyn_cast<ConstantFP>(Splat))
-            if (F->isExactlyValue(1.0))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+    } else if (ConstantDataVector *Op1V = dyn_cast<ConstantDataVector>(Op1C)) {
+      // As above, vector X*splat(1.0) -> X in all defined cases.
+      if (ConstantFP *F = dyn_cast_or_null<ConstantFP>(Op1V->getSplatValue()))
+        if (F->isExactlyValue(1.0))
+          return ReplaceInstUsesWith(I, Op0);
     }
 
     // Try to fold constant mul into select arguments.
@@ -415,7 +417,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
 
 /// dyn_castZExtVal - Checks if V is a zext or constant that can
 /// be truncated to Ty without losing bits.
-static Value *dyn_castZExtVal(Value *V, const Type *Ty) {
+static Value *dyn_castZExtVal(Value *V, Type *Ty) {
   if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
     if (Z->getSrcTy() == Ty)
       return Z->getOperand(0);
@@ -435,19 +437,23 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
   // Handle the integer div common cases
   if (Instruction *Common = commonIDivTransforms(I))
     return Common;
-
-  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
+  
+  { 
     // X udiv 2^C -> X >> C
     // Check to see if this is an unsigned division with an exact power of 2,
     // if so, convert to a right shift.
-    if (C->getValue().isPowerOf2()) { // 0 not included in isPowerOf2
+    const APInt *C;
+    if (match(Op1, m_Power2(C))) {
       BinaryOperator *LShr =
-        BinaryOperator::CreateLShr(Op0, 
-            ConstantInt::get(Op0->getType(), C->getValue().logBase2()));
+      BinaryOperator::CreateLShr(Op0, 
+                                 ConstantInt::get(Op0->getType(), 
+                                                  C->logBase2()));
       if (I.isExact()) LShr->setIsExact();
       return LShr;
     }
+  }
 
+  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
     // X udiv C, where C >= signbit
     if (C->getValue().isNegative()) {
       Value *IC = Builder->CreateICmpULT(Op0, C);
@@ -456,12 +462,25 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
     }
   }
 
+  // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
+  if (ConstantInt *C2 = dyn_cast<ConstantInt>(Op1)) {
+    Value *X;
+    ConstantInt *C1;
+    if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) {
+      APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1));
+      return BinaryOperator::CreateUDiv(X, Builder->getInt(NC));
+    }
+  }
+
   // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
   { const APInt *CI; Value *N;
-    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N)))) {
+    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) ||
+        match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) {
       if (*CI != 1)
-        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(), CI->logBase2()),
-                               "tmp");
+        N = Builder->CreateAdd(N,
+                               ConstantInt::get(N->getType(), CI->logBase2()));
+      if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
+        N = Builder->CreateZExt(N, Z->getDestTy());
       if (I.isExact())
         return BinaryOperator::CreateExactLShr(Op0, N);
       return BinaryOperator::CreateLShr(Op0, N);
@@ -531,7 +550,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
     APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
     if (MaskedValueIsZero(Op0, Mask)) {
       if (MaskedValueIsZero(Op1, Mask)) {
-        // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
+        // X sdiv Y -> X udiv Y, if X and Y don't have sign bit set
         return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
       }
       
@@ -624,7 +643,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
   // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
   if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
     Constant *N1 = Constant::getAllOnesValue(I.getType());
-    Value *Add = Builder->CreateAdd(Op1, N1, "tmp");
+    Value *Add = Builder->CreateAdd(Op1, N1);
     return BinaryOperator::CreateAnd(Op0, Add);
   }
 
@@ -673,34 +692,42 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
   if (I.getType()->isIntegerTy()) {
     APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
     if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) {
-      // X srem Y -> X urem Y, iff X and Y don't have sign bit set
+      // X srem Y -> X urem Y, if X and Y don't have sign bit set
       return BinaryOperator::CreateURem(Op0, Op1, I.getName());
     }
   }
 
   // If it's a constant vector, flip any negative values positive.
-  if (ConstantVector *RHSV = dyn_cast<ConstantVector>(Op1)) {
-    unsigned VWidth = RHSV->getNumOperands();
+  if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
+    Constant *C = cast<Constant>(Op1);
+    unsigned VWidth = C->getType()->getVectorNumElements();
 
     bool hasNegative = false;
-    for (unsigned i = 0; !hasNegative && i != VWidth; ++i)
-      if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i)))
-        if (RHS->getValue().isNegative())
+    bool hasMissing = false;
+    for (unsigned i = 0; i != VWidth; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (Elt == 0) {
+        hasMissing = true;
+        break;
+      }
+
+      if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt))
+        if (RHS->isNegative())
           hasNegative = true;
+    }
 
-    if (hasNegative) {
-      std::vector<Constant *> Elts(VWidth);
+    if (hasNegative && !hasMissing) {
+      SmallVector<Constant *, 16> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
-        if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) {
-          if (RHS->getValue().isNegative())
+        Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
+        if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
+          if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
-          else
-            Elts[i] = RHS;
         }
       }
 
       Constant *NewRHSV = ConstantVector::get(Elts);
-      if (NewRHSV != RHSV) {
+      if (NewRHSV != C) {  // Don't loop on -MININT
         Worklist.AddValue(I.getOperand(1));
         I.setOperand(1, NewRHSV);
         return &I;