From: Elena Demikhovsky Date: Mon, 7 Dec 2015 13:39:24 +0000 (+0000) Subject: AVX-512: Fixed masked load / store instruction selection for KNL. X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=b06ff9b1e1a0e4e9b3a135efd2803424db3a2abf;p=oota-llvm.git AVX-512: Fixed masked load / store instruction selection for KNL. Patterns were missing for KNL target for <8 x i32>, <8 x float> masked load/store. This intrinsic comes with all legal types: <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 align, <8 x i1> %mask, <8 x float> %passThru), but still requires lowering, because VMASKMOVPS, VMASKMOVDQU32 work with 512-bit vectors only. All data operands should be widened to 512-bit vector. The mask operand should be widened to v16i1 with zeroes. Differential Revision: http://reviews.llvm.org/D15265 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@254909 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 1fb7b160a67..8295b2a19dd 100644 --- a/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -244,7 +244,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { Changed = true; return LegalizeOp(ExpandStore(Op)); } - } else if (Op.getOpcode() == ISD::MSCATTER) + } else if (Op.getOpcode() == ISD::MSCATTER || Op.getOpcode() == ISD::MSTORE) HasVectorValue = true; for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end(); @@ -344,6 +344,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::MSCATTER: QueryType = cast(Node)->getValue().getValueType(); break; + case ISD::MSTORE: + QueryType = cast(Node)->getValue().getValueType(); + break; } switch (TLI.getOperationAction(Node->getOpcode(), QueryType)) { diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index f38ca2956ff..fa6f5c8be88 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -1384,6 +1384,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal); setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal); setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal); + } else { + setOperationAction(ISD::MLOAD, MVT::v8i32, Custom); + setOperationAction(ISD::MLOAD, MVT::v8f32, Custom); + setOperationAction(ISD::MSTORE, MVT::v8i32, Custom); + setOperationAction(ISD::MSTORE, MVT::v8f32, Custom); } setOperationAction(ISD::TRUNCATE, MVT::i1, Custom); setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom); @@ -1459,6 +1464,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v8i1, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v16i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v16i1, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v8i1, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v8i1, Custom); @@ -19685,6 +19691,47 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget *Subtarget, return DAG.getNode(ISD::MERGE_VALUES, dl, Tys, SinVal, CosVal); } +/// Widen a vector input to a vector of NVT. The +/// input vector must have the same element type as NVT. +static SDValue ExtendToType(SDValue InOp, MVT NVT, SelectionDAG &DAG, + bool FillWithZeroes = false) { + // Check if InOp already has the right width. + MVT InVT = InOp.getSimpleValueType(); + if (InVT == NVT) + return InOp; + + if (InOp.isUndef()) + return DAG.getUNDEF(NVT); + + assert(InVT.getVectorElementType() == NVT.getVectorElementType() && + "input and widen element type must match"); + + unsigned InNumElts = InVT.getVectorNumElements(); + unsigned WidenNumElts = NVT.getVectorNumElements(); + assert(WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0 && + "Unexpected request for vector widening"); + + EVT EltVT = NVT.getVectorElementType(); + + SDLoc dl(InOp); + if (ISD::isBuildVectorOfConstantSDNodes(InOp.getNode()) || + ISD::isBuildVectorOfConstantFPSDNodes(InOp.getNode())) { + SmallVector Ops; + for (unsigned i = 0; i < InNumElts; ++i) + Ops.push_back(InOp.getOperand(i)); + + SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, EltVT) : + DAG.getUNDEF(EltVT); + for (unsigned i = 0; i < WidenNumElts - InNumElts; ++i) + Ops.push_back(FillVal); + return DAG.getNode(ISD::BUILD_VECTOR, dl, NVT, Ops); + } + SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, NVT) : + DAG.getUNDEF(NVT); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NVT, FillVal, + InOp, DAG.getIntPtrConstant(0, dl)); +} + static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget, SelectionDAG &DAG) { assert(Subtarget->hasAVX512() && @@ -19714,6 +19761,62 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget, return Op; } +static SDValue LowerMLOAD(SDValue Op, const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + + MaskedLoadSDNode *N = cast(Op.getNode()); + MVT VT = Op.getSimpleValueType(); + SDValue Mask = N->getMask(); + SDLoc dl(Op); + + if (Subtarget->hasAVX512() && !Subtarget->hasVLX() && + !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) { + // This operation is legal for targets with VLX, but without + // VLX the vector should be widened to 512 bit + unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits(); + MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); + SDValue Src0 = N->getSrc0(); + Src0 = ExtendToType(Src0, WideDataVT, DAG); + Mask = ExtendToType(Mask, WideMaskVT, DAG, true); + SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(), + N->getBasePtr(), Mask, Src0, + N->getMemoryVT(), N->getMemOperand(), + N->getExtensionType()); + + SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, + NewLoad.getValue(0), + DAG.getIntPtrConstant(0, dl)); + SDValue RetOps[] = {Exract, NewLoad.getValue(1)}; + return DAG.getMergeValues(RetOps, dl); + } + return Op; +} + +static SDValue LowerMSTORE(SDValue Op, const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + MaskedStoreSDNode *N = cast(Op.getNode()); + SDValue DataToStore = N->getValue(); + MVT VT = DataToStore.getSimpleValueType(); + SDValue Mask = N->getMask(); + SDLoc dl(Op); + + if (Subtarget->hasAVX512() && !Subtarget->hasVLX() && + !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) { + // This operation is legal for targets with VLX, but without + // VLX the vector should be widened to 512 bit + unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits(); + MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec); + MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); + DataToStore = ExtendToType(DataToStore, WideDataVT, DAG); + Mask = ExtendToType(Mask, WideMaskVT, DAG, true); + return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(), + Mask, N->getMemoryVT(), N->getMemOperand(), + N->isTruncatingStore()); + } + return Op; +} + static SDValue LowerMGATHER(SDValue Op, const X86Subtarget *Subtarget, SelectionDAG &DAG) { assert(Subtarget->hasAVX512() && @@ -19873,6 +19976,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::UMAX: case ISD::UMIN: return LowerMINMAX(Op, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG); + case ISD::MLOAD: return LowerMLOAD(Op, Subtarget, DAG); + case ISD::MSTORE: return LowerMSTORE(Op, Subtarget, DAG); case ISD::MGATHER: return LowerMGATHER(Op, Subtarget, DAG); case ISD::MSCATTER: return LowerMSCATTER(Op, Subtarget, DAG); case ISD::GC_TRANSITION_START: diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index 60238f6ab23..58206c6acaa 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -2766,22 +2766,6 @@ def: Pat<(int_x86_avx512_mask_store_pd_512 addr:$ptr, (v8f64 VR512:$src), (VMOVAPDZmrk addr:$ptr, (v8i1 (COPY_TO_REGCLASS GR8:$mask, VK8WM)), VR512:$src)>; -let Predicates = [HasAVX512, NoVLX] in { -def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src)), - (VMOVUPSZmrk addr:$ptr, - (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), - (INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>; - -def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, undef)), - (v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmkz - (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>; - -def: Pat<(v8f32 (masked_load addr:$ptr, VK8WM:$mask, (v8f32 VR256:$src0))), - (v8f32 (EXTRACT_SUBREG (v16f32 (VMOVUPSZrmk - (INSERT_SUBREG (v16f32 (IMPLICIT_DEF)), VR256:$src0, sub_ymm), - (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>; -} - defm VMOVDQA32 : avx512_alignedload_vl<0x6F, "vmovdqa32", avx512vl_i32_info, HasAVX512>, avx512_alignedstore_vl<0x7F, "vmovdqa32", avx512vl_i32_info, @@ -2843,17 +2827,6 @@ def : Pat<(v16i32 (vselect VK16WM:$mask, (v16i32 immAllZerosV), (v16i32 VR512:$src))), (VMOVDQU32Zrrkz (KNOTWrr VK16WM:$mask), VR512:$src)>; } -// NoVLX patterns -let Predicates = [HasAVX512, NoVLX] in { -def: Pat<(X86mstore addr:$ptr, VK8WM:$mask, (v8i32 VR256:$src)), - (VMOVDQU32Zmrk addr:$ptr, - (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), - (INSERT_SUBREG (v16i32 (IMPLICIT_DEF)), VR256:$src, sub_ymm))>; - -def: Pat<(v8i32 (masked_load addr:$ptr, VK8WM:$mask, undef)), - (v8i32 (EXTRACT_SUBREG (v16i32 (VMOVDQU32Zrmkz - (v16i1 (COPY_TO_REGCLASS VK8WM:$mask, VK16WM)), addr:$ptr)), sub_ymm))>; -} // Move Int Doubleword to Packed Double Int // diff --git a/test/CodeGen/X86/masked_memop.ll b/test/CodeGen/X86/masked_memop.ll index a720054c167..1a9cf008e86 100644 --- a/test/CodeGen/X86/masked_memop.ll +++ b/test/CodeGen/X86/masked_memop.ll @@ -139,18 +139,55 @@ define <4 x double> @test10(<4 x i32> %trigger, <4 x double>* %addr, <4 x double ret <4 x double> %res } -; AVX2-LABEL: test11 +; AVX2-LABEL: test11a ; AVX2: vmaskmovps ; AVX2: vblendvps -; SKX-LABEL: test11 -; SKX: vmovaps {{.*}}{%k1} -define <8 x float> @test11(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) { +; SKX-LABEL: test11a +; SKX: vmovaps (%rdi), %ymm1 {%k1} +; AVX512-LABEL: test11a +; AVX512: kshiftlw $8 +; AVX512: kshiftrw $8 +; AVX512: vmovups (%rdi), %zmm1 {%k1} +define <8 x float> @test11a(<8 x i32> %trigger, <8 x float>* %addr, <8 x float> %dst) { %mask = icmp eq <8 x i32> %trigger, zeroinitializer %res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1>%mask, <8 x float>%dst) ret <8 x float> %res } +; SKX-LABEL: test11b +; SKX: vmovdqu32 (%rdi), %ymm1 {%k1} +; AVX512-LABEL: test11b +; AVX512: kshiftlw $8 +; AVX512: kshiftrw $8 +; AVX512: vmovdqu32 (%rdi), %zmm1 {%k1} +define <8 x i32> @test11b(<8 x i1> %mask, <8 x i32>* %addr, <8 x i32> %dst) { + %res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1>%mask, <8 x i32>%dst) + ret <8 x i32> %res +} + +; SKX-LABEL: test11c +; SKX: vmovaps (%rdi), %ymm0 {%k1} {z} +; AVX512-LABEL: test11c +; AVX512: kshiftlw $8 +; AVX512: kshiftrw $8 +; AVX512: vmovups (%rdi), %zmm0 {%k1} {z} +define <8 x float> @test11c(<8 x i1> %mask, <8 x float>* %addr) { + %res = call <8 x float> @llvm.masked.load.v8f32(<8 x float>* %addr, i32 32, <8 x i1> %mask, <8 x float> zeroinitializer) + ret <8 x float> %res +} + +; SKX-LABEL: test11d +; SKX: vmovdqu32 (%rdi), %ymm0 {%k1} {z} +; AVX512-LABEL: test11d +; AVX512: kshiftlw $8 +; AVX512: kshiftrw $8 +; AVX512: vmovdqu32 (%rdi), %zmm0 {%k1} {z} +define <8 x i32> @test11d(<8 x i1> %mask, <8 x i32>* %addr) { + %res = call <8 x i32> @llvm.masked.load.v8i32(<8 x i32>* %addr, i32 4, <8 x i1> %mask, <8 x i32> zeroinitializer) + ret <8 x i32> %res +} + ; AVX2-LABEL: test12 ; AVX2: vpmaskmovd %ymm @@ -291,6 +328,7 @@ declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>) declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>) declare <8 x float> @llvm.masked.load.v8f32(<8 x float>*, i32, <8 x i1>, <8 x float>) +declare <8 x i32> @llvm.masked.load.v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>) declare <4 x float> @llvm.masked.load.v4f32(<4 x float>*, i32, <4 x i1>, <4 x float>) declare <2 x float> @llvm.masked.load.v2f32(<2 x float>*, i32, <2 x i1>, <2 x float>) declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x double>)