Teach the X86 instruction selection to do some heroic transforms to
authorChandler Carruth <chandlerc@gmail.com>
Wed, 11 Jan 2012 08:41:08 +0000 (08:41 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Wed, 11 Jan 2012 08:41:08 +0000 (08:41 +0000)
detect a pattern which can be implemented with a small 'shl' embedded in
the addressing mode scale. This happens in real code as follows:

  unsigned x = my_accelerator_table[input >> 11];

Here we have some lookup table that we look into using the high bits of
'input'. Each entity in the table is 4-bytes, which means this
implicitly gets turned into (once lowered out of a GEP):

  *(unsigned*)((char*)my_accelerator_table + ((input >> 11) << 2));

The shift right followed by a shift left is canonicalized to a smaller
shift right and masking off the low bits. That hides the shift right
which x86 has an addressing mode designed to support. We now detect
masks of this form, and produce the longer shift right followed by the
proper addressing mode. In addition to saving a (rather large)
instruction, this also reduces stalls in Intel chips on benchmarks I've
measured.

In order for all of this to work, one part of the DAG needs to be
canonicalized *still further* than it currently is. This involves
removing pointless 'trunc' nodes between a zextload and a zext. Without
that, we end up generating spurious masks and hiding the pattern.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@147936 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/Target/X86/X86ISelDAGToDAG.cpp
test/CodeGen/X86/fold-and-shift.ll

index 16bea7b9fbc5840cb70c406b0148570ed0eb565b..6d35b1200e06ca6fdd88d68bef9d00b78e868dda 100644 (file)
@@ -4254,6 +4254,29 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
     return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT,
                        N0.getOperand(0));
 
+  // fold (zext (truncate x)) -> (zext x) or
+  //      (zext (truncate x)) -> (truncate x)
+  // This is valid when the truncated bits of x are already zero.
+  // FIXME: We should extend this to work for vectors too.
+  if (N0.getOpcode() == ISD::TRUNCATE && !VT.isVector()) {
+    SDValue Op = N0.getOperand(0);
+    APInt TruncatedBits
+      = APInt::getBitsSet(Op.getValueSizeInBits(),
+                          N0.getValueSizeInBits(),
+                          std::min(Op.getValueSizeInBits(),
+                                   VT.getSizeInBits()));
+    APInt KnownZero, KnownOne;
+    DAG.ComputeMaskedBits(Op, TruncatedBits, KnownZero, KnownOne);
+    if (TruncatedBits == KnownZero) {
+      if (VT.bitsGT(Op.getValueType()))
+        return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op);
+      if (VT.bitsLT(Op.getValueType()))
+        return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, Op);
+
+      return Op;
+    }
+  }
+
   // fold (zext (truncate (load x))) -> (zext (smaller load x))
   // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n)))
   if (N0.getOpcode() == ISD::TRUNCATE) {
index 1cc19b25d5e41999712d8ca4facb571ea5714701..64820215aa45cee0dfd31003d33b2d7b27d8172e 100644 (file)
@@ -725,6 +725,140 @@ bool X86DAGToDAGISel::MatchAddress(SDValue N, X86ISelAddressMode &AM) {
   return false;
 }
 
+// Implement some heroics to detect shifts of masked values where the mask can
+// be replaced by extending the shift and undoing that in the addressing mode
+// scale. Patterns such as (shl (srl x, c1), c2) are canonicalized into (and
+// (srl x, SHIFT), MASK) by DAGCombines that don't know the shl can be done in
+// the addressing mode. This results in code such as:
+//
+//   int f(short *y, int *lookup_table) {
+//     ...
+//     return *y + lookup_table[*y >> 11];
+//   }
+//
+// Turning into:
+//   movzwl (%rdi), %eax
+//   movl %eax, %ecx
+//   shrl $11, %ecx
+//   addl (%rsi,%rcx,4), %eax
+//
+// Instead of:
+//   movzwl (%rdi), %eax
+//   movl %eax, %ecx
+//   shrl $9, %ecx
+//   andl $124, %rcx
+//   addl (%rsi,%rcx), %eax
+//
+static bool FoldMaskAndShiftToScale(SelectionDAG &DAG, SDValue N,
+                                    X86ISelAddressMode &AM) {
+  // Scale must not be used already.
+  if (AM.IndexReg.getNode() != 0 || AM.Scale != 1) return true;
+
+  SDValue Shift = N;
+  SDValue And = N.getOperand(0);
+  if (N.getOpcode() != ISD::SRL)
+    std::swap(Shift, And);
+  if (Shift.getOpcode() != ISD::SRL || And.getOpcode() != ISD::AND ||
+      !Shift.hasOneUse() ||
+      !isa<ConstantSDNode>(Shift.getOperand(1)) ||
+      !isa<ConstantSDNode>(And.getOperand(1)))
+    return true;
+  SDValue X = (N == Shift ? And.getOperand(0) : Shift.getOperand(0));
+
+  // We only handle up to 64-bit values here as those are what matter for
+  // addressing mode optimizations.
+  if (X.getValueSizeInBits() > 64) return true;
+
+  uint64_t Mask = And.getConstantOperandVal(1);
+  unsigned ShiftAmt = Shift.getConstantOperandVal(1);
+  unsigned MaskLZ = CountLeadingZeros_64(Mask);
+  unsigned MaskTZ = CountTrailingZeros_64(Mask);
+
+  // The amount of shift we're trying to fit into the addressing mode is taken
+  // from the trailing zeros of the mask. If the mask is pre-shift, we subtract
+  // the shift amount.
+  int AMShiftAmt = MaskTZ - (N == Shift ? ShiftAmt : 0);
+
+  // There is nothing we can do here unless the mask is removing some bits.
+  // Also, the addressing mode can only represent shifts of 1, 2, or 3 bits.
+  if (AMShiftAmt <= 0 || AMShiftAmt > 3) return true;
+
+  // We also need to ensure that mask is a continuous run of bits.
+  if (CountTrailingOnes_64(Mask >> MaskTZ) + MaskTZ + MaskLZ != 64) return true;
+
+  // Scale the leading zero count down based on the actual size of the value.
+  // Also scale it down based on the size of the shift if it was applied
+  // before the mask.
+  MaskLZ -= (64 - X.getValueSizeInBits()) + (N == Shift ? 0 : ShiftAmt);
+
+  // The final check is to ensure that any masked out high bits of X are
+  // already known to be zero. Otherwise, the mask has a semantic impact
+  // other than masking out a couple of low bits. Unfortunately, because of
+  // the mask, zero extensions will be removed from operands in some cases.
+  // This code works extra hard to look through extensions because we can
+  // replace them with zero extensions cheaply if necessary.
+  bool ReplacingAnyExtend = false;
+  if (X.getOpcode() == ISD::ANY_EXTEND) {
+    unsigned ExtendBits =
+      X.getValueSizeInBits() - X.getOperand(0).getValueSizeInBits();
+    // Assume that we'll replace the any-extend with a zero-extend, and
+    // narrow the search to the extended value.
+    X = X.getOperand(0);
+    MaskLZ = ExtendBits > MaskLZ ? 0 : MaskLZ - ExtendBits;
+    ReplacingAnyExtend = true;
+  }
+  APInt MaskedHighBits = APInt::getHighBitsSet(X.getValueSizeInBits(),
+                                               MaskLZ);
+  APInt KnownZero, KnownOne;
+  DAG.ComputeMaskedBits(X, MaskedHighBits, KnownZero, KnownOne);
+  if (MaskedHighBits != KnownZero) return true;
+
+  // We've identified a pattern that can be transformed into a single shift
+  // and an addressing mode. Make it so.
+  EVT VT = N.getValueType();
+  if (ReplacingAnyExtend) {
+    assert(X.getValueType() != VT);
+    // We looked through an ANY_EXTEND node, insert a ZERO_EXTEND.
+    SDValue NewX = DAG.getNode(ISD::ZERO_EXTEND, X.getDebugLoc(), VT, X);
+    if (NewX.getNode()->getNodeId() == -1 ||
+        NewX.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+      DAG.RepositionNode(N.getNode(), NewX.getNode());
+      NewX.getNode()->setNodeId(N.getNode()->getNodeId());
+    }
+    X = NewX;
+  }
+  DebugLoc DL = N.getDebugLoc();
+  SDValue NewSRLAmt = DAG.getConstant(ShiftAmt + AMShiftAmt, MVT::i8);
+  SDValue NewSRL = DAG.getNode(ISD::SRL, DL, VT, X, NewSRLAmt);
+  SDValue NewSHLAmt = DAG.getConstant(AMShiftAmt, MVT::i8);
+  SDValue NewSHL = DAG.getNode(ISD::SHL, DL, VT, NewSRL, NewSHLAmt);
+  if (NewSRLAmt.getNode()->getNodeId() == -1 ||
+      NewSRLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSRLAmt.getNode());
+    NewSRLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSRL.getNode()->getNodeId() == -1 ||
+      NewSRL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSRL.getNode());
+    NewSRL.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSHLAmt.getNode()->getNodeId() == -1 ||
+      NewSHLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSHLAmt.getNode());
+    NewSHLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSHL.getNode()->getNodeId() == -1 ||
+      NewSHL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSHL.getNode());
+    NewSHL.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  DAG.ReplaceAllUsesWith(N, NewSHL);
+
+  AM.Scale = 1 << AMShiftAmt;
+  AM.IndexReg = NewSRL;
+  return false;
+}
+
 bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
                                               unsigned Depth) {
   DebugLoc dl = N.getDebugLoc();
@@ -814,6 +948,13 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
     break;
     }
 
+  case ISD::SRL:
+    // Try to fold the mask and shift into the scale, and return false if we
+    // succeed.
+    if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
+      return false;
+    break;
+
   case ISD::SMUL_LOHI:
   case ISD::UMUL_LOHI:
     // A mul_lohi where we need the low part can be folded as a plain multiply.
@@ -1047,6 +1188,11 @@ bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
       }
     }
 
+    // Try to fold the mask and shift into the scale, and return false if we
+    // succeed.
+    if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
+      return false;
+
     // Handle "(X << C1) & C2" as "(X & (C2>>C1)) << C1" if safe and if this
     // allows us to fold the shift into this addressing mode.
     if (Shift.getOpcode() != ISD::SHL) break;
index c42a421a7c48d5a24b85582f38b2233e833d2a92..93baa0e0eee0ba636858da5af144b0ebd78a11e6 100644 (file)
@@ -31,3 +31,47 @@ entry:
   %tmp9 = load i32* %tmp78
   ret i32 %tmp9
 }
+
+define i32 @t3(i16* %i.ptr, i32* %arr) {
+; This case is tricky. The lshr followed by a gep will produce a lshr followed
+; by an and to remove the low bits. This can be simplified by doing the lshr by
+; a greater constant and using the addressing mode to scale the result back up.
+; To make matters worse, because of the two-phase zext of %i and their reuse in
+; the function, the DAG can get confusing trying to re-use both of them and
+; prevent easy analysis of the mask in order to match this.
+; CHECK: t3:
+; CHECK-NOT: and
+; CHECK: shrl
+; CHECK: addl (%{{...}},%{{...}},4),
+; CHECK: ret
+
+entry:
+  %i = load i16* %i.ptr
+  %i.zext = zext i16 %i to i32
+  %index = lshr i32 %i.zext, 11
+  %val.ptr = getelementptr inbounds i32* %arr, i32 %index
+  %val = load i32* %val.ptr
+  %sum = add i32 %val, %i.zext
+  ret i32 %sum
+}
+
+define i32 @t4(i16* %i.ptr, i32* %arr) {
+; A version of @t3 that has more zero extends and more re-use of intermediate
+; values. This exercise slightly different bits of canonicalization.
+; CHECK: t4:
+; CHECK-NOT: and
+; CHECK: shrl
+; CHECK: addl (%{{...}},%{{...}},4),
+; CHECK: ret
+
+entry:
+  %i = load i16* %i.ptr
+  %i.zext = zext i16 %i to i32
+  %index = lshr i32 %i.zext, 11
+  %index.zext = zext i32 %index to i64
+  %val.ptr = getelementptr inbounds i32* %arr, i64 %index.zext
+  %val = load i32* %val.ptr
+  %sum.1 = add i32 %val, %i.zext
+  %sum.2 = add i32 %sum.1, %index
+  ret i32 %sum.2
+}