AVX-512: implement kunpck intrinsics.
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 364a8c260ba132aa7ce6158de995c79151fd2a98..f38ca2956ff325979dc8e1f154a213f07e9f2e97 100644 (file)
@@ -15998,19 +15998,26 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
   }
 
   if (Mask.getSimpleValueType() == MVT::i64 && Subtarget->is32Bit()) {
-    assert(MaskVT == MVT::v64i1 && "Unexpected mask VT!");
-    assert(Subtarget->hasBWI() && "Expected AVX512BW target!");
-    // In case 32bit mode, bitcast i64 is illegal, extend/split it.
-    SDValue Lo, Hi;
-    Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
-                        DAG.getConstant(0, dl, MVT::i32));
-    Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
-                        DAG.getConstant(1, dl, MVT::i32));
-
-    Lo = DAG.getNode(ISD::BITCAST, dl, MVT::v32i1, Lo);
-    Hi = DAG.getNode(ISD::BITCAST, dl, MVT::v32i1, Hi);
-
-    return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Hi, Lo);
+    if (MaskVT == MVT::v64i1) {
+      assert(Subtarget->hasBWI() && "Expected AVX512BW target!");
+      // In case 32bit mode, bitcast i64 is illegal, extend/split it.
+      SDValue Lo, Hi;
+      Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
+                          DAG.getConstant(0, dl, MVT::i32));
+      Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, MVT::i32, Mask,
+                          DAG.getConstant(1, dl, MVT::i32));
+
+      Lo = DAG.getBitcast(MVT::v32i1, Lo);
+      Hi = DAG.getBitcast(MVT::v32i1, Hi);
+
+      return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
+    } else {
+      // MaskVT require < 64bit. Truncate mask (should succeed in any case),
+      // and bitcast.
+      MVT TruncVT = MVT::getIntegerVT(MaskVT.getSizeInBits());
+      return DAG.getBitcast(MaskVT,
+                            DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Mask));
+    }
 
   } else {
     MVT BitcastVT = MVT::getVectorVT(MVT::i1,
@@ -16600,6 +16607,18 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget *Subtarget
       return DAG.getNode(IntrData->Opc0, dl, VT, VMask, Op.getOperand(1),
                          Op.getOperand(2));
     }
+    case KUNPCK: {
+      MVT VT = Op.getSimpleValueType();
+      MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getSizeInBits()/2);
+
+      SDValue Src1 = getMaskNode(Op.getOperand(1), MaskVT, Subtarget, DAG, dl);
+      SDValue Src2 = getMaskNode(Op.getOperand(2), MaskVT, Subtarget, DAG, dl);
+      // Arguments should be swapped.
+      SDValue Res = DAG.getNode(IntrData->Opc0, dl,
+                                MVT::getVectorVT(MVT::i1, VT.getSizeInBits()),
+                                Src2, Src1);
+      return DAG.getBitcast(VT, Res);
+    }
     default:
       break;
     }
@@ -20001,8 +20020,9 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
     }
   }
   case ISD::INTRINSIC_WO_CHAIN: {
-         Results.push_back(LowerINTRINSIC_WO_CHAIN(SDValue(N, 0), Subtarget, DAG));
-         return;
+    if (SDValue V = LowerINTRINSIC_WO_CHAIN(SDValue(N, 0), Subtarget, DAG))
+      Results.push_back(V);
+    return;
   }
   case ISD::READCYCLECOUNTER: {
     return getReadTimeStampCounter(N, dl, X86ISD::RDTSC_DAG, DAG, Subtarget,