Fix integer cast code to handle vector types.
authorDan Gohman <gohman@apple.com>
Mon, 14 Dec 2009 23:40:38 +0000 (23:40 +0000)
committerDan Gohman <gohman@apple.com>
Mon, 14 Dec 2009 23:40:38 +0000 (23:40 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@91362 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/SelectionDAG.cpp
lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/vec-trunc-store.ll [new file with mode: 0644]

index aee2f20b9d00e7086c6a2df998c0b7085f037bad..027348c0f6d029b69f9213da99c34ca42fb20c31 100644 (file)
@@ -5196,7 +5196,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
     // SimplifyDemandedBits, which only works if the value has a single use.
     if (SimplifyDemandedBits(Value,
                              APInt::getLowBitsSet(
-                               Value.getValueSizeInBits(),
+                               Value.getValueType().getScalarType().getSizeInBits(),
                                ST->getMemoryVT().getSizeInBits())))
       return SDValue(N, 0);
   }
index c74a2e45c1f7bf5dc19adfd7b322061acdb96960..094625754a5acc8a7e5701faa5bad516abd836fa 100644 (file)
@@ -2354,6 +2354,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, DebugLoc DL,
     assert(VT.isFloatingPoint() &&
            Operand.getValueType().isFloatingPoint() && "Invalid FP cast!");
     if (Operand.getValueType() == VT) return Operand;  // noop conversion.
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() ==
+            Operand.getValueType().getVectorNumElements()) &&
+           "Vector element count mismatch!");
     if (Operand.getOpcode() == ISD::UNDEF)
       return getUNDEF(VT);
     break;
@@ -2361,8 +2365,12 @@ SDValue SelectionDAG::getNode(unsigned Opcode, DebugLoc DL,
     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
            "Invalid SIGN_EXTEND!");
     if (Operand.getValueType() == VT) return Operand;   // noop extension
-    assert(Operand.getValueType().bitsLT(VT)
-           && "Invalid sext node, dst < src!");
+    assert(Operand.getValueType().getScalarType().bitsLT(VT.getScalarType()) &&
+           "Invalid sext node, dst < src!");
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() ==
+            Operand.getValueType().getVectorNumElements()) &&
+           "Vector element count mismatch!");
     if (OpOpcode == ISD::SIGN_EXTEND || OpOpcode == ISD::ZERO_EXTEND)
       return getNode(OpOpcode, DL, VT, Operand.getNode()->getOperand(0));
     break;
@@ -2370,8 +2378,12 @@ SDValue SelectionDAG::getNode(unsigned Opcode, DebugLoc DL,
     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
            "Invalid ZERO_EXTEND!");
     if (Operand.getValueType() == VT) return Operand;   // noop extension
-    assert(Operand.getValueType().bitsLT(VT)
-           && "Invalid zext node, dst < src!");
+    assert(Operand.getValueType().getScalarType().bitsLT(VT.getScalarType()) &&
+           "Invalid zext node, dst < src!");
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() ==
+            Operand.getValueType().getVectorNumElements()) &&
+           "Vector element count mismatch!");
     if (OpOpcode == ISD::ZERO_EXTEND)   // (zext (zext x)) -> (zext x)
       return getNode(ISD::ZERO_EXTEND, DL, VT,
                      Operand.getNode()->getOperand(0));
@@ -2380,8 +2392,12 @@ SDValue SelectionDAG::getNode(unsigned Opcode, DebugLoc DL,
     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
            "Invalid ANY_EXTEND!");
     if (Operand.getValueType() == VT) return Operand;   // noop extension
-    assert(Operand.getValueType().bitsLT(VT)
-           && "Invalid anyext node, dst < src!");
+    assert(Operand.getValueType().getScalarType().bitsLT(VT.getScalarType()) &&
+           "Invalid anyext node, dst < src!");
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() ==
+            Operand.getValueType().getVectorNumElements()) &&
+           "Vector element count mismatch!");
     if (OpOpcode == ISD::ZERO_EXTEND || OpOpcode == ISD::SIGN_EXTEND)
       // (ext (zext x)) -> (zext x)  and  (ext (sext x)) -> (sext x)
       return getNode(OpOpcode, DL, VT, Operand.getNode()->getOperand(0));
@@ -2390,14 +2406,19 @@ SDValue SelectionDAG::getNode(unsigned Opcode, DebugLoc DL,
     assert(VT.isInteger() && Operand.getValueType().isInteger() &&
            "Invalid TRUNCATE!");
     if (Operand.getValueType() == VT) return Operand;   // noop truncate
-    assert(Operand.getValueType().bitsGT(VT)
-           && "Invalid truncate node, src < dst!");
+    assert(Operand.getValueType().getScalarType().bitsGT(VT.getScalarType()) &&
+           "Invalid truncate node, src < dst!");
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() ==
+            Operand.getValueType().getVectorNumElements()) &&
+           "Vector element count mismatch!");
     if (OpOpcode == ISD::TRUNCATE)
       return getNode(ISD::TRUNCATE, DL, VT, Operand.getNode()->getOperand(0));
     else if (OpOpcode == ISD::ZERO_EXTEND || OpOpcode == ISD::SIGN_EXTEND ||
              OpOpcode == ISD::ANY_EXTEND) {
       // If the source is smaller than the dest, we still need an extend.
-      if (Operand.getNode()->getOperand(0).getValueType().bitsLT(VT))
+      if (Operand.getNode()->getOperand(0).getValueType().getScalarType()
+            .bitsLT(VT.getScalarType()))
         return getNode(OpOpcode, DL, VT, Operand.getNode()->getOperand(0));
       else if (Operand.getNode()->getOperand(0).getValueType().bitsGT(VT))
         return getNode(ISD::TRUNCATE, DL, VT, Operand.getNode()->getOperand(0));
@@ -3743,16 +3764,15 @@ SelectionDAG::getLoad(ISD::MemIndexedMode AM, DebugLoc dl,
     assert(VT == MemVT && "Non-extending load from different memory type!");
   } else {
     // Extending load.
-    if (VT.isVector())
-      assert(MemVT.getVectorNumElements() == VT.getVectorNumElements() &&
-             "Invalid vector extload!");
-    else
-      assert(MemVT.bitsLT(VT) &&
-             "Should only be an extending load, not truncating!");
-    assert((ExtType == ISD::EXTLOAD || VT.isInteger()) &&
-           "Cannot sign/zero extend a FP/Vector load!");
+    assert(MemVT.getScalarType().bitsLT(VT.getScalarType()) &&
+           "Should only be an extending load, not truncating!");
     assert(VT.isInteger() == MemVT.isInteger() &&
            "Cannot convert from FP to Int or Int -> FP!");
+    assert(VT.isVector() == MemVT.isVector() &&
+           "Cannot use trunc store to convert to or from a vector!");
+    assert((!VT.isVector() ||
+            VT.getVectorNumElements() == MemVT.getVectorNumElements()) &&
+           "Cannot use trunc store to change the number of vector elements!");
   }
 
   bool Indexed = AM != ISD::UNINDEXED;
@@ -3885,10 +3905,15 @@ SDValue SelectionDAG::getTruncStore(SDValue Chain, DebugLoc dl, SDValue Val,
   if (VT == SVT)
     return getStore(Chain, dl, Val, Ptr, MMO);
 
-  assert(VT.bitsGT(SVT) && "Not a truncation?");
+  assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
+         "Should only be a truncating store, not extending!");
   assert(VT.isInteger() == SVT.isInteger() &&
          "Can't do FP-INT conversion!");
-
+  assert(VT.isVector() == SVT.isVector() &&
+         "Cannot use trunc store to convert to or from a vector!");
+  assert((!VT.isVector() ||
+          VT.getVectorNumElements() == SVT.getVectorNumElements()) &&
+         "Cannot use trunc store to change the number of vector elements!");
 
   SDVTList VTs = getVTList(MVT::Other);
   SDValue Undef = getUNDEF(Ptr.getValueType());
index 8c3b707e8fcdb78b97e666c096699a701b3dc9a8..99f984512307d4848521d40bc0ed29cff28619f1 100644 (file)
@@ -596,6 +596,17 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM)
     setOperationAction(ISD::UINT_TO_FP, (MVT::SimpleValueType)VT, Expand);
     setOperationAction(ISD::SINT_TO_FP, (MVT::SimpleValueType)VT, Expand);
     setOperationAction(ISD::SIGN_EXTEND_INREG, (MVT::SimpleValueType)VT,Expand);
+    setOperationAction(ISD::TRUNCATE,  (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::SIGN_EXTEND,  (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::ZERO_EXTEND,  (MVT::SimpleValueType)VT, Expand);
+    setOperationAction(ISD::ANY_EXTEND,  (MVT::SimpleValueType)VT, Expand);
+    for (unsigned InnerVT = (unsigned)MVT::FIRST_VECTOR_VALUETYPE;
+         InnerVT <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++InnerVT)
+      setTruncStoreAction((MVT::SimpleValueType)VT,
+                          (MVT::SimpleValueType)InnerVT, Expand);
+    setLoadExtAction(ISD::SEXTLOAD, (MVT::SimpleValueType)VT, Expand);
+    setLoadExtAction(ISD::ZEXTLOAD, (MVT::SimpleValueType)VT, Expand);
+    setLoadExtAction(ISD::EXTLOAD, (MVT::SimpleValueType)VT, Expand);
   }
 
   // FIXME: In order to prevent SSE instructions being expanded to MMX ones
@@ -672,8 +683,6 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM)
 
     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v4i16, Custom);
 
-    setTruncStoreAction(MVT::v8i16,             MVT::v8i8, Expand);
-    setOperationAction(ISD::TRUNCATE,           MVT::v8i8, Expand);
     setOperationAction(ISD::SELECT,             MVT::v8i8, Promote);
     setOperationAction(ISD::SELECT,             MVT::v4i16, Promote);
     setOperationAction(ISD::SELECT,             MVT::v2i32, Promote);
diff --git a/test/CodeGen/X86/vec-trunc-store.ll b/test/CodeGen/X86/vec-trunc-store.ll
new file mode 100644 (file)
index 0000000..ea1a151
--- /dev/null
@@ -0,0 +1,13 @@
+; RUN: llc < %s -march=x86-64 -disable-mmx | grep punpcklwd | count 2
+
+define void @foo() nounwind {
+  %cti69 = trunc <8 x i32> undef to <8 x i16>     ; <<8 x i16>> [#uses=1]
+  store <8 x i16> %cti69, <8 x i16>* undef
+  ret void
+}
+
+define void @bar() nounwind {
+  %cti44 = trunc <4 x i32> undef to <4 x i16>     ; <<4 x i16>> [#uses=1]
+  store <4 x i16> %cti44, <4 x i16>* undef
+  ret void
+}