[SDAG] Handle LowerOperation returning its input consistently
[oota-llvm.git] / lib / CodeGen / SelectionDAG / LegalizeVectorTypes.cpp
index 27f63d278232fedb6f299eb929791d47eed9978a..63671f75bf309f877fdd28c29985c74a8bfc7791 100644 (file)
@@ -597,6 +597,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::LOAD:
     SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
     break;
+  case ISD::MLOAD:
+    SplitVecRes_MLOAD(cast<MaskedLoadSDNode>(N), Lo, Hi);
+    break;
   case ISD::SETCC:
     SplitVecRes_SETCC(N, Lo, Hi);
     break;
@@ -979,6 +982,67 @@ void DAGTypeLegalizer::SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo,
   ReplaceValueWith(SDValue(LD, 1), Ch);
 }
 
+void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
+                                         SDValue &Lo, SDValue &Hi) {
+  EVT LoVT, HiVT;
+  SDLoc dl(MLD);
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MLD->getValueType(0));
+
+  SDValue Ch = MLD->getChain();
+  SDValue Ptr = MLD->getBasePtr();
+  SDValue Mask = MLD->getMask();
+  unsigned Alignment = MLD->getOriginalAlignment();
+  ISD::LoadExtType ExtType = MLD->getExtensionType();
+
+  // if Alignment is equal to the vector size,
+  // take the half of it for the second part
+  unsigned SecondHalfAlignment =
+    (Alignment == MLD->getValueType(0).getSizeInBits()/8) ?
+     Alignment/2 : Alignment;
+
+  SDValue MaskLo, MaskHi;
+  std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl);
+
+  EVT MemoryVT = MLD->getMemoryVT();
+  EVT LoMemVT, HiMemVT;
+  std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+
+  SDValue Src0 = MLD->getSrc0();
+  SDValue Src0Lo, Src0Hi;
+  std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, dl);
+
+  MachineMemOperand *MMO = DAG.getMachineFunction().
+    getMachineMemOperand(MLD->getPointerInfo(), 
+                         MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
+                         Alignment, MLD->getAAInfo(), MLD->getRanges());
+
+  Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, MaskLo, Src0Lo, LoMemVT, MMO,
+                         ExtType);
+
+  unsigned IncrementSize = LoMemVT.getSizeInBits()/8;
+  Ptr = DAG.getNode(ISD::ADD, dl, Ptr.getValueType(), Ptr,
+                    DAG.getConstant(IncrementSize, Ptr.getValueType()));
+
+  MMO = DAG.getMachineFunction().
+    getMachineMemOperand(MLD->getPointerInfo(), 
+                         MachineMemOperand::MOLoad,  HiMemVT.getStoreSize(),
+                         SecondHalfAlignment, MLD->getAAInfo(), MLD->getRanges());
+
+  Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, MaskHi, Src0Hi, HiMemVT, MMO,
+                         ExtType);
+
+
+  // Build a factor node to remember that this load is independent of the
+  // other one.
+  Ch = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1),
+                   Hi.getValue(1));
+
+  // Legalized the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(MLD, 1), Ch);
+
+}
+
 void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
   assert(N->getValueType(0).isVector() &&
          N->getOperand(0).getValueType().isVector() &&
@@ -1234,6 +1298,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     case ISD::STORE:
       Res = SplitVecOp_STORE(cast<StoreSDNode>(N), OpNo);
       break;
+    case ISD::MSTORE:
+      Res = SplitVecOp_MSTORE(cast<MaskedStoreSDNode>(N), OpNo);
+      break;
     case ISD::VSELECT:
       Res = SplitVecOp_VSELECT(N, OpNo);
       break;
@@ -1395,6 +1462,58 @@ SDValue DAGTypeLegalizer::SplitVecOp_EXTRACT_VECTOR_ELT(SDNode *N) {
                         MachinePointerInfo(), EltVT, false, false, false, 0);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_MSTORE(MaskedStoreSDNode *N,
+                                            unsigned OpNo) {
+  SDValue Ch  = N->getChain();
+  SDValue Ptr = N->getBasePtr();
+  SDValue Mask = N->getMask();
+  SDValue Data = N->getValue();
+  EVT MemoryVT = N->getMemoryVT();
+  unsigned Alignment = N->getOriginalAlignment();
+  SDLoc DL(N);
+  
+  EVT LoMemVT, HiMemVT;
+  std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+
+  SDValue DataLo, DataHi;
+  GetSplitVector(Data, DataLo, DataHi);
+  SDValue MaskLo, MaskHi;
+  GetSplitVector(Mask, MaskLo, MaskHi);
+
+  // if Alignment is equal to the vector size,
+  // take the half of it for the second part
+  unsigned SecondHalfAlignment =
+    (Alignment == Data->getValueType(0).getSizeInBits()/8) ?
+       Alignment/2 : Alignment;
+
+  SDValue Lo, Hi;
+  MachineMemOperand *MMO = DAG.getMachineFunction().
+    getMachineMemOperand(N->getPointerInfo(), 
+                         MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
+                         Alignment, N->getAAInfo(), N->getRanges());
+
+  Lo = DAG.getMaskedStore(Ch, DL, DataLo, Ptr, MaskLo, LoMemVT, MMO,
+                          N->isTruncatingStore());
+
+  unsigned IncrementSize = LoMemVT.getSizeInBits()/8;
+  Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
+                    DAG.getConstant(IncrementSize, Ptr.getValueType()));
+
+  MMO = DAG.getMachineFunction().
+    getMachineMemOperand(N->getPointerInfo(), 
+                         MachineMemOperand::MOStore,  HiMemVT.getStoreSize(),
+                         SecondHalfAlignment, N->getAAInfo(), N->getRanges());
+
+  Hi = DAG.getMaskedStore(Ch, DL, DataHi, Ptr, MaskHi, HiMemVT, MMO,
+                          N->isTruncatingStore());
+
+
+  // Build a factor node to remember that this store is independent of the
+  // other one.
+  return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi);
+
+}
+
 SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) {
   assert(N->isUnindexed() && "Indexed store of vector?");
   assert(OpNo == 1 && "Can only split the stored value");
@@ -1599,6 +1718,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::VECTOR_SHUFFLE:
     Res = WidenVecRes_VECTOR_SHUFFLE(cast<ShuffleVectorSDNode>(N));
     break;
+  case ISD::MLOAD:
+    Res = WidenVecRes_MLOAD(cast<MaskedLoadSDNode>(N));
+    break;
 
   case ISD::ADD:
   case ISD::AND:
@@ -2289,6 +2411,44 @@ SDValue DAGTypeLegalizer::WidenVecRes_LOAD(SDNode *N) {
   return Result;
 }
 
+SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
+  
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0));
+  SDValue Mask = N->getMask();
+  EVT MaskVT = Mask.getValueType();
+  SDValue Src0 = GetWidenedVector(N->getSrc0());
+  ISD::LoadExtType ExtType = N->getExtensionType();
+  SDLoc dl(N);
+
+  if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
+    Mask = GetWidenedVector(Mask);
+  else {
+    EVT BoolVT = getSetCCResultType(WidenVT);
+
+    // We can't use ModifyToType() because we should fill the mask with
+    // zeroes
+    unsigned WidenNumElts = BoolVT.getVectorNumElements();
+    unsigned MaskNumElts = MaskVT.getVectorNumElements();
+
+    unsigned NumConcat = WidenNumElts / MaskNumElts;
+    SmallVector<SDValue, 16> Ops(NumConcat);
+    SDValue ZeroVal = DAG.getConstant(0, MaskVT);
+    Ops[0] = Mask;
+    for (unsigned i = 1; i != NumConcat; ++i)
+      Ops[i] = ZeroVal;
+
+    Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
+  }
+
+  SDValue Res = DAG.getMaskedLoad(WidenVT, dl, N->getChain(), N->getBasePtr(),
+                                  Mask, Src0, N->getMemoryVT(),
+                                  N->getMemOperand(), ExtType);
+  // Legalized the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
 SDValue DAGTypeLegalizer::WidenVecRes_SCALAR_TO_VECTOR(SDNode *N) {
   EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N),
@@ -2434,6 +2594,7 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXTRACT_SUBVECTOR:  Res = WidenVecOp_EXTRACT_SUBVECTOR(N); break;
   case ISD::EXTRACT_VECTOR_ELT: Res = WidenVecOp_EXTRACT_VECTOR_ELT(N); break;
   case ISD::STORE:              Res = WidenVecOp_STORE(N); break;
+  case ISD::MSTORE:             Res = WidenVecOp_MSTORE(N, OpNo); break;
   case ISD::SETCC:              Res = WidenVecOp_SETCC(N); break;
 
   case ISD::ANY_EXTEND:
@@ -2632,6 +2793,42 @@ SDValue DAGTypeLegalizer::WidenVecOp_STORE(SDNode *N) {
     return DAG.getNode(ISD::TokenFactor, SDLoc(ST), MVT::Other, StChain);
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_MSTORE(SDNode *N, unsigned OpNo) {
+  MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
+  SDValue Mask = MST->getMask();
+  EVT MaskVT = Mask.getValueType();
+  SDValue StVal = MST->getValue();
+  // Widen the value
+  SDValue WideVal = GetWidenedVector(StVal);
+  SDLoc dl(N);
+
+  if (OpNo == 2 || getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
+    Mask = GetWidenedVector(Mask);
+  else {
+    // The mask should be widened as well
+    EVT BoolVT = getSetCCResultType(WideVal.getValueType());
+    // We can't use ModifyToType() because we should fill the mask with
+    // zeroes
+    unsigned WidenNumElts = BoolVT.getVectorNumElements();
+    unsigned MaskNumElts = MaskVT.getVectorNumElements();
+
+    unsigned NumConcat = WidenNumElts / MaskNumElts;
+    SmallVector<SDValue, 16> Ops(NumConcat);
+    SDValue ZeroVal = DAG.getConstant(0, MaskVT);
+    Ops[0] = Mask;
+    for (unsigned i = 1; i != NumConcat; ++i)
+      Ops[i] = ZeroVal;
+
+    Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
+  }
+  assert(Mask.getValueType().getVectorNumElements() ==
+         WideVal.getValueType().getVectorNumElements() &&
+         "Mask and data vectors should have the same number of elements");
+  return DAG.getMaskedStore(MST->getChain(), dl, WideVal, MST->getBasePtr(),
+                            Mask, MST->getMemoryVT(), MST->getMemOperand(),
+                            false);
+}
+
 SDValue DAGTypeLegalizer::WidenVecOp_SETCC(SDNode *N) {
   SDValue InOp0 = GetWidenedVector(N->getOperand(0));
   SDValue InOp1 = GetWidenedVector(N->getOperand(1));