(i32 sext_in_reg (i32 aext (i16 x)), i16) -> (i32 sext x). No known test case until...
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 671c5077056462564efeb93c31ca7c318cb4cbd1..9ba9bb5e38d0dac4cde7e98930b99de4872c96ad 100644 (file)
@@ -129,6 +129,7 @@ namespace {
     bool CombineToPreIndexedLoadStore(SDNode *N);
     bool CombineToPostIndexedLoadStore(SDNode *N);
 
+    SDValue PromoteIntBinOp(SDValue Op);
 
     /// combine - call the node-specific routine that knows how to fold each
     /// particular type of node. If that doesn't do anything, try the
@@ -633,6 +634,46 @@ bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) {
   return true;
 }
 
+static SDValue PromoteOperand(SDValue Op, EVT PVT, SelectionDAG &DAG) {
+  unsigned Opc = ISD::ZERO_EXTEND;
+  if (Op.getOpcode() == ISD::Constant) {
+    // Zero extend things like i1, sign extend everything else.  It shouldn't
+    // matter in theory which one we pick, but this tends to give better code?
+    // See DAGTypeLegalizer::PromoteIntRes_Constant.
+    if (Op.getValueType().isByteSized())
+      Opc = ISD::SIGN_EXTEND;
+  }
+  return DAG.getNode(Opc, Op.getDebugLoc(), PVT, Op);
+}
+
+/// PromoteIntBinOp - Promote the specified integer binary operation if the
+/// target indicates it is beneficial. e.g. On x86, it's usually better to
+/// promote i16 operations to i32 since i16 instructions are longer.
+SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
+  if (!LegalOperations)
+    return SDValue();
+
+  EVT VT = Op.getValueType();
+  if (VT.isVector() || !VT.isInteger())
+    return SDValue();
+
+  EVT PVT = VT;
+  if (TLI.PerformDAGCombinePromotion(Op, PVT)) {
+    assert(PVT != VT && "Don't know what type to promote to!");
+
+    SDValue N0 = PromoteOperand(Op.getOperand(0), PVT, DAG);
+    AddToWorkList(N0.getNode());
+
+    SDValue N1 = PromoteOperand(Op.getOperand(1), PVT, DAG);
+    AddToWorkList(N1.getNode());
+
+    DebugLoc dl = Op.getDebugLoc();
+    return DAG.getNode(ISD::TRUNCATE, dl, VT,
+                       DAG.getNode(Op.getOpcode(), dl, PVT, N0, N1));
+  }
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Main DAG Combiner implementation
 //===----------------------------------------------------------------------===//
@@ -1112,7 +1153,7 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
                                        N0.getOperand(0).getOperand(1),
                                        N0.getOperand(1)));
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitADDC(SDNode *N) {
@@ -1250,7 +1291,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
                                  VT);
     }
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitMUL(SDNode *N) {
@@ -1343,7 +1384,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
   if (RMUL.getNode() != 0)
     return RMUL;
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitSDIV(SDNode *N) {
@@ -1987,7 +2028,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     }
   }
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 SDValue DAGCombiner::visitOR(SDNode *N) {
@@ -2113,7 +2154,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
   if (SDNode *Rot = MatchRotate(N0, N1, N->getDebugLoc()))
     return SDValue(Rot, 0);
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 /// MatchRotateHalf - Match "(X shl/srl V1) & V2" where V2 may not be present.
@@ -2422,7 +2463,7 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
       SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
 
-  return SDValue();
+  return PromoteIntBinOp(SDValue(N, 0));
 }
 
 /// visitShiftByConstant - Handle transforms common to the three shifts, when
@@ -2735,6 +2776,15 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
     return DAG.getNode(ISD::SRL, N->getDebugLoc(), VT, N0.getOperand(0),
                        DAG.getConstant(c1 + c2, N1.getValueType()));
   }
+  
+  // fold (srl (shl x, c), c) -> (and x, cst2)
+  if (N1C && N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
+      N0.getValueSizeInBits() <= 64) {
+    uint64_t ShAmt = N1C->getZExtValue()+64-N0.getValueSizeInBits();
+    return DAG.getNode(ISD::AND, N->getDebugLoc(), VT, N0.getOperand(0),
+                       DAG.getConstant(~0ULL >> ShAmt, VT));
+  }
+  
 
   // fold (srl (anyextend x), c) -> (anyextend (srl x, c))
   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
@@ -3628,7 +3678,7 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
   // Do not generate loads of non-round integer types since these can
   // be expensive (and would be wrong if the type is not byte sized).
   if (isa<LoadSDNode>(N0) && N0.hasOneUse() && ExtVT.isRound() &&
-      cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits() > EVTBits &&
+      cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits() >= EVTBits &&
       // Do not change the width of a volatile load.
       !cast<LoadSDNode>(N0)->isVolatile()) {
     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
@@ -3698,7 +3748,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
   // if x is small enough.
   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
     SDValue N00 = N0.getOperand(0);
-    if (N00.getValueType().getScalarType().getSizeInBits() < EVTBits)
+    if (N00.getValueType().getScalarType().getSizeInBits() <= EVTBits &&
+        (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
       return DAG.getNode(ISD::SIGN_EXTEND, N->getDebugLoc(), VT, N00, N1);
   }
 
@@ -5163,12 +5214,25 @@ CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
       !ISD::isNormalLoad(V->getOperand(0).getNode()))
     return Result;
   
-  // Check the chain and pointer.  The store should be chained directly to the
-  // load (TODO: Or through a TF node!) since it's to the same address.
+  // Check the chain and pointer.
   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
-  if (LD->getBasePtr() != Ptr ||
-      V->getOperand(0).getNode() != Chain.getNode())
-    return Result;
+  if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
+  
+  // The store should be chained directly to the load or be an operand of a
+  // tokenfactor.
+  if (LD == Chain.getNode())
+    ; // ok.
+  else if (Chain->getOpcode() != ISD::TokenFactor)
+    return Result; // Fail.
+  else {
+    bool isOk = false;
+    for (unsigned i = 0, e = Chain->getNumOperands(); i != e; ++i)
+      if (Chain->getOperand(i).getNode() == LD) {
+        isOk = true;
+        break;
+      }
+    if (!isOk) return Result;
+  }
   
   // This only handles simple types.
   if (V.getValueType() != MVT::i16 &&