[X86][SSE] Vectorized i64 uniform constant SRA shifts
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 6 Jul 2015 22:35:19 +0000 (22:35 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 6 Jul 2015 22:35:19 +0000 (22:35 +0000)
This patch adds vectorization support for uniform constant i64 arithmetic shift right operators.

Differential Revision: http://reviews.llvm.org/D9645

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@241514 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86TargetTransformInfo.cpp
test/Analysis/CostModel/X86/testshiftashr.ll
test/CodeGen/X86/vector-shift-ashr-128.ll
test/CodeGen/X86/vector-shift-ashr-256.ll
test/CodeGen/X86/vshift-3.ll
test/CodeGen/X86/widen_conv-2.ll

index 05b3604f851ce70c73ecb6b3dd50678019403189..a92ab5ae2a07e9a9fa97717b555682d6e66c451d 100644 (file)
@@ -1032,6 +1032,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::SHL,               MVT::v2i64, Custom);
     setOperationAction(ISD::SHL,               MVT::v4i32, Custom);
 
+    setOperationAction(ISD::SRA,               MVT::v2i64, Custom);
     setOperationAction(ISD::SRA,               MVT::v4i32, Custom);
   }
 
@@ -1211,6 +1212,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::SHL,               MVT::v4i64, Custom);
     setOperationAction(ISD::SHL,               MVT::v8i32, Custom);
 
+    setOperationAction(ISD::SRA,               MVT::v4i64, Custom);
     setOperationAction(ISD::SRA,               MVT::v8i32, Custom);
 
     // Custom lower several nodes for 256-bit types.
@@ -16948,6 +16950,38 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
   unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
     (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
 
+  auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) {
+    assert((VT == MVT::v2i64 || VT == MVT::v4i64) && "Unexpected SRA type");
+    MVT ExVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2);
+    SDValue Ex = DAG.getBitcast(ExVT, R);
+
+    if (ShiftAmt >= 32) {
+      // Splat sign to upper i32 dst, and SRA upper i32 src to lower i32.
+      SDValue Upper =
+          getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex, 31, DAG);
+      SDValue Lower = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
+                                                 ShiftAmt - 32, DAG);
+      if (VT == MVT::v2i64)
+        Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {5, 1, 7, 3});
+      if (VT == MVT::v4i64)
+        Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
+                                  {9, 1, 11, 3, 13, 5, 15, 7});
+    } else {
+      // SRA upper i32, SHL whole i64 and select lower i32.
+      SDValue Upper = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
+                                                 ShiftAmt, DAG);
+      SDValue Lower =
+          getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, DAG);
+      Lower = DAG.getBitcast(ExVT, Lower);
+      if (VT == MVT::v2i64)
+        Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {4, 1, 6, 3});
+      if (VT == MVT::v4i64)
+        Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
+                                  {8, 1, 10, 3, 12, 5, 14, 7});
+    }
+    return DAG.getBitcast(VT, Ex);
+  };
+
   // Optimize shl/srl/sra with constant shift amount.
   if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
     if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
@@ -16956,6 +16990,11 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
       if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
         return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
 
+      // i64 SRA needs to be performed as partial shifts.
+      if ((VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64)) &&
+          Op.getOpcode() == ISD::SRA)
+        return ArithmeticShiftRight64(ShiftAmt);
+
       if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
         unsigned NumElts = VT.getVectorNumElements();
         MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
@@ -17039,7 +17078,12 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
       if (ShAmt != ShiftAmt)
         return SDValue();
     }
-    return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
+
+    if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
+      return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
+
+    if (Op.getOpcode() == ISD::SRA)
+      return ArithmeticShiftRight64(ShiftAmt);
   }
 
   return SDValue();
@@ -17121,7 +17165,9 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
         if (Vals[j] != Amt.getOperand(i + j))
           return SDValue();
     }
-    return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
+
+    if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode()))
+      return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
   }
   return SDValue();
 }
index 0c82a700952b8ce6e29ed6560803493ea1514b9c..a541035f13d607b7d8bfc80974c964c58957a116 100644 (file)
@@ -117,6 +117,8 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
 
   static const CostTblEntry<MVT::SimpleValueType>
   AVX2UniformConstCostTable[] = {
+    { ISD::SRA,  MVT::v4i64,   4 }, // 2 x psrad + shuffle.
+
     { ISD::SDIV, MVT::v16i16,  6 }, // vpmulhw sequence
     { ISD::UDIV, MVT::v16i16,  6 }, // vpmulhuw sequence
     { ISD::SDIV, MVT::v8i32,  15 }, // vpmuldq sequence
@@ -211,6 +213,7 @@ unsigned X86TTIImpl::getArithmeticInstrCost(
     { ISD::SRA,  MVT::v16i8,  4 }, // psrlw, pand, pxor, psubb.
     { ISD::SRA,  MVT::v8i16,  1 }, // psraw.
     { ISD::SRA,  MVT::v4i32,  1 }, // psrad.
+    { ISD::SRA,  MVT::v2i64,  4 }, // 2 x psrad + shuffle.
 
     { ISD::SDIV, MVT::v8i16,  6 }, // pmulhw sequence
     { ISD::UDIV, MVT::v8i16,  6 }, // pmulhuw sequence
index ced2ffed455200153aa3ffdb3a22676b6fe617dc..ebb06cc3bba5e4e33ca20cfdb2d08ff8ba1a8d9f 100644 (file)
@@ -247,9 +247,9 @@ entry:
 define %shifttypec @shift2i16const(%shifttypec %a, %shifttypec %b) {
 entry:
   ; SSE2: shift2i16const
-  ; SSE2: cost of 20 {{.*}} ashr
+  ; SSE2: cost of 4 {{.*}} ashr
   ; SSE2-CODEGEN: shift2i16const
-  ; SSE2-CODEGEN: sarq $
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec %a , <i16 3, i16 3>
   ret %shifttypec %0
@@ -320,9 +320,9 @@ entry:
 define %shifttypec2i32 @shift2i32c(%shifttypec2i32 %a, %shifttypec2i32 %b) {
 entry:
   ; SSE2: shift2i32c
-  ; SSE2: cost of 20 {{.*}} ashr
+  ; SSE2: cost of 4 {{.*}} ashr
   ; SSE2-CODEGEN: shift2i32c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec2i32 %a , <i32 3, i32 3>
   ret %shifttypec2i32 %0
@@ -391,9 +391,9 @@ entry:
 define %shifttypec2i64 @shift2i64c(%shifttypec2i64 %a, %shifttypec2i64 %b) {
 entry:
   ; SSE2: shift2i64c
-  ; SSE2: cost of 20 {{.*}} ashr
+  ; SSE2: cost of 4 {{.*}} ashr
   ; SSE2-CODEGEN: shift2i64c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec2i64 %a , <i64 3, i64 3>
   ret %shifttypec2i64 %0
@@ -403,9 +403,9 @@ entry:
 define %shifttypec4i64 @shift4i64c(%shifttypec4i64 %a, %shifttypec4i64 %b) {
 entry:
   ; SSE2: shift4i64c
-  ; SSE2: cost of 40 {{.*}} ashr
+  ; SSE2: cost of 8 {{.*}} ashr
   ; SSE2-CODEGEN: shift4i64c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec4i64 %a , <i64 3, i64 3, i64 3, i64 3>
   ret %shifttypec4i64 %0
@@ -415,9 +415,9 @@ entry:
 define %shifttypec8i64 @shift8i64c(%shifttypec8i64 %a, %shifttypec8i64 %b) {
 entry:
   ; SSE2: shift8i64c
-  ; SSE2: cost of 80 {{.*}} ashr
+  ; SSE2: cost of 16 {{.*}} ashr
   ; SSE2-CODEGEN: shift8i64c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
  %0 = ashr %shifttypec8i64 %a , <i64 3, i64 3, i64 3, i64 3,
                                  i64 3, i64 3, i64 3, i64 3>
@@ -428,9 +428,9 @@ entry:
 define %shifttypec16i64 @shift16i64c(%shifttypec16i64 %a, %shifttypec16i64 %b) {
 entry:
   ; SSE2: shift16i64c
-  ; SSE2: cost of 160 {{.*}} ashr
+  ; SSE2: cost of 32 {{.*}} ashr
   ; SSE2-CODEGEN: shift16i64c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec16i64 %a , <i64 3, i64 3, i64 3, i64 3,
                                    i64 3, i64 3, i64 3, i64 3,
@@ -443,9 +443,9 @@ entry:
 define %shifttypec32i64 @shift32i64c(%shifttypec32i64 %a, %shifttypec32i64 %b) {
 entry:
   ; SSE2: shift32i64c
-  ; SSE2: cost of 320 {{.*}} ashr
+  ; SSE2: cost of 64 {{.*}} ashr
   ; SSE2-CODEGEN: shift32i64c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec32i64 %a ,<i64 3, i64 3, i64 3, i64 3,
                                   i64 3, i64 3, i64 3, i64 3,
@@ -462,9 +462,9 @@ entry:
 define %shifttypec2i8 @shift2i8c(%shifttypec2i8 %a, %shifttypec2i8 %b) {
 entry:
   ; SSE2: shift2i8c
-  ; SSE2: cost of 20 {{.*}} ashr
+  ; SSE2: cost of 4 {{.*}} ashr
   ; SSE2-CODEGEN: shift2i8c
-  ; SSE2-CODEGEN: sarq $3
+  ; SSE2-CODEGEN: psrad $3
 
   %0 = ashr %shifttypec2i8 %a , <i8 3, i8 3>
   ret %shifttypec2i8 %0
index 4fd2f8b51b8b2b3ed5488256ec15a40010fd82cb..0e7ca6325d9d6404903396cf3735869fc2f22ae3 100644 (file)
@@ -954,38 +954,35 @@ define <16 x i8> @constant_shift_v16i8(<16 x i8> %a) {
 define <2 x i64> @splatconstant_shift_v2i64(<2 x i64> %a) {
 ; SSE2-LABEL: splatconstant_shift_v2i64:
 ; SSE2:       # BB#0:
-; SSE2-NEXT:    movd       %xmm0, %rax
-; SSE2-NEXT:    sarq       $7, %rax
-; SSE2-NEXT:    movd       %rax, %xmm1
-; SSE2-NEXT:    pshufd     {{.*#+}} xmm0 = xmm0[2,3,0,1]
-; SSE2-NEXT:    movd       %xmm0, %rax
-; SSE2-NEXT:    sarq       $7, %rax
-; SSE2-NEXT:    movd       %rax, %xmm0
-; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
-; SSE2-NEXT:    movdqa     %xmm1, %xmm0
+; SSE2-NEXT:    movdqa    %xmm0, %xmm1
+; SSE2-NEXT:    psrad     $7, %xmm1
+; SSE2-NEXT:    pshufd    {{.*#+}} xmm1 = xmm1[1,3,2,3]
+; SSE2-NEXT:    psrlq     $7, %xmm0
+; SSE2-NEXT:    pshufd    {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; SSE2-NEXT:    punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
 ; SSE2-NEXT:    retq
 ;
 ; SSE41-LABEL: splatconstant_shift_v2i64:
 ; SSE41:       # BB#0:
-; SSE41-NEXT:    pextrq     $1, %xmm0, %rax
-; SSE41-NEXT:    sarq       $7, %rax
-; SSE41-NEXT:    movd       %rax, %xmm1
-; SSE41-NEXT:    movd       %xmm0, %rax
-; SSE41-NEXT:    sarq       $7, %rax
-; SSE41-NEXT:    movd       %rax, %xmm0
-; SSE41-NEXT:    punpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; SSE41-NEXT:    movdqa  %xmm0, %xmm1
+; SSE41-NEXT:    psrad   $7, %xmm1
+; SSE41-NEXT:    psrlq   $7, %xmm0
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
 ; SSE41-NEXT:    retq
 ;
-; AVX-LABEL: splatconstant_shift_v2i64:
-; AVX:       # BB#0:
-; AVX-NEXT:    vpextrq     $1, %xmm0, %rax
-; AVX-NEXT:    sarq        $7, %rax
-; AVX-NEXT:    vmovq       %rax, %xmm1
-; AVX-NEXT:    vmovq       %xmm0, %rax
-; AVX-NEXT:    sarq        $7, %rax
-; AVX-NEXT:    vmovq       %rax, %xmm0
-; AVX-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
-; AVX-NEXT:    retq
+; AVX1-LABEL: splatconstant_shift_v2i64:
+; AVX1:       # BB#0:
+; AVX1-NEXT:    vpsrad   $7, %xmm0, %xmm1
+; AVX1-NEXT:    vpsrlq   $7, %xmm0, %xmm0
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: splatconstant_shift_v2i64:
+; AVX2:       # BB#0:
+; AVX2-NEXT:    vpsrad   $7, %xmm0, %xmm1
+; AVX2-NEXT:    vpsrlq   $7, %xmm0, %xmm0
+; AVX2-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; AVX2-NEXT:    retq
   %shift = ashr <2 x i64> %a, <i64 7, i64 7>
   ret <2 x i64> %shift
 }
index 3fc377af56500932e98892c001100a7b3d08f5f5..89996bb20418def846796189e9243b8cfbcdbc7a 100644 (file)
@@ -663,41 +663,20 @@ define <4 x i64> @splatconstant_shift_v4i64(<4 x i64> %a) {
 ; AVX1-LABEL: splatconstant_shift_v4i64:
 ; AVX1:       # BB#0:
 ; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm1
-; AVX1-NEXT:    vpextrq $1, %xmm1, %rax
-; AVX1-NEXT:    sarq $7, %rax
-; AVX1-NEXT:    vmovq %rax, %xmm2
-; AVX1-NEXT:    vmovq %xmm1, %rax
-; AVX1-NEXT:    sarq $7, %rax
-; AVX1-NEXT:    vmovq %rax, %xmm1
-; AVX1-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; AVX1-NEXT:    vpextrq $1, %xmm0, %rax
-; AVX1-NEXT:    sarq $7, %rax
-; AVX1-NEXT:    vmovq %rax, %xmm2
-; AVX1-NEXT:    vmovq %xmm0, %rax
-; AVX1-NEXT:    sarq $7, %rax
-; AVX1-NEXT:    vmovq %rax, %xmm0
-; AVX1-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX1-NEXT:    vpsrad       $7, %xmm1, %xmm2
+; AVX1-NEXT:    vpsrlq       $7, %xmm1, %xmm1
+; AVX1-NEXT:    vpblendw     {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3],xmm1[4,5],xmm2[6,7]
+; AVX1-NEXT:    vpsrad       $7, %xmm0, %xmm2
+; AVX1-NEXT:    vpsrlq       $7, %xmm0, %xmm0
+; AVX1-NEXT:    vpblendw     {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7]
+; AVX1-NEXT:    vinsertf128  $1, %xmm1, %ymm0, %ymm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: splatconstant_shift_v4i64:
 ; AVX2:       # BB#0:
-; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
-; AVX2-NEXT:    vpextrq $1, %xmm1, %rax
-; AVX2-NEXT:    sarq $7, %rax
-; AVX2-NEXT:    vmovq %rax, %xmm2
-; AVX2-NEXT:    vmovq %xmm1, %rax
-; AVX2-NEXT:    sarq $7, %rax
-; AVX2-NEXT:    vmovq %rax, %xmm1
-; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; AVX2-NEXT:    vpextrq $1, %xmm0, %rax
-; AVX2-NEXT:    sarq $7, %rax
-; AVX2-NEXT:    vmovq %rax, %xmm2
-; AVX2-NEXT:    vmovq %xmm0, %rax
-; AVX2-NEXT:    sarq $7, %rax
-; AVX2-NEXT:    vmovq %rax, %xmm0
-; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm2[0]
-; AVX2-NEXT:    vinserti128 $1, %xmm1, %ymm0, %ymm0
+; AVX2-NEXT:    vpsrad   $7, %ymm0, %ymm1
+; AVX2-NEXT:    vpsrlq   $7, %ymm0, %ymm0
+; AVX2-NEXT:    vpblendd {{.*#+}} ymm0 = ymm0[0],ymm1[1],ymm0[2],ymm1[3],ymm0[4],ymm1[5],ymm0[6],ymm1[7]
 ; AVX2-NEXT:    retq
   %shift = ashr <4 x i64> %a, <i64 7, i64 7, i64 7, i64 7>
   ret <4 x i64> %shift
index 0bdb32fcb86e13638bd5245940e39c319ee65d97..f368029e4b4947101eb62f158c243f4a4af818df 100644 (file)
@@ -3,13 +3,12 @@
 ; test vector shifts converted to proper SSE2 vector shifts when the shift
 ; amounts are the same.
 
-; Note that x86 does have ashr 
+; Note that x86 does have ashr
 
-; shift1a can't use a packed shift
 define void @shift1a(<2 x i64> %val, <2 x i64>* %dst) nounwind {
 entry:
 ; CHECK-LABEL: shift1a:
-; CHECK: sarl
+; CHECK: psrad $31
   %ashr = ashr <2 x i64> %val, < i64 32, i64 32 >
   store <2 x i64> %ashr, <2 x i64>* %dst
   ret void
index 906f7cdafb9580146f8c9a13769afed0da8cbaf7..c8646c6489a15796edaa195b95e8792d6c00b908 100644 (file)
@@ -1,8 +1,9 @@
 ; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s
-; CHECK: {{cwtl|movswl}}
-; CHECK: {{cwtl|movswl}}
+; CHECK: psllq $48, %xmm0
+; CHECK: psrad $16, %xmm0
+; CHECK: pshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
 
-; sign extension v2i32 to v2i16
+; sign extension v2i16 to v2i32
 
 define void @convert(<2 x i32>* %dst.addr, <2 x i16> %src) nounwind {
 entry: