PR16726: extend rol/ror matching
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 98806551e4d70b9e25cb28a43e00270b0fb97203..b18c69b52a799822a63a925595c6669972b39cfe 100644 (file)
@@ -35,6 +35,7 @@
 #include "llvm/Target/TargetLowering.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "llvm/Target/TargetSubtargetInfo.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -154,7 +155,7 @@ namespace {
     SDValue PromoteExtend(SDValue Op);
     bool PromoteLoad(SDValue Op);
 
-    void ExtendSetCCUses(SmallVector<SDNode*, 4> SetCCs,
+    void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
                          SDValue Trunc, SDValue ExtLoad, SDLoc DL,
                          ISD::NodeType ExtType);
 
@@ -279,7 +280,7 @@ namespace {
     /// GatherAllAliases - Walk up chain skipping non-aliasing memory nodes,
     /// looking for aliasing nodes and adding them to the Aliases vector.
     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
-                          SmallVector<SDValue, 8> &Aliases);
+                          SmallVectorImpl<SDValue> &Aliases);
 
     /// isAlias - Return true if there is any possibility that the two addresses
     /// overlap.
@@ -1823,20 +1824,24 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
   // fold (mul x, 0) -> 0
   if (N1IsConst && ConstValue1 == 0)
     return N1;
+  // We require a splat of the entire scalar bit width for non-contiguous
+  // bit patterns.
+  bool IsFullSplat =
+    ConstValue1.getBitWidth() == VT.getScalarType().getSizeInBits();
   // fold (mul x, 1) -> x
-  if (N1IsConst && ConstValue1 == 1)
+  if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
     return N0;
   // fold (mul x, -1) -> 0-x
   if (N1IsConst && ConstValue1.isAllOnesValue())
     return DAG.getNode(ISD::SUB, SDLoc(N), VT,
                        DAG.getConstant(0, VT), N0);
   // fold (mul x, (1 << c)) -> x << c
-  if (N1IsConst && ConstValue1.isPowerOf2())
+  if (N1IsConst && ConstValue1.isPowerOf2() && IsFullSplat)
     return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
                        DAG.getConstant(ConstValue1.logBase2(),
                                        getShiftAmountTy(N0.getValueType())));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
-  if (N1IsConst && (-ConstValue1).isPowerOf2()) {
+  if (N1IsConst && (-ConstValue1).isPowerOf2() && IsFullSplat) {
     unsigned Log2Val = (-ConstValue1).logBase2();
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
@@ -2675,6 +2680,19 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
         return DAG.getSetCC(SDLoc(N), VT, ORNode, LR, Op1);
       }
     }
+    // Simplify (and (setne X, 0), (setne X, -1)) -> (setuge (add X, 1), 2)
+    if (LL == RL && isa<ConstantSDNode>(LR) && isa<ConstantSDNode>(RR) &&
+        Op0 == Op1 && LL.getValueType().isInteger() &&
+      Op0 == ISD::SETNE && ((cast<ConstantSDNode>(LR)->isNullValue() &&
+                                 cast<ConstantSDNode>(RR)->isAllOnesValue()) ||
+                                (cast<ConstantSDNode>(LR)->isAllOnesValue() &&
+                                 cast<ConstantSDNode>(RR)->isNullValue()))) {
+      SDValue ADDNode = DAG.getNode(ISD::ADD, SDLoc(N0), LL.getValueType(),
+                                    LL, DAG.getConstant(1, LL.getValueType()));
+      AddToWorkList(ADDNode.getNode());
+      return DAG.getSetCC(SDLoc(N), VT, ADDNode,
+                          DAG.getConstant(2, LL.getValueType()), ISD::SETUGE);
+    }
     // canonicalize equivalent to ll == rl
     if (LL == RR && LR == RL) {
       Op1 = ISD::getSetCCSwappedOperands(Op1);
@@ -2848,6 +2866,14 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     }
   }
 
+  // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
+  if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
+    SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
+                                       N0.getOperand(1), false);
+    if (BSwap.getNode())
+      return BSwap;
+  }
+
   return SDValue();
 }
 
@@ -2932,13 +2958,23 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
   if (N00 != N10)
     return SDValue();
 
-  // Make sure everything beyond the low halfword is zero since the SRL 16
-  // will clear the top bits.
+  // Make sure everything beyond the low halfword gets set to zero since the SRL
+  // 16 will clear the top bits.
   unsigned OpSizeInBits = VT.getSizeInBits();
-  if (DemandHighBits && OpSizeInBits > 16 &&
-      (!LookPassAnd0 || !LookPassAnd1) &&
-      !DAG.MaskedValueIsZero(N10, APInt::getHighBitsSet(OpSizeInBits, 16)))
-    return SDValue();
+  if (DemandHighBits && OpSizeInBits > 16) {
+    // If the left-shift isn't masked out then the only way this is a bswap is
+    // if all bits beyond the low 8 are 0. In that case the entire pattern
+    // reduces to a left shift anyway: leave it for other parts of the combiner.
+    if (!LookPassAnd0)
+      return SDValue();
+
+    // However, if the right shift isn't masked out then it might be because
+    // it's not needed. See if we can spot that too.
+    if (!LookPassAnd1 &&
+        !DAG.MaskedValueIsZero(
+            N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
+      return SDValue();
+  }
 
   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
   if (OpSizeInBits > 16)
@@ -2950,7 +2986,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
 /// isBSwapHWordElement - Return true if the specified node is an element
 /// that makes up a 32-bit packed halfword byteswap. i.e.
 /// ((x&0xff)<<8)|((x&0xff00)>>8)|((x&0x00ff0000)<<8)|((x&0xff000000)>>8)
-static bool isBSwapHWordElement(SDValue N, SmallVector<SDNode*,4> &Parts) {
+static bool isBSwapHWordElement(SDValue N, SmallVectorImpl<SDNode *> &Parts) {
   if (!N.getNode()->hasOneUse())
     return false;
 
@@ -3305,6 +3341,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
   unsigned OpSizeInBits = VT.getSizeInBits();
   SDValue LHSShiftArg = LHSShift.getOperand(0);
   SDValue LHSShiftAmt = LHSShift.getOperand(1);
+  SDValue RHSShiftArg = RHSShift.getOperand(0);
   SDValue RHSShiftAmt = RHSShift.getOperand(1);
 
   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
@@ -3388,6 +3425,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
           return DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
                              LHSShiftArg,
                              HasROTL ? LHSShiftAmt : RHSShiftAmt).getNode();
+        else if (LHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 LHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+          // fold (or (shl (*ext x), (*ext y)),
+          //          (srl (*ext x), (*ext (sub 32, y)))) ->
+          //   (*ext (rotl x, y))
+          // fold (or (shl (*ext x), (*ext y)),
+          //          (srl (*ext x), (*ext (sub 32, y)))) ->
+          //   (*ext (rotr x, (sub 32, y)))
+          SDValue LArgExtOp0 = LHSShiftArg.getOperand(0);
+          EVT LArgVT = LArgExtOp0.getValueType();
+          if (LArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, LArgVT,
+                             LArgExtOp0,
+                             HasROTL ? LHSShiftAmt : RHSShiftAmt);
+            return DAG.getNode(LHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }
+        }
     } else if (LExtOp0.getOpcode() == ISD::SUB &&
                RExtOp0 == LExtOp0.getOperand(1)) {
       // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
@@ -3400,6 +3454,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
           return DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, VT,
                              LHSShiftArg,
                              HasROTR ? RHSShiftAmt : LHSShiftAmt).getNode();
+        else if (RHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 RHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+          // fold (or (shl (*ext x), (*ext (sub 32, y))),
+          //          (srl (*ext x), (*ext y))) ->
+          //   (*ext (rotl x, y))
+          // fold (or (shl (*ext x), (*ext (sub 32, y))),
+          //          (srl (*ext x), (*ext y))) ->
+          //   (*ext (rotr x, (sub 32, y)))
+          SDValue RArgExtOp0 = RHSShiftArg.getOperand(0);
+          EVT RArgVT = RArgExtOp0.getValueType();
+          if (RArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V = DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, RArgVT,
+                             RArgExtOp0,
+                             HasROTR ? RHSShiftAmt : LHSShiftAmt);
+            return DAG.getNode(RHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }
+        }
     }
   }
 
@@ -4309,7 +4380,7 @@ SDValue DAGCombiner::visitSETCC(SDNode *N) {
 // mentioned transformation is profitable.
 static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0,
                                     unsigned ExtOpc,
-                                    SmallVector<SDNode*, 4> &ExtendNodes,
+                                    SmallVectorImpl<SDNode *> &ExtendNodes,
                                     const TargetLowering &TLI) {
   bool HasCopyToRegUses = false;
   bool isTruncFree = TLI.isTruncateFree(N->getValueType(0), N0.getValueType());
@@ -4367,7 +4438,7 @@ static bool ExtendUsesToFormExtLoad(SDNode *N, SDValue N0,
   return true;
 }
 
-void DAGCombiner::ExtendSetCCUses(SmallVector<SDNode*, 4> SetCCs,
+void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
                                   SDValue Trunc, SDValue ExtLoad, SDLoc DL,
                                   ISD::NodeType ExtType) {
   // Extend SetCC uses if necessary.
@@ -5448,7 +5519,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
     SDValue EltNo = N0->getOperand(1);
     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
-      EVT IndexTy = N0->getOperand(1).getValueType();
+      EVT IndexTy = TLI.getVectorIdxTy();
       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
 
       SDValue V = DAG.getNode(ISD::BITCAST, SDLoc(N),
@@ -5680,8 +5751,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
   // This often reduces constant pool loads.
-  if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT)) ||
-       (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT))) &&
+  if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
+       (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
       N0.getNode()->hasOneUse() && VT.isInteger() &&
       !VT.isVector() && !N0.getValueType().isVector()) {
     SDValue NewConv = DAG.getNode(ISD::BITCAST, SDLoc(N0), VT,
@@ -6151,7 +6222,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
       if (N10 == N0 && isNegatibleForFree(N11, LegalOperations, TLI,
                                           &DAG.getTarget().Options))
         return GetNegatedExpression(N11, DAG, LegalOperations);
-      
+
       if (N11 == N0 && isNegatibleForFree(N10, LegalOperations, TLI,
                                           &DAG.getTarget().Options))
         return GetNegatedExpression(N10, DAG, LegalOperations);
@@ -6172,7 +6243,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
 
     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
     // Note: Commutes FSUB operands.
-    if (N1.getOpcode() == ISD::FMUL && N1->hasOneUse()) 
+    if (N1.getOpcode() == ISD::FMUL && N1->hasOneUse())
       return DAG.getNode(ISD::FMA, dl, VT,
                          DAG.getNode(ISD::FNEG, dl, VT,
                          N1.getOperand(0)),
@@ -7457,7 +7528,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
     }
   }
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -7846,17 +7919,28 @@ struct BaseIndexOffset {
   static BaseIndexOffset match(SDValue Ptr) {
     bool IsIndexSignExt = false;
 
-    // Just Base or possibly anything else.
+    // We only can pattern match BASE + INDEX + OFFSET. If Ptr is not an ADD
+    // instruction, then it could be just the BASE or everything else we don't
+    // know how to handle. Just use Ptr as BASE and give up.
     if (Ptr->getOpcode() != ISD::ADD)
       return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
 
-    // Base + offset.
+    // We know that we have at least an ADD instruction. Try to pattern match
+    // the simple case of BASE + OFFSET.
     if (isa<ConstantSDNode>(Ptr->getOperand(1))) {
       int64_t Offset = cast<ConstantSDNode>(Ptr->getOperand(1))->getSExtValue();
       return  BaseIndexOffset(Ptr->getOperand(0), SDValue(), Offset,
                               IsIndexSignExt);
     }
 
+    // Inside a loop the current BASE pointer is calculated using an ADD and a
+    // MUL instruction. In this case Ptr is the actual BASE pointer.
+    // (i64 add (i64 %array_ptr)
+    //          (i64 mul (i64 %induction_var)
+    //                   (i64 %element_size)))
+    if (Ptr->getOperand(1)->getOpcode() == ISD::MUL)
+      return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
+
     // Look at Base + Index + Offset cases.
     SDValue Base = Ptr->getOperand(0);
     SDValue IndexOffset = Ptr->getOperand(1);
@@ -8399,7 +8483,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
     // transform should not be done in this case.
     if (Value.getOpcode() != ISD::TargetConstantFP) {
       SDValue Tmp;
-      switch (CFP->getValueType(0).getSimpleVT().SimpleTy) {
+      switch (CFP->getSimpleValueType(0).SimpleTy) {
       default: llvm_unreachable("Unknown FP type");
       case MVT::f16:    // We don't do this for these yet.
       case MVT::f80:
@@ -8477,7 +8561,9 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
   if (NewST.getNode())
     return NewST;
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -8612,7 +8698,9 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
   // be converted to a BUILD_VECTOR).  Fill in the Ops vector with the
   // vector elements.
   SmallVector<SDValue, 8> Ops;
-  if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
+  // Do not combine these two vectors if the output vector will not replace
+  // the input vector.
+  if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
     Ops.append(InVec.getNode()->op_begin(),
                InVec.getNode()->op_end());
   } else if (InVec.getOpcode() == ISD::UNDEF) {
@@ -8685,7 +8773,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
       OrigElt -= NumElem;
     }
 
-    EVT IndexTy = N->getOperand(1).getValueType();
+    EVT IndexTy = TLI.getVectorIdxTy();
     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT,
                        InVec, DAG.getConstant(OrigElt, IndexTy));
   }
@@ -9356,10 +9444,10 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
     for (unsigned i = 0; i != NumElts; ++i) {
       int Idx = SVN->getMaskElt(i);
       if (Idx >= 0) {
-        if (Idx < (int)NumElts)
-          Idx += NumElts;
-        else
+        if (Idx >= (int)NumElts)
           Idx -= NumElts;
+        else
+          Idx = -1; // remove reference to lhs
       }
       NewMask.push_back(Idx);
     }
@@ -9850,7 +9938,7 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
         SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(),
                                           Cond, One, Zero);
         AddToWorkList(CstOffset.getNode());
-        CPIdx = DAG.getNode(ISD::ADD, DL, TLI.getPointerTy(), CPIdx,
+        CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx,
                             CstOffset);
         AddToWorkList(CPIdx.getNode());
         return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
@@ -10185,7 +10273,9 @@ bool DAGCombiner::isAlias(SDValue Ptr1, int64_t Size1,
       return false;
   }
 
-  if (CombinerGlobalAA) {
+  bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0 ? CombinerGlobalAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA && SrcValue1 && SrcValue2) {
     // Use alias analysis information.
     int64_t MinOffset = std::min(SrcValueOffset1, SrcValueOffset2);
     int64_t Overlap1 = Size1 + SrcValueOffset1 - MinOffset;
@@ -10240,7 +10330,7 @@ bool DAGCombiner::FindAliasInfo(SDNode *N,
 /// GatherAllAliases - Walk up chain skipping non-aliasing memory nodes,
 /// looking for aliasing nodes and adding them to the Aliases vector.
 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
-                                   SmallVector<SDValue, 8> &Aliases) {
+                                   SmallVectorImpl<SDValue> &Aliases) {
   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.