Add missing newlines at EOF (for clang++).
[oota-llvm.git] / lib / Transforms / Scalar / InstructionCombining.cpp
index efda6608e44d1a67d44ead393a6f611c730ab34e..a6e0eef854d4e3b719a423ed54dbe44253e6145f 100644 (file)
@@ -2163,8 +2163,8 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
   
   // Add has the property that adding any two 2's complement numbers can only 
   // have one carry bit which can change a sign.  As such, if LHS and RHS each
-  // have at least two sign bits, we know that the addition of the two values will
-  // sign extend fine.
+  // have at least two sign bits, we know that the addition of the two values
+  // will sign extend fine.
   if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
     return true;
   
@@ -2184,15 +2184,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   bool Changed = SimplifyCommutative(I);
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
 
-  if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
-    // X + undef -> undef
-    if (isa<UndefValue>(RHS))
-      return ReplaceInstUsesWith(I, RHS);
-
-    // X + 0 --> X
-    if (RHSC->isNullValue())
-      return ReplaceInstUsesWith(I, LHS);
+  if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
+                                 I.hasNoUnsignedWrap(), TD))
+    return ReplaceInstUsesWith(I, V);
 
+  
+  if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(RHSC)) {
       // X + (signbit) --> X ^ signbit
       const APInt& Val = CI->getValue();
@@ -4070,6 +4067,21 @@ Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS,
 /// FoldAndOfICmps - Fold (icmp)&(icmp) if possible.
 Instruction *InstCombiner::FoldAndOfICmps(Instruction &I,
                                           ICmpInst *LHS, ICmpInst *RHS) {
+  // (icmp eq A, null) & (icmp eq B, null) -->
+  //     (icmp eq (ptrtoint(A)|ptrtoint(B)), 0)
+  if (TD &&
+      LHS->getPredicate() == ICmpInst::ICMP_EQ &&
+      RHS->getPredicate() == ICmpInst::ICMP_EQ &&
+      isa<ConstantPointerNull>(LHS->getOperand(1)) &&
+      isa<ConstantPointerNull>(RHS->getOperand(1))) {
+    const Type *IntPtrTy = TD->getIntPtrType(I.getContext());
+    Value *A = Builder->CreatePtrToInt(LHS->getOperand(0), IntPtrTy);
+    Value *B = Builder->CreatePtrToInt(RHS->getOperand(0), IntPtrTy);
+    Value *NewOr = Builder->CreateOr(A, B);
+    return new ICmpInst(ICmpInst::ICMP_EQ, NewOr,
+                        Constant::getNullValue(IntPtrTy));
+  }
+  
   Value *Val, *Val2;
   ConstantInt *LHSCst, *RHSCst;
   ICmpInst::Predicate LHSCC, RHSCC;
@@ -4081,12 +4093,20 @@ Instruction *InstCombiner::FoldAndOfICmps(Instruction &I,
                          m_ConstantInt(RHSCst))))
     return 0;
   
-  // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)
-  // where C is a power of 2
-  if (LHSCst == RHSCst && LHSCC == RHSCC && LHSCC == ICmpInst::ICMP_ULT &&
-      LHSCst->getValue().isPowerOf2()) {
-    Value *NewOr = Builder->CreateOr(Val, Val2);
-    return new ICmpInst(LHSCC, NewOr, LHSCst);
+  if (LHSCst == RHSCst && LHSCC == RHSCC) {
+    // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)
+    // where C is a power of 2
+    if (LHSCC == ICmpInst::ICMP_ULT &&
+        LHSCst->getValue().isPowerOf2()) {
+      Value *NewOr = Builder->CreateOr(Val, Val2);
+      return new ICmpInst(LHSCC, NewOr, LHSCst);
+    }
+    
+    // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0)
+    if (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero()) {
+      Value *NewOr = Builder->CreateOr(Val, Val2);
+      return new ICmpInst(LHSCC, NewOr, LHSCst);
+    }
   }
   
   // From here on, we only handle:
@@ -4322,7 +4342,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
 
   if (Value *V = SimplifyAndInst(Op0, Op1, TD))
     return ReplaceInstUsesWith(I, V);
-    
 
   // See if we can simplify any instructions used by the instruction whose sole 
   // purpose is to compute bits we don't care about.
@@ -4743,16 +4762,37 @@ static Instruction *MatchSelectFromAndOr(Value *A, Value *B,
 /// FoldOrOfICmps - Fold (icmp)|(icmp) if possible.
 Instruction *InstCombiner::FoldOrOfICmps(Instruction &I,
                                          ICmpInst *LHS, ICmpInst *RHS) {
+  // (icmp ne A, null) | (icmp ne B, null) -->
+  //     (icmp ne (ptrtoint(A)|ptrtoint(B)), 0)
+  if (TD &&
+      LHS->getPredicate() == ICmpInst::ICMP_NE &&
+      RHS->getPredicate() == ICmpInst::ICMP_NE &&
+      isa<ConstantPointerNull>(LHS->getOperand(1)) &&
+      isa<ConstantPointerNull>(RHS->getOperand(1))) {
+    const Type *IntPtrTy = TD->getIntPtrType(I.getContext());
+    Value *A = Builder->CreatePtrToInt(LHS->getOperand(0), IntPtrTy);
+    Value *B = Builder->CreatePtrToInt(RHS->getOperand(0), IntPtrTy);
+    Value *NewOr = Builder->CreateOr(A, B);
+    return new ICmpInst(ICmpInst::ICMP_NE, NewOr,
+                        Constant::getNullValue(IntPtrTy));
+  }
+  
   Value *Val, *Val2;
   ConstantInt *LHSCst, *RHSCst;
   ICmpInst::Predicate LHSCC, RHSCC;
   
   // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
-  if (!match(LHS, m_ICmp(LHSCC, m_Value(Val),
-             m_ConstantInt(LHSCst))) ||
-      !match(RHS, m_ICmp(RHSCC, m_Value(Val2),
-             m_ConstantInt(RHSCst))))
+  if (!match(LHS, m_ICmp(LHSCC, m_Value(Val), m_ConstantInt(LHSCst))) ||
+      !match(RHS, m_ICmp(RHSCC, m_Value(Val2), m_ConstantInt(RHSCst))))
     return 0;
+
+  
+  // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
+  if (LHSCst == RHSCst && LHSCC == RHSCC &&
+      LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) {
+    Value *NewOr = Builder->CreateOr(Val, Val2);
+    return new ICmpInst(LHSCC, NewOr, LHSCst);
+  }
   
   // From here on, we only handle:
   //    (icmp1 A, C1) | (icmp2 A, C2) --> something simpler.
@@ -6316,24 +6356,26 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         // comparison into the select arms, which will cause one to be
         // constant folded and the select turned into a bitwise or.
         Value *Op1 = 0, *Op2 = 0;
-        if (LHSI->hasOneUse()) {
-          if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) {
-            // Fold the known value into the constant operand.
-            Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
-            // Insert a new ICmp of the other select operand.
-            Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2),
-                                      RHSC, I.getName());
-          } else if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) {
-            // Fold the known value into the constant operand.
-            Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
-            // Insert a new ICmp of the other select operand.
+        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1)))
+          Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2)))
+          Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+
+        // We only want to perform this transformation if it will not lead to
+        // additional code. This is true if either both sides of the select
+        // fold to a constant (in which case the icmp is replaced with a select
+        // which will usually simplify) or this is the only user of the
+        // select (in which case we are trading a select+icmp for a simpler
+        // select+icmp).
+        if ((Op1 && Op2) || (LHSI->hasOneUse() && (Op1 || Op2))) {
+          if (!Op1)
             Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1),
                                       RHSC, I.getName());
-          }
-        }
-
-        if (Op1)
+          if (!Op2)
+            Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2),
+                                      RHSC, I.getName());
           return SelectInst::Create(LHSI->getOperand(0), Op1, Op2);
+        }
         break;
       }
       case Instruction::Call:
@@ -6412,7 +6454,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     //   if (X) ...
     // For generality, we handle any zero-extension of any operand comparison
     // with a constant or another cast from the same type.
-    if (isa<ConstantInt>(Op1) || isa<CastInst>(Op1))
+    if (isa<Constant>(Op1) || isa<CastInst>(Op1))
       if (Instruction *R = visitICmpInstWithCastAndCast(I))
         return R;
   }
@@ -7259,19 +7301,17 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) {
 
   // If the re-extended constant didn't change...
   if (Res2 == CI) {
-    // Make sure that sign of the Cmp and the sign of the Cast are the same.
-    // For example, we might have:
-    //    %A = sext i16 %X to i32
-    //    %B = icmp ugt i32 %A, 1330
-    // It is incorrect to transform this into 
-    //    %B = icmp ugt i16 %X, 1330
-    // because %A may have negative value. 
-    //
-    // However, we allow this when the compare is EQ/NE, because they are
-    // signless.
-    if (isSignedExt == isSignedCmp || ICI.isEquality())
+    // Deal with equality cases early.
+    if (ICI.isEquality())
       return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1);
-    return 0;
+
+    // A signed comparison of sign extended values simplifies into a
+    // signed comparison.
+    if (isSignedExt && isSignedCmp)
+      return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1);
+
+    // The other three cases all fold into an unsigned comparison.
+    return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, Res1);
   }
 
   // The re-extended constant changed so the constant cannot be represented 
@@ -8539,6 +8579,47 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
     }
   }
 
+  // icmp ne A, B is equal to xor A, B when A and B only really have one bit.
+  // It is also profitable to transform icmp eq into not(xor(A, B)) because that
+  // may lead to additional simplifications.
+  if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) {
+    if (const IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) {
+      uint32_t BitWidth = ITy->getBitWidth();
+      Value *LHS = ICI->getOperand(0);
+      Value *RHS = ICI->getOperand(1);
+
+      APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0);
+      APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0);
+      APInt TypeMask(APInt::getAllOnesValue(BitWidth));
+      ComputeMaskedBits(LHS, TypeMask, KnownZeroLHS, KnownOneLHS);
+      ComputeMaskedBits(RHS, TypeMask, KnownZeroRHS, KnownOneRHS);
+
+      if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) {
+        APInt KnownBits = KnownZeroLHS | KnownOneLHS;
+        APInt UnknownBit = ~KnownBits;
+        if (UnknownBit.countPopulation() == 1) {
+          if (!DoXform) return ICI;
+
+          Value *Result = Builder->CreateXor(LHS, RHS);
+
+          // Mask off any bits that are set and won't be shifted away.
+          if (KnownOneLHS.uge(UnknownBit))
+            Result = Builder->CreateAnd(Result,
+                                        ConstantInt::get(ITy, UnknownBit));
+
+          // Shift the bit we're testing down to the lsb.
+          Result = Builder->CreateLShr(
+               Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros()));
+
+          if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
+            Result = Builder->CreateXor(Result, ConstantInt::get(ITy, 1));
+          Result->takeName(ICI);
+          return ReplaceInstUsesWith(CI, Result);
+        }
+      }
+    }
+  }
+
   return 0;
 }
 
@@ -9815,9 +9896,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
                         Intrinsic::getDeclaration(M, MemCpyID, Tys, 1));
           Changed = true;
         }
+    }
 
+    if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
       // memmove(x,x,size) -> noop.
-      if (MMI->getSource() == MMI->getDest())
+      if (MTI->getSource() == MTI->getDest())
         return EraseInstFromFunction(CI);
     }
 
@@ -9842,6 +9925,126 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
       if (Operand->getIntrinsicID() == Intrinsic::bswap)
         return ReplaceInstUsesWith(CI, Operand->getOperand(1));
     break;
+  case Intrinsic::uadd_with_overflow: {
+    Value *LHS = II->getOperand(1), *RHS = II->getOperand(2);
+    const IntegerType *IT = cast<IntegerType>(II->getOperand(1)->getType());
+    uint32_t BitWidth = IT->getBitWidth();
+    APInt Mask = APInt::getSignBit(BitWidth);
+    APInt LHSKnownZero(BitWidth, 0);
+    APInt LHSKnownOne(BitWidth, 0);
+    ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne);
+    bool LHSKnownNegative = LHSKnownOne[BitWidth - 1];
+    bool LHSKnownPositive = LHSKnownZero[BitWidth - 1];
+
+    if (LHSKnownNegative || LHSKnownPositive) {
+      APInt RHSKnownZero(BitWidth, 0);
+      APInt RHSKnownOne(BitWidth, 0);
+      ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne);
+      bool RHSKnownNegative = RHSKnownOne[BitWidth - 1];
+      bool RHSKnownPositive = RHSKnownZero[BitWidth - 1];
+      if (LHSKnownNegative && RHSKnownNegative) {
+        // The sign bit is set in both cases: this MUST overflow.
+        // Create a simple add instruction, and insert it into the struct.
+        Instruction *Add = BinaryOperator::CreateAdd(LHS, RHS, "", &CI);
+        Worklist.Add(Add);
+        Constant *V[] = {
+          UndefValue::get(LHS->getType()), ConstantInt::getTrue(*Context)
+        };
+        Constant *Struct = ConstantStruct::get(*Context, V, 2, false);
+        return InsertValueInst::Create(Struct, Add, 0);
+      }
+      
+      if (LHSKnownPositive && RHSKnownPositive) {
+        // The sign bit is clear in both cases: this CANNOT overflow.
+        // Create a simple add instruction, and insert it into the struct.
+        Instruction *Add = BinaryOperator::CreateNUWAdd(LHS, RHS, "", &CI);
+        Worklist.Add(Add);
+        Constant *V[] = {
+          UndefValue::get(LHS->getType()), ConstantInt::getFalse(*Context)
+        };
+        Constant *Struct = ConstantStruct::get(*Context, V, 2, false);
+        return InsertValueInst::Create(Struct, Add, 0);
+      }
+    }
+  }
+  // FALL THROUGH uadd into sadd
+  case Intrinsic::sadd_with_overflow:
+    // Canonicalize constants into the RHS.
+    if (isa<Constant>(II->getOperand(1)) &&
+        !isa<Constant>(II->getOperand(2))) {
+      Value *LHS = II->getOperand(1);
+      II->setOperand(1, II->getOperand(2));
+      II->setOperand(2, LHS);
+      return II;
+    }
+
+    // X + undef -> undef
+    if (isa<UndefValue>(II->getOperand(2)))
+      return ReplaceInstUsesWith(CI, UndefValue::get(II->getType()));
+      
+    if (ConstantInt *RHS = dyn_cast<ConstantInt>(II->getOperand(2))) {
+      // X + 0 -> {X, false}
+      if (RHS->isZero()) {
+        Constant *V[] = {
+          UndefValue::get(II->getOperand(0)->getType()),
+          ConstantInt::getFalse(*Context)
+        };
+        Constant *Struct = ConstantStruct::get(*Context, V, 2, false);
+        return InsertValueInst::Create(Struct, II->getOperand(1), 0);
+      }
+    }
+    break;
+  case Intrinsic::usub_with_overflow:
+  case Intrinsic::ssub_with_overflow:
+    // undef - X -> undef
+    // X - undef -> undef
+    if (isa<UndefValue>(II->getOperand(1)) ||
+        isa<UndefValue>(II->getOperand(2)))
+      return ReplaceInstUsesWith(CI, UndefValue::get(II->getType()));
+      
+    if (ConstantInt *RHS = dyn_cast<ConstantInt>(II->getOperand(2))) {
+      // X - 0 -> {X, false}
+      if (RHS->isZero()) {
+        Constant *V[] = {
+          UndefValue::get(II->getOperand(1)->getType()),
+          ConstantInt::getFalse(*Context)
+        };
+        Constant *Struct = ConstantStruct::get(*Context, V, 2, false);
+        return InsertValueInst::Create(Struct, II->getOperand(1), 0);
+      }
+    }
+    break;
+  case Intrinsic::umul_with_overflow:
+  case Intrinsic::smul_with_overflow:
+    // Canonicalize constants into the RHS.
+    if (isa<Constant>(II->getOperand(1)) &&
+        !isa<Constant>(II->getOperand(2))) {
+      Value *LHS = II->getOperand(1);
+      II->setOperand(1, II->getOperand(2));
+      II->setOperand(2, LHS);
+      return II;
+    }
+
+    // X * undef -> undef
+    if (isa<UndefValue>(II->getOperand(2)))
+      return ReplaceInstUsesWith(CI, UndefValue::get(II->getType()));
+      
+    if (ConstantInt *RHSI = dyn_cast<ConstantInt>(II->getOperand(2))) {
+      // X*0 -> {0, false}
+      if (RHSI->isZero())
+        return ReplaceInstUsesWith(CI, Constant::getNullValue(II->getType()));
+      
+      // X * 1 -> {X, false}
+      if (RHSI->equalsInt(1)) {
+        Constant *V[] = {
+          UndefValue::get(II->getOperand(1)->getType()),
+          ConstantInt::getFalse(*Context)
+        };
+        Constant *Struct = ConstantStruct::get(*Context, V, 2, false);
+        return InsertValueInst::Create(Struct, II->getOperand(1), 0);
+      }
+    }
+    break;
   case Intrinsic::ppc_altivec_lvx:
   case Intrinsic::ppc_altivec_lvxl:
   case Intrinsic::x86_sse_loadu_ps:
@@ -10999,8 +11202,9 @@ namespace llvm {
       return LHS.PN == RHS.PN && LHS.Shift == RHS.Shift &&
              LHS.Width == RHS.Width;
     }
-    static bool isPod() { return true; }
   };
+  template <>
+  struct isPodLike<LoweredPHIRecord> { static const bool value = true; };
 }
 
 
@@ -11282,21 +11486,16 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) {
 }
 
 Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
+
+  if (Value *V = SimplifyGEPInst(&Ops[0], Ops.size(), TD))
+    return ReplaceInstUsesWith(GEP, V);
+
   Value *PtrOp = GEP.getOperand(0);
-  // Eliminate 'getelementptr %P, i32 0' and 'getelementptr %P', they are noops.
-  if (GEP.getNumOperands() == 1)
-    return ReplaceInstUsesWith(GEP, PtrOp);
 
   if (isa<UndefValue>(GEP.getOperand(0)))
     return ReplaceInstUsesWith(GEP, UndefValue::get(GEP.getType()));
 
-  bool HasZeroPointerIndex = false;
-  if (Constant *C = dyn_cast<Constant>(GEP.getOperand(1)))
-    HasZeroPointerIndex = C->isNullValue();
-
-  if (GEP.getNumOperands() == 2 && HasZeroPointerIndex)
-    return ReplaceInstUsesWith(GEP, PtrOp);
-
   // Eliminate unneeded casts for indices.
   if (TD) {
     bool MadeChange = false;
@@ -11401,6 +11600,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
       return 0;
     }
     
+    bool HasZeroPointerIndex = false;
+    if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1)))
+      HasZeroPointerIndex = C->isZero();
+    
     // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ...
     // into     : GEP [10 x i8]* X, i32 0, ...
     //
@@ -11952,12 +12155,6 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
   Value *Val = SI.getOperand(0);
   Value *Ptr = SI.getOperand(1);
 
-  if (isa<UndefValue>(Ptr)) {     // store X, undef -> noop (even if volatile)
-    EraseInstFromFunction(SI);
-    ++NumCombined;
-    return 0;
-  }
-  
   // If the RHS is an alloca with a single use, zapify the store, making the
   // alloca dead.
   // If the RHS is an alloca with a two uses, the other one being a 
@@ -12917,29 +13114,33 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
     if (isa<UndefValue>(RHS)) {
       std::vector<unsigned> LHSMask = getShuffleMask(LHSSVI);
 
-      std::vector<unsigned> NewMask;
-      for (unsigned i = 0, e = Mask.size(); i != e; ++i)
-        if (Mask[i] >= 2*e)
-          NewMask.push_back(2*e);
-        else
-          NewMask.push_back(LHSMask[Mask[i]]);
+      if (LHSMask.size() == Mask.size()) {
+        std::vector<unsigned> NewMask;
+        for (unsigned i = 0, e = Mask.size(); i != e; ++i)
+          if (Mask[i] >= e)
+            NewMask.push_back(2*e);
+          else
+            NewMask.push_back(LHSMask[Mask[i]]);
       
-      // If the result mask is equal to the src shuffle or this shuffle mask, do
-      // the replacement.
-      if (NewMask == LHSMask || NewMask == Mask) {
-        unsigned LHSInNElts =
-          cast<VectorType>(LHSSVI->getOperand(0)->getType())->getNumElements();
-        std::vector<Constant*> Elts;
-        for (unsigned i = 0, e = NewMask.size(); i != e; ++i) {
-          if (NewMask[i] >= LHSInNElts*2) {
-            Elts.push_back(UndefValue::get(Type::getInt32Ty(*Context)));
-          } else {
-            Elts.push_back(ConstantInt::get(Type::getInt32Ty(*Context), NewMask[i]));
+        // If the result mask is equal to the src shuffle or this
+        // shuffle mask, do the replacement.
+        if (NewMask == LHSMask || NewMask == Mask) {
+          unsigned LHSInNElts =
+            cast<VectorType>(LHSSVI->getOperand(0)->getType())->
+            getNumElements();
+          std::vector<Constant*> Elts;
+          for (unsigned i = 0, e = NewMask.size(); i != e; ++i) {
+            if (NewMask[i] >= LHSInNElts*2) {
+              Elts.push_back(UndefValue::get(Type::getInt32Ty(*Context)));
+            } else {
+              Elts.push_back(ConstantInt::get(Type::getInt32Ty(*Context),
+                                              NewMask[i]));
+            }
           }
+          return new ShuffleVectorInst(LHSSVI->getOperand(0),
+                                       LHSSVI->getOperand(1),
+                                       ConstantVector::get(Elts));
         }
-        return new ShuffleVectorInst(LHSSVI->getOperand(0),
-                                     LHSSVI->getOperand(1),
-                                     ConstantVector::get(Elts));
       }
     }
   }