Add comment.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineShifts.cpp
index 9f7d98ed794d0d17870dc5e69e1c6a94f887936f..811f94976f68164232e8e7e4ec0da13c94709370 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "InstCombine.h"
 #include "llvm/IntrinsicInst.h"
+#include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -21,25 +22,6 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
   assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // shl X, 0 == X and shr X, 0 == X
-  // shl 0, X == 0 and shr 0, X == 0
-  if (Op1 == Constant::getNullValue(Op1->getType()) ||
-      Op0 == Constant::getNullValue(Op0->getType()))
-    return ReplaceInstUsesWith(I, Op0);
-  
-  if (isa<UndefValue>(Op0)) {            
-    if (I.getOpcode() == Instruction::AShr) // undef >>s X -> undef
-      return ReplaceInstUsesWith(I, Op0);
-    else                                    // undef << X -> 0, undef >>u X -> 0
-      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-  }
-  if (isa<UndefValue>(Op1)) {
-    if (I.getOpcode() == Instruction::AShr)  // X >>s undef -> X
-      return ReplaceInstUsesWith(I, Op0);          
-    else                                     // X << undef, X >>u undef -> 0
-      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-  }
-
   // See if we can fold away this shift.
   if (SimplifyDemandedInstructionBits(I))
     return &I;
@@ -53,6 +35,20 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
   if (ConstantInt *CUI = dyn_cast<ConstantInt>(Op1))
     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
       return Res;
+
+  // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2.
+  // Because shifts by negative values (which could occur if A were negative)
+  // are undefined.
+  Value *A; const APInt *B;
+  if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) {
+    // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
+    // demand the sign bit (and many others) here??
+    Value *Rem = Builder->CreateAnd(A, ConstantInt::get(I.getType(), *B-1),
+                                    Op1->getName());
+    I.setOperand(1, Rem);
+    return &I;
+  }
+  
   return 0;
 }
 
@@ -81,7 +77,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
   // if the needed bits are already zero in the input.  This allows us to reuse
   // the value which means that we don't care if the shift has multiple uses.
   //  TODO:  Handle opposite shift by exact value.
-  ConstantInt *CI;
+  ConstantInt *CI = 0;
   if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
       (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
     if (CI->getZExtValue() == NumBits) {
@@ -157,7 +153,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
     if (CI->getZExtValue() > NumBits) {
       unsigned LowBits = CI->getZExtValue() - NumBits;
       if (MaskedValueIsZero(I->getOperand(0),
-                          APInt::getLowBitsSet(TypeWidth, LowBits) << NumBits))
+                          APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits))
         return true;
     }
       
@@ -622,16 +618,56 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1,
 }
 
 Instruction *InstCombiner::visitShl(BinaryOperator &I) {
-  return commonShiftTransforms(I);
+  if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
+                                 I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+                                 TD))
+    return ReplaceInstUsesWith(I, V);
+  
+  if (Instruction *V = commonShiftTransforms(I))
+    return V;
+  
+  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) {
+    unsigned ShAmt = Op1C->getZExtValue();
+    
+    // If the shifted-out value is known-zero, then this is a NUW shift.
+    if (!I.hasNoUnsignedWrap() && 
+        MaskedValueIsZero(I.getOperand(0),
+                          APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt))) {
+          I.setHasNoUnsignedWrap();
+          return &I;
+        }
+    
+    // If the shifted out value is all signbits, this is a NSW shift.
+    if (!I.hasNoSignedWrap() &&
+        ComputeNumSignBits(I.getOperand(0)) > ShAmt) {
+      I.setHasNoSignedWrap();
+      return &I;
+    }
+  }
+
+  // (C1 << A) << C2 -> (C1 << C2) << A
+  Constant *C1, *C2;
+  Value *A;
+  if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) &&
+      match(I.getOperand(1), m_Constant(C2)))
+    return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
+
+  return 0;    
 }
 
 Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
+  if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1),
+                                  I.isExact(), TD))
+    return ReplaceInstUsesWith(I, V);
+
   if (Instruction *R = commonShiftTransforms(I))
     return R;
   
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   
-  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1))
+  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
+    unsigned ShAmt = Op1C->getZExtValue();
+
     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
       unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
       // ctlz.i32(x)>>5  --> zext(x == 0)
@@ -640,7 +676,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
       if ((II->getIntrinsicID() == Intrinsic::ctlz ||
            II->getIntrinsicID() == Intrinsic::cttz ||
            II->getIntrinsicID() == Intrinsic::ctpop) &&
-          isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == Op1C->getZExtValue()){
+          isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
         bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
         Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
         Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
@@ -648,29 +684,37 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
       }
     }
   
+    // If the shifted-out value is known-zero, then this is an exact shift.
+    if (!I.isExact() && 
+        MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){
+      I.setIsExact();
+      return &I;
+    }    
+  }
+  
   return 0;
 }
 
 Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
+  if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1),
+                                  I.isExact(), TD))
+    return ReplaceInstUsesWith(I, V);
+
   if (Instruction *R = commonShiftTransforms(I))
     return R;
   
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  
-  if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0)) {
-    // ashr int -1, X = -1   (for any arithmetic shift rights of ~0)
-    if (CSI->isAllOnesValue())
-      return ReplaceInstUsesWith(I, CSI);
-  }
-  
+
   if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
+    unsigned ShAmt = Op1C->getZExtValue();
+    
     // If the input is a SHL by the same constant (ashr (shl X, C), C), then we
     // have a sign-extend idiom.
     Value *X;
     if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
-      // If the input value is known to already be sign extended enough, delete
-      // the extension.
-      if (ComputeNumSignBits(X) > Op1C->getZExtValue())
+      // If the left shift is just shifting out partial signbits, delete the
+      // extension.
+      if (cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap())
         return ReplaceInstUsesWith(I, X);
 
       // If the input is an extension from the shifted amount value, e.g.
@@ -685,6 +729,13 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
           return new SExtInst(ZI->getOperand(0), ZI->getType());
       }
     }
+
+    // If the shifted-out value is known-zero, then this is an exact shift.
+    if (!I.isExact() && 
+        MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){
+      I.setIsExact();
+      return &I;
+    }
   }            
   
   // See if we can turn a signed shr into an unsigned shr.