PR16726: extend rol/ror matching
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 0e316db55cb4f475c329041695af1138efd9c78e..b18c69b52a799822a63a925595c6669972b39cfe 100644 (file)
@@ -35,6 +35,7 @@
 #include "llvm/Target/TargetLowering.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "llvm/Target/TargetSubtargetInfo.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -1823,20 +1824,24 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
   // fold (mul x, 0) -> 0
   if (N1IsConst && ConstValue1 == 0)
     return N1;
+  // We require a splat of the entire scalar bit width for non-contiguous
+  // bit patterns.
+  bool IsFullSplat =
+    ConstValue1.getBitWidth() == VT.getScalarType().getSizeInBits();
   // fold (mul x, 1) -> x
-  if (N1IsConst && ConstValue1 == 1)
+  if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
     return N0;
   // fold (mul x, -1) -> 0-x
   if (N1IsConst && ConstValue1.isAllOnesValue())
     return DAG.getNode(ISD::SUB, SDLoc(N), VT,
                        DAG.getConstant(0, VT), N0);
   // fold (mul x, (1 << c)) -> x << c
-  if (N1IsConst && ConstValue1.isPowerOf2())
+  if (N1IsConst && ConstValue1.isPowerOf2() && IsFullSplat)
     return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
                        DAG.getConstant(ConstValue1.logBase2(),
                                        getShiftAmountTy(N0.getValueType())));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
-  if (N1IsConst && (-ConstValue1).isPowerOf2()) {
+  if (N1IsConst && (-ConstValue1).isPowerOf2() && IsFullSplat) {
     unsigned Log2Val = (-ConstValue1).logBase2();
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
@@ -2861,6 +2866,14 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     }
   }
 
+  // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
+  if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
+    SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
+                                       N0.getOperand(1), false);
+    if (BSwap.getNode())
+      return BSwap;
+  }
+
   return SDValue();
 }
 
@@ -2945,13 +2958,23 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
   if (N00 != N10)
     return SDValue();
 
-  // Make sure everything beyond the low halfword is zero since the SRL 16
-  // will clear the top bits.
+  // Make sure everything beyond the low halfword gets set to zero since the SRL
+  // 16 will clear the top bits.
   unsigned OpSizeInBits = VT.getSizeInBits();
-  if (DemandHighBits && OpSizeInBits > 16 &&
-      (!LookPassAnd0 || !LookPassAnd1) &&
-      !DAG.MaskedValueIsZero(N10, APInt::getHighBitsSet(OpSizeInBits, 16)))
-    return SDValue();
+  if (DemandHighBits && OpSizeInBits > 16) {
+    // If the left-shift isn't masked out then the only way this is a bswap is
+    // if all bits beyond the low 8 are 0. In that case the entire pattern
+    // reduces to a left shift anyway: leave it for other parts of the combiner.
+    if (!LookPassAnd0)
+      return SDValue();
+
+    // However, if the right shift isn't masked out then it might be because
+    // it's not needed. See if we can spot that too.
+    if (!LookPassAnd1 &&
+        !DAG.MaskedValueIsZero(
+            N10, APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - 16)))
+      return SDValue();
+  }
 
   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
   if (OpSizeInBits > 16)
@@ -3318,6 +3341,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
   unsigned OpSizeInBits = VT.getSizeInBits();
   SDValue LHSShiftArg = LHSShift.getOperand(0);
   SDValue LHSShiftAmt = LHSShift.getOperand(1);
+  SDValue RHSShiftArg = RHSShift.getOperand(0);
   SDValue RHSShiftAmt = RHSShift.getOperand(1);
 
   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
@@ -3401,6 +3425,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
           return DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
                              LHSShiftArg,
                              HasROTL ? LHSShiftAmt : RHSShiftAmt).getNode();
+        else if (LHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 LHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+          // fold (or (shl (*ext x), (*ext y)),
+          //          (srl (*ext x), (*ext (sub 32, y)))) ->
+          //   (*ext (rotl x, y))
+          // fold (or (shl (*ext x), (*ext y)),
+          //          (srl (*ext x), (*ext (sub 32, y)))) ->
+          //   (*ext (rotr x, (sub 32, y)))
+          SDValue LArgExtOp0 = LHSShiftArg.getOperand(0);
+          EVT LArgVT = LArgExtOp0.getValueType();
+          if (LArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, LArgVT,
+                             LArgExtOp0,
+                             HasROTL ? LHSShiftAmt : RHSShiftAmt);
+            return DAG.getNode(LHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }
+        }
     } else if (LExtOp0.getOpcode() == ISD::SUB &&
                RExtOp0 == LExtOp0.getOperand(1)) {
       // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
@@ -3413,6 +3454,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
           return DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, VT,
                              LHSShiftArg,
                              HasROTR ? RHSShiftAmt : LHSShiftAmt).getNode();
+        else if (RHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 RHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+          // fold (or (shl (*ext x), (*ext (sub 32, y))),
+          //          (srl (*ext x), (*ext y))) ->
+          //   (*ext (rotl x, y))
+          // fold (or (shl (*ext x), (*ext (sub 32, y))),
+          //          (srl (*ext x), (*ext y))) ->
+          //   (*ext (rotr x, (sub 32, y)))
+          SDValue RArgExtOp0 = RHSShiftArg.getOperand(0);
+          EVT RArgVT = RArgExtOp0.getValueType();
+          if (RArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V = DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, RArgVT,
+                             RArgExtOp0,
+                             HasROTR ? RHSShiftAmt : LHSShiftAmt);
+            return DAG.getNode(RHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }
+        }
     }
   }
 
@@ -7470,7 +7528,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
     }
   }
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -7859,17 +7919,28 @@ struct BaseIndexOffset {
   static BaseIndexOffset match(SDValue Ptr) {
     bool IsIndexSignExt = false;
 
-    // Just Base or possibly anything else.
+    // We only can pattern match BASE + INDEX + OFFSET. If Ptr is not an ADD
+    // instruction, then it could be just the BASE or everything else we don't
+    // know how to handle. Just use Ptr as BASE and give up.
     if (Ptr->getOpcode() != ISD::ADD)
       return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
 
-    // Base + offset.
+    // We know that we have at least an ADD instruction. Try to pattern match
+    // the simple case of BASE + OFFSET.
     if (isa<ConstantSDNode>(Ptr->getOperand(1))) {
       int64_t Offset = cast<ConstantSDNode>(Ptr->getOperand(1))->getSExtValue();
       return  BaseIndexOffset(Ptr->getOperand(0), SDValue(), Offset,
                               IsIndexSignExt);
     }
 
+    // Inside a loop the current BASE pointer is calculated using an ADD and a
+    // MUL instruction. In this case Ptr is the actual BASE pointer.
+    // (i64 add (i64 %array_ptr)
+    //          (i64 mul (i64 %induction_var)
+    //                   (i64 %element_size)))
+    if (Ptr->getOperand(1)->getOpcode() == ISD::MUL)
+      return BaseIndexOffset(Ptr, SDValue(), 0, IsIndexSignExt);
+
     // Look at Base + Index + Offset cases.
     SDValue Base = Ptr->getOperand(0);
     SDValue IndexOffset = Ptr->getOperand(1);
@@ -8412,7 +8483,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
     // transform should not be done in this case.
     if (Value.getOpcode() != ISD::TargetConstantFP) {
       SDValue Tmp;
-      switch (CFP->getValueType(0).getSimpleVT().SimpleTy) {
+      switch (CFP->getSimpleValueType(0).SimpleTy) {
       default: llvm_unreachable("Unknown FP type");
       case MVT::f16:    // We don't do this for these yet.
       case MVT::f80:
@@ -8490,7 +8561,9 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
   if (NewST.getNode())
     return NewST;
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -9865,7 +9938,7 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1,
         SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(),
                                           Cond, One, Zero);
         AddToWorkList(CstOffset.getNode());
-        CPIdx = DAG.getNode(ISD::ADD, DL, TLI.getPointerTy(), CPIdx,
+        CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx,
                             CstOffset);
         AddToWorkList(CPIdx.getNode());
         return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
@@ -10200,7 +10273,9 @@ bool DAGCombiner::isAlias(SDValue Ptr1, int64_t Size1,
       return false;
   }
 
-  if (CombinerGlobalAA) {
+  bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0 ? CombinerGlobalAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA && SrcValue1 && SrcValue2) {
     // Use alias analysis information.
     int64_t MinOffset = std::min(SrcValueOffset1, SrcValueOffset2);
     int64_t Overlap1 = Size1 + SrcValueOffset1 - MinOffset;