Fix a typo 'iff' => 'if'
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index 122f9486c8ecb5436c8bca4bbfba65f57879cc30..6d81d6dff8e0e0428261d378363bcfc9d0ead2cc 100644 (file)
@@ -256,22 +256,18 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // Simplify mul instructions with a constant RHS...
+  // Simplify mul instructions with a constant RHS.
   if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
     if (ConstantFP *Op1F = dyn_cast<ConstantFP>(Op1C)) {
       // "In IEEE floating point, x*1 is not equivalent to x for nans.  However,
       // ANSI says we can drop signals, so we can do this anyway." (from GCC)
       if (Op1F->isExactlyValue(1.0))
         return ReplaceInstUsesWith(I, Op0);  // Eliminate 'fmul double %X, 1.0'
-    } else if (Op1C->getType()->isVectorTy()) {
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantFP *F = dyn_cast<ConstantFP>(Splat))
-            if (F->isExactlyValue(1.0))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+    } else if (ConstantDataVector *Op1V = dyn_cast<ConstantDataVector>(Op1C)) {
+      // As above, vector X*splat(1.0) -> X in all defined cases.
+      if (ConstantFP *F = dyn_cast_or_null<ConstantFP>(Op1V->getSplatValue()))
+        if (F->isExactlyValue(1.0))
+          return ReplaceInstUsesWith(I, Op0);
     }
 
     // Try to fold constant mul into select arguments.
@@ -466,11 +462,25 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
     }
   }
 
+  // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
+  if (ConstantInt *C2 = dyn_cast<ConstantInt>(Op1)) {
+    Value *X;
+    ConstantInt *C1;
+    if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) {
+      APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1));
+      return BinaryOperator::CreateUDiv(X, Builder->getInt(NC));
+    }
+  }
+
   // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
   { const APInt *CI; Value *N;
-    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N)))) {
+    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) ||
+        match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) {
       if (*CI != 1)
-        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(),CI->logBase2()));
+        N = Builder->CreateAdd(N,
+                               ConstantInt::get(N->getType(), CI->logBase2()));
+      if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
+        N = Builder->CreateZExt(N, Z->getDestTy());
       if (I.isExact())
         return BinaryOperator::CreateExactLShr(Op0, N);
       return BinaryOperator::CreateLShr(Op0, N);
@@ -540,7 +550,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
     APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
     if (MaskedValueIsZero(Op0, Mask)) {
       if (MaskedValueIsZero(Op1, Mask)) {
-        // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
+        // X sdiv Y -> X udiv Y, if X and Y don't have sign bit set
         return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
       }
       
@@ -682,34 +692,42 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
   if (I.getType()->isIntegerTy()) {
     APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
     if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) {
-      // X srem Y -> X urem Y, iff X and Y don't have sign bit set
+      // X srem Y -> X urem Y, if X and Y don't have sign bit set
       return BinaryOperator::CreateURem(Op0, Op1, I.getName());
     }
   }
 
   // If it's a constant vector, flip any negative values positive.
-  if (ConstantVector *RHSV = dyn_cast<ConstantVector>(Op1)) {
-    unsigned VWidth = RHSV->getNumOperands();
+  if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
+    Constant *C = cast<Constant>(Op1);
+    unsigned VWidth = C->getType()->getVectorNumElements();
 
     bool hasNegative = false;
-    for (unsigned i = 0; !hasNegative && i != VWidth; ++i)
-      if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i)))
+    bool hasMissing = false;
+    for (unsigned i = 0; i != VWidth; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (Elt == 0) {
+        hasMissing = true;
+        break;
+      }
+
+      if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt))
         if (RHS->isNegative())
           hasNegative = true;
+    }
 
-    if (hasNegative) {
+    if (hasNegative && !hasMissing) {
       SmallVector<Constant *, 16> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
-        if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) {
+        Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
+        if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
           if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
-          else
-            Elts[i] = RHS;
         }
       }
 
       Constant *NewRHSV = ConstantVector::get(Elts);
-      if (NewRHSV != RHSV) {
+      if (NewRHSV != C) {  // Don't loop on -MININT
         Worklist.AddValue(I.getOperand(1));
         I.setOperand(1, NewRHSV);
         return &I;