Re-commit 141203, but much more conservative.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index f3d10611ad2a709d2d30e65701fe1b770250e8e3..7f48125a97ab6e89f493a604335ba1078b6f3da5 100644 (file)
@@ -29,40 +29,47 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
   // code.
   if (!V->hasOneUse()) return 0;
   
+  bool MadeChange = false;
+
+  // ((1 << A) >>u B) --> (1 << (A-B))
+  // Because V cannot be zero, we know that B is less than A.
+  Value *A = 0, *B = 0, *PowerOf2 = 0;
+  if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))),
+                      m_Value(B))) &&
+      // The "1" can be any value known to be a power of 2.
+      isPowerOfTwo(PowerOf2, IC.getTargetData())) {
+    A = IC.Builder->CreateSub(A, B);
+    return IC.Builder->CreateShl(PowerOf2, A);
+  }
   
   // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
   // inexact.  Similarly for <<.
   if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
     if (I->isLogicalShift() &&
         isPowerOfTwo(I->getOperand(0), IC.getTargetData())) {
+      // We know that this is an exact/nuw shift and that the input is a
+      // non-zero context as well.
+      if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) {
+        I->setOperand(0, V2);
+        MadeChange = true;
+      }
+      
       if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
         I->setIsExact();
-        return I;
+        MadeChange = true;
       }
       
       if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
         I->setHasNoUnsignedWrap();
-        return I;
+        MadeChange = true;
       }
     }
-      
-  // ((1 << A) >>u B) --> (1 << (A-B))
-  // Because V cannot be zero, we know that B is less than A.
-  Value *A = 0, *B = 0, *PowerOf2 = 0;
-  if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))),
-                      m_Value(B))) &&
-      // The "1" can be any value known to be a power of 2.
-      isPowerOfTwo(PowerOf2, IC.getTargetData())) {
-    A = IC.Builder->CreateSub(A, B, "tmp");
-    return IC.Builder->CreateShl(PowerOf2, A);
-  }
-  
+
   // TODO: Lots more we could do here:
-  //    "1 >> X" could get an "isexact" bit.
   //    If V is a phi node, we can call this on each of its operands.
   //    "select cond, X, 0" can simplify to "X".
   
-  return 0;
+  return MadeChange ? V : 0;
 }
 
 
@@ -124,10 +131,33 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
     { Value *X; ConstantInt *C1;
       if (Op0->hasOneUse() &&
           match(Op0, m_Add(m_Value(X), m_ConstantInt(C1)))) {
-        Value *Add = Builder->CreateMul(X, CI, "tmp");
+        Value *Add = Builder->CreateMul(X, CI);
         return BinaryOperator::CreateAdd(Add, Builder->CreateMul(C1, CI));
       }
     }
+
+    // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
+    // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
+    // The "* (2**n)" thus becomes a potential shifting opportunity.
+    {
+      const APInt &   Val = CI->getValue();
+      const APInt &PosVal = Val.abs();
+      if (Val.isNegative() && PosVal.isPowerOf2()) {
+        Value *X = 0, *Y = 0;
+        if (Op0->hasOneUse()) {
+          ConstantInt *C1;
+          Value *Sub = 0;
+          if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
+            Sub = Builder->CreateSub(X, Y, "suba");
+          else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
+            Sub = Builder->CreateSub(Builder->CreateNeg(C1), Y, "subc");
+          if (Sub)
+            return
+              BinaryOperator::CreateMul(Sub,
+                                        ConstantInt::get(Y->getType(), PosVal));
+        }
+      }
+    }
   }
   
   // Simplify mul instructions with a constant RHS.
@@ -214,7 +244,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
 
     if (BoolCast) {
       Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
-                                    BoolCast, "tmp");
+                                    BoolCast);
       return BinaryOperator::CreateAnd(V, OtherOp);
     }
   }
@@ -391,7 +421,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
 
 /// dyn_castZExtVal - Checks if V is a zext or constant that can
 /// be truncated to Ty without losing bits.
-static Value *dyn_castZExtVal(Value *V, const Type *Ty) {
+static Value *dyn_castZExtVal(Value *V, Type *Ty) {
   if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
     if (Z->getSrcTy() == Ty)
       return Z->getOperand(0);
@@ -436,8 +466,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
   { const APInt *CI; Value *N;
     if (match(Op1, m_Shl(m_Power2(CI), m_Value(N)))) {
       if (*CI != 1)
-        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(), CI->logBase2()),
-                               "tmp");
+        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(),CI->logBase2()));
       if (I.isExact())
         return BinaryOperator::CreateExactLShr(Op0, N);
       return BinaryOperator::CreateLShr(Op0, N);
@@ -600,7 +629,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
   // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
   if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
     Constant *N1 = Constant::getAllOnesValue(I.getType());
-    Value *Add = Builder->CreateAdd(Op1, N1, "tmp");
+    Value *Add = Builder->CreateAdd(Op1, N1);
     return BinaryOperator::CreateAnd(Op0, Add);
   }
 
@@ -661,14 +690,14 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
     bool hasNegative = false;
     for (unsigned i = 0; !hasNegative && i != VWidth; ++i)
       if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i)))
-        if (RHS->getValue().isNegative())
+        if (RHS->isNegative())
           hasNegative = true;
 
     if (hasNegative) {
       std::vector<Constant *> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
         if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) {
-          if (RHS->getValue().isNegative())
+          if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
           else
             Elts[i] = RHS;