From d8018eeac9861e47c4216a10a2296887ed20c19a Mon Sep 17 00:00:00 2001 From: Chandler Carruth Date: Sat, 30 May 2015 04:05:11 +0000 Subject: [PATCH] [x86] Restore the bitcasts I removed when refactoring this to avoid shifting vectors of bytes as x86 doesn't have direct support for that. This removes a bunch of redundant masking in the generated code for SSE2 and SSE3. In order to avoid the really significant code size growth this would have triggered, I also factored the completely repeatative logic for shifting and masking into two lambdas which in turn makes all of this much easier to read IMO. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@238637 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86ISelLowering.cpp | 84 +++++++++++++-------------- test/CodeGen/X86/vector-popcnt-128.ll | 12 +--- 2 files changed, 43 insertions(+), 53 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 49be23a2204..6e37b7355fc 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -17479,7 +17479,6 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL, "Only 128-bit vector bitmath lowering supported."); int VecSize = VT.getSizeInBits(); - int NumElts = VT.getVectorNumElements(); MVT EltVT = VT.getVectorElementType(); int Len = EltVT.getSizeInBits(); @@ -17490,48 +17489,52 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL, // this when we don't have SSSE3 which allows a LUT-based lowering that is // much faster, even faster than using native popcnt instructions. - SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL, - EltVT); - SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL, - EltVT); - SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL, - EltVT); + auto GetShift = [&](unsigned OpCode, SDValue V, int Shifter) { + MVT VT = V.getSimpleValueType(); + SmallVector Shifters( + VT.getVectorNumElements(), + DAG.getConstant(Shifter, DL, VT.getVectorElementType())); + return DAG.getNode(OpCode, DL, VT, V, + DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Shifters)); + }; + auto GetMask = [&](SDValue V, APInt Mask) { + MVT VT = V.getSimpleValueType(); + SmallVector Masks( + VT.getVectorNumElements(), + DAG.getConstant(Mask, DL, VT.getVectorElementType())); + return DAG.getNode(ISD::AND, DL, VT, V, + DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Masks)); + }; + + // We don't want to incur the implicit masks required to SRL vNi8 vectors on + // x86, so set the SRL type to have elements at least i16 wide. This is + // correct because all of our SRLs are followed immediately by a mask anyways + // that handles any bits that sneak into the high bits of the byte elements. + MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16); SDValue V = Op; // v = v - ((v >> 1) & 0x55555555...) - SmallVector Ones(NumElts, DAG.getConstant(1, DL, EltVT)); - SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones); - SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV); - - SmallVector Mask55(NumElts, Cst55); - SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55); - SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55); - + SDValue Srl = DAG.getNode( + ISD::BITCAST, DL, VT, + GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 1)); + SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55))); V = DAG.getNode(ISD::SUB, DL, VT, V, And); // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...) - SmallVector Mask33(NumElts, Cst33); - SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33); - SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33); - - SmallVector Twos(NumElts, DAG.getConstant(2, DL, EltVT)); - SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos); - Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV); - SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33); - + SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33))); + Srl = DAG.getNode( + ISD::BITCAST, DL, VT, + GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 2)); + SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33))); V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS); // v = (v + (v >> 4)) & 0x0F0F0F0F... - SmallVector Fours(NumElts, DAG.getConstant(4, DL, EltVT)); - SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours); - Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV); + Srl = DAG.getNode( + ISD::BITCAST, DL, VT, + GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 4)); SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl); - - SmallVector Mask0F(NumElts, Cst0F); - SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F); - - V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F); + V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F))); // At this point, V contains the byte-wise population count, and we are // merely doing a horizontal sum if necessary to get the wider element @@ -17543,26 +17546,21 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, SDLoc DL, MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8); MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64); V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V); - SmallVector Csts; assert(Len <= 64 && "We don't support element sizes of more than 64 bits!"); assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!"); for (int i = Len; i > 8; i /= 2) { - Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64)); SDValue Shl = DAG.getNode( - ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V), - DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts)); - V = DAG.getNode(ISD::ADD, DL, ByteVT, V, - DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl)); + ISD::BITCAST, DL, ByteVT, + GetShift(ISD::SHL, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V), i / 2)); + V = DAG.getNode(ISD::ADD, DL, ByteVT, V, Shl); } // The high byte now contains the sum of the element bytes. Shift it right // (if needed) to make it the low byte. V = DAG.getNode(ISD::BITCAST, DL, VT, V); - if (Len > 8) { - Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT)); - V = DAG.getNode(ISD::SRL, DL, VT, V, - DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts)); - } + if (Len > 8) + V = GetShift(ISD::SRL, V, Len - 8); + return V; } diff --git a/test/CodeGen/X86/vector-popcnt-128.ll b/test/CodeGen/X86/vector-popcnt-128.ll index dc99fec3d47..f55b054deb0 100644 --- a/test/CodeGen/X86/vector-popcnt-128.ll +++ b/test/CodeGen/X86/vector-popcnt-128.ll @@ -339,21 +339,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) { ; SSE2-NEXT: movdqa %xmm0, %xmm1 ; SSE2-NEXT: psrlw $1, %xmm1 ; SSE2-NEXT: pand {{.*}}(%rip), %xmm1 -; SSE2-NEXT: pand {{.*}}(%rip), %xmm1 ; SSE2-NEXT: psubb %xmm1, %xmm0 ; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51] ; SSE2-NEXT: movdqa %xmm0, %xmm2 ; SSE2-NEXT: pand %xmm1, %xmm2 ; SSE2-NEXT: psrlw $2, %xmm0 -; SSE2-NEXT: pand {{.*}}(%rip), %xmm0 ; SSE2-NEXT: pand %xmm1, %xmm0 ; SSE2-NEXT: paddb %xmm2, %xmm0 ; SSE2-NEXT: movdqa %xmm0, %xmm1 ; SSE2-NEXT: psrlw $4, %xmm1 -; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15] -; SSE2-NEXT: pand %xmm2, %xmm1 ; SSE2-NEXT: paddb %xmm0, %xmm1 -; SSE2-NEXT: pand %xmm2, %xmm1 +; SSE2-NEXT: pand {{.*}}(%rip), %xmm1 ; SSE2-NEXT: movdqa %xmm1, %xmm0 ; SSE2-NEXT: retq ; @@ -362,21 +358,17 @@ define <16 x i8> @testv16i8(<16 x i8> %in) { ; SSE3-NEXT: movdqa %xmm0, %xmm1 ; SSE3-NEXT: psrlw $1, %xmm1 ; SSE3-NEXT: pand {{.*}}(%rip), %xmm1 -; SSE3-NEXT: pand {{.*}}(%rip), %xmm1 ; SSE3-NEXT: psubb %xmm1, %xmm0 ; SSE3-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51] ; SSE3-NEXT: movdqa %xmm0, %xmm2 ; SSE3-NEXT: pand %xmm1, %xmm2 ; SSE3-NEXT: psrlw $2, %xmm0 -; SSE3-NEXT: pand {{.*}}(%rip), %xmm0 ; SSE3-NEXT: pand %xmm1, %xmm0 ; SSE3-NEXT: paddb %xmm2, %xmm0 ; SSE3-NEXT: movdqa %xmm0, %xmm1 ; SSE3-NEXT: psrlw $4, %xmm1 -; SSE3-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15] -; SSE3-NEXT: pand %xmm2, %xmm1 ; SSE3-NEXT: paddb %xmm0, %xmm1 -; SSE3-NEXT: pand %xmm2, %xmm1 +; SSE3-NEXT: pand {{.*}}(%rip), %xmm1 ; SSE3-NEXT: movdqa %xmm1, %xmm0 ; SSE3-NEXT: retq ; -- 2.34.1