InstCombine: ((A & ~B) ^ (~A & B)) to A ^ B
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineShifts.cpp
index 536b148d29cdc9251c18484213115b6682322b5b..3d0cc05f30c6fb407325da6910b6122adf415fd5 100644 (file)
@@ -19,6 +19,8 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+#define DEBUG_TYPE "instcombine"
+
 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
   assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -50,7 +52,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
     return &I;
   }
 
-  return 0;
+  return nullptr;
 }
 
 /// CanEvaluateShifted - See if we can compute the specified value, but shifted
@@ -78,7 +80,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
   // if the needed bits are already zero in the input.  This allows us to reuse
   // the value which means that we don't care if the shift has multiple uses.
   //  TODO:  Handle opposite shift by exact value.
-  ConstantInt *CI = 0;
+  ConstantInt *CI = nullptr;
   if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
       (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
     if (CI->getZExtValue() == NumBits) {
@@ -115,7 +117,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
   case Instruction::Shl: {
     // We can often fold the shift into shifts-by-a-constant.
     CI = dyn_cast<ConstantInt>(I->getOperand(1));
-    if (CI == 0) return false;
+    if (!CI) return false;
 
     // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
     if (isLeftShift) return true;
@@ -139,7 +141,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
   case Instruction::LShr: {
     // We can often fold the shift into shifts-by-a-constant.
     CI = dyn_cast<ConstantInt>(I->getOperand(1));
-    if (CI == 0) return false;
+    if (!CI) return false;
 
     // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
     if (!isLeftShift) return true;
@@ -335,28 +337,19 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
                  GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this));
   }
 
-
   // See if we can simplify any instructions used by the instruction whose sole
   // purpose is to compute bits we don't care about.
   uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
 
-  // shl i32 X, 32 = 0 and srl i8 Y, 9 = 0, ... just don't eliminate
-  // a signed shift.
-  //
-  if (COp1->uge(TypeBits)) {
-    if (I.getOpcode() != Instruction::AShr)
-      return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType()));
-    // ashr i32 X, 32 --> ashr i32 X, 31
-    I.setOperand(1, ConstantInt::get(I.getType(), TypeBits-1));
-    return &I;
-  }
+  assert(!COp1->uge(TypeBits) &&
+         "Shift over the type width should have been removed already");
 
   // ((X*C1) << C2) == (X * (C1 << C2))
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
     if (BO->getOpcode() == Instruction::Mul && isLeftShift)
       if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
         return BinaryOperator::CreateMul(BO->getOperand(0),
-                                        ConstantExpr::getShl(BOOp, COp1));
+                                        ConstantExpr::getShl(BOOp, Op1));
 
   // Try to fold constant and into select arguments.
   if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
@@ -432,8 +425,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
           Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
                                           Op0BO->getOperand(1)->getName());
           uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
-          return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
-                     APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val)));
+
+          APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
+          Constant *Mask = ConstantInt::get(I.getContext(), Bits);
+          if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
+            Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+          return BinaryOperator::CreateAnd(X, Mask);
         }
 
         // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
@@ -464,8 +461,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
           Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
                                           Op0BO->getOperand(0)->getName());
           uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
-          return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getContext(),
-                     APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val)));
+
+          APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
+          Constant *Mask = ConstantInt::get(I.getContext(), Bits);
+          if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
+            Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+          return BinaryOperator::CreateAnd(X, Mask);
         }
 
         // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
@@ -487,7 +488,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
       }
 
 
-      // If the operand is an bitwise operator with a constant RHS, and the
+      // If the operand is a bitwise operator with a constant RHS, and the
       // shift is the only use, we can pull it out of the shift.
       if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) {
         bool isValid = true;     // Valid only for And, Or, Xor
@@ -533,7 +534,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
   // Find out if this is a shift of a shift by a constant.
   BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
   if (ShiftOp && !ShiftOp->isShift())
-    ShiftOp = 0;
+    ShiftOp = nullptr;
 
   if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) {
 
@@ -553,7 +554,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
     uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
     uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
     assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
-    if (ShiftAmt1 == 0) return 0;  // Will be simplified in the future.
+    if (ShiftAmt1 == 0) return nullptr;  // Will be simplified in the future.
     Value *X = ShiftOp->getOperand(0);
 
     IntegerType *Ty = cast<IntegerType>(I.getType());
@@ -681,10 +682,13 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
       }
     }
   }
-  return 0;
+  return nullptr;
 }
 
 Instruction *InstCombiner::visitShl(BinaryOperator &I) {
+  if (Value *V = SimplifyVectorOp(I))
+    return ReplaceInstUsesWith(I, V);
+
   if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
                                  DL))
@@ -719,10 +723,13 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
       match(I.getOperand(1), m_Constant(C2)))
     return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
 
-  return 0;
+  return nullptr;
 }
 
 Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
+  if (Value *V = SimplifyVectorOp(I))
+    return ReplaceInstUsesWith(I, V);
+
   if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1),
                                   I.isExact(), DL))
     return ReplaceInstUsesWith(I, V);
@@ -759,10 +766,13 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
     }
   }
 
-  return 0;
+  return nullptr;
 }
 
 Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
+  if (Value *V = SimplifyVectorOp(I))
+    return ReplaceInstUsesWith(I, V);
+
   if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1),
                                   I.isExact(), DL))
     return ReplaceInstUsesWith(I, V);
@@ -779,11 +789,6 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
     // have a sign-extend idiom.
     Value *X;
     if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
-      // If the left shift is just shifting out partial signbits, delete the
-      // extension.
-      if (cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap())
-        return ReplaceInstUsesWith(I, X);
-
       // If the input is an extension from the shifted amount value, e.g.
       //   %x = zext i8 %A to i32
       //   %y = shl i32 %x, 24
@@ -810,11 +815,5 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
                         APInt::getSignBit(I.getType()->getScalarSizeInBits())))
     return BinaryOperator::CreateLShr(Op0, Op1);
 
-  // Arithmetic shifting an all-sign-bit value is a no-op.
-  unsigned NumSignBits = ComputeNumSignBits(Op0);
-  if (NumSignBits == Op0->getType()->getScalarSizeInBits())
-    return ReplaceInstUsesWith(I, Op0);
-
-  return 0;
+  return nullptr;
 }
-