Make instcombine a little more aggressive in combining vector shuffles.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineVectorOps.cpp
index a58124d7032e0c046a00e3f389146270ae641939..634add86a81cd785fd256fef1f7b4f463b81f5a8 100644 (file)
@@ -515,37 +515,44 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
   // we are absolutely afraid of producing a shuffle mask not in the input
   // program, because the code gen may not be smart enough to turn a merged
   // shuffle into two specific shuffles: it may produce worse code.  As such,
-  // we only merge two shuffles if the result is one of the two input shuffle
-  // masks.  In this case, merging the shuffles just removes one instruction,
-  // which we know is safe.  This is good for things like turning:
-  // (splat(splat)) -> splat.
+  // we only merge two shuffles if the result is either a splat or one of the
+  // two input shuffle masks.  In this case, merging the shuffles just removes
+  // one instruction, which we know is safe.  This is good for things like
+  // turning: (splat(splat)) -> splat.
   if (ShuffleVectorInst *LHSSVI = dyn_cast<ShuffleVectorInst>(LHS)) {
     if (isa<UndefValue>(RHS)) {
       std::vector<unsigned> LHSMask = getShuffleMask(LHSSVI);
       
       if (LHSMask.size() == Mask.size()) {
         std::vector<unsigned> NewMask;
-        for (unsigned i = 0, e = Mask.size(); i != e; ++i)
-          if (Mask[i] >= e)
-            NewMask.push_back(2*e);
-          else
-            NewMask.push_back(LHSMask[Mask[i]]);
+        bool isSplat = true;
+        unsigned SplatElt = 2 * Mask.size(); // undef
+        for (unsigned i = 0, e = Mask.size(); i != e; ++i) {
+          unsigned MaskElt = 2 * e; // undef
+          if (Mask[i] < e)
+            MaskElt = LHSMask[Mask[i]];
+          // Check if this could still be a splat.
+          if (MaskElt < 2*e) {
+            if (SplatElt < 2*e && SplatElt != MaskElt)
+              isSplat = false;
+            SplatElt = MaskElt;
+          }
+          NewMask.push_back(MaskElt);
+        }
         
         // If the result mask is equal to the src shuffle or this
         // shuffle mask, do the replacement.
-        if (NewMask == LHSMask || NewMask == Mask) {
+        if (isSplat || NewMask == LHSMask || NewMask == Mask) {
           unsigned LHSInNElts =
           cast<VectorType>(LHSSVI->getOperand(0)->getType())->
           getNumElements();
           std::vector<Constant*> Elts;
+          const Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
           for (unsigned i = 0, e = NewMask.size(); i != e; ++i) {
             if (NewMask[i] >= LHSInNElts*2) {
-              Elts.push_back(UndefValue::get(
-                                             Type::getInt32Ty(SVI.getContext())));
+              Elts.push_back(UndefValue::get(Int32Ty));
             } else {
-              Elts.push_back(ConstantInt::get(
-                                              Type::getInt32Ty(SVI.getContext()),
-                                              NewMask[i]));
+              Elts.push_back(ConstantInt::get(Int32Ty, NewMask[i]));
             }
           }
           return new ShuffleVectorInst(LHSSVI->getOperand(0),