Prevent binary-tree deterioration in sparse switch statements.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / LegalizeVectorOps.cpp
index 8c02e65c5c4906ecf5ebdecdd0bb1db02f057000..eac404c5036596ef4ef993c715e05122fa26d707 100644 (file)
@@ -200,7 +200,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
     LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
     ISD::LoadExtType ExtType = LD->getExtensionType();
     if (LD->getMemoryVT().isVector() && ExtType != ISD::NON_EXTLOAD)
-      switch (TLI.getLoadExtAction(LD->getExtensionType(), LD->getMemoryVT())) {
+      switch (TLI.getLoadExtAction(LD->getExtensionType(), LD->getValueType(0),
+                                   LD->getMemoryVT())) {
       default: llvm_unreachable("This action is not supported yet!");
       case TargetLowering::Legal:
         return TranslateLegalizeResults(Op, Result);
@@ -290,6 +291,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::FP_TO_UINT:
   case ISD::FNEG:
   case ISD::FABS:
+  case ISD::FMINNUM:
+  case ISD::FMAXNUM:
   case ISD::FCOPYSIGN:
   case ISD::FSQRT:
   case ISD::FSIN:
@@ -370,9 +373,11 @@ SDValue VectorLegalizer::Promote(SDValue Op) {
     return PromoteFP_TO_INT(Op, Op->getOpcode() == ISD::FP_TO_SINT);
   }
 
-  // The rest of the time, vector "promotion" is basically just bitcasting and
-  // doing the operation in a different type.  For example, x86 promotes
-  // ISD::AND on v2i32 to v1i64.
+  // There are currently two cases of vector promotion:
+  // 1) Bitcasting a vector of integers to a different type to a vector of the
+  //    same overall length. For example, x86 promotes ISD::AND on v2i32 to v1i64.
+  // 2) Extending a vector of floats to a vector of the same number oflarger
+  //    floats. For example, AArch64 promotes ISD::FADD on v4f16 to v4f32.
   MVT VT = Op.getSimpleValueType();
   assert(Op.getNode()->getNumValues() == 1 &&
          "Can't promote a vector with multiple results!");
@@ -382,14 +387,23 @@ SDValue VectorLegalizer::Promote(SDValue Op) {
 
   for (unsigned j = 0; j != Op.getNumOperands(); ++j) {
     if (Op.getOperand(j).getValueType().isVector())
-      Operands[j] = DAG.getNode(ISD::BITCAST, dl, NVT, Op.getOperand(j));
+      if (Op.getOperand(j)
+              .getValueType()
+              .getVectorElementType()
+              .isFloatingPoint())
+        Operands[j] = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op.getOperand(j));
+      else
+        Operands[j] = DAG.getNode(ISD::BITCAST, dl, NVT, Op.getOperand(j));
     else
       Operands[j] = Op.getOperand(j);
   }
 
   Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands);
-
-  return DAG.getNode(ISD::BITCAST, dl, VT, Op);
+  if (VT.isFloatingPoint() ||
+      (VT.isVector() && VT.getVectorElementType().isFloatingPoint()))
+    return DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0));
+  else
+    return DAG.getNode(ISD::BITCAST, dl, VT, Op);
 }
 
 SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) {
@@ -507,8 +521,8 @@ SDValue VectorLegalizer::ExpandLoad(SDValue Op) {
         ScalarLoad = DAG.getExtLoad(ISD::EXTLOAD, dl, WideVT, Chain, BasePTR,
                                     LD->getPointerInfo().getWithOffset(Offset),
                                     LoadVT, LD->isVolatile(),
-                                    LD->isNonTemporal(), LD->getAlignment(),
-                                    LD->getAAInfo());
+                                    LD->isNonTemporal(), LD->isInvariant(),
+                                    LD->getAlignment(), LD->getAAInfo());
       }
 
       RemainingBytes -= LoadBytes;
@@ -578,7 +592,7 @@ SDValue VectorLegalizer::ExpandLoad(SDValue Op) {
                 Op.getNode()->getValueType(0).getScalarType(),
                 Chain, BasePTR, LD->getPointerInfo().getWithOffset(Idx * Stride),
                 SrcVT.getScalarType(),
-                LD->isVolatile(), LD->isNonTemporal(),
+                LD->isVolatile(), LD->isNonTemporal(), LD->isInvariant(),
                 LD->getAlignment(), LD->getAAInfo());
 
       BasePTR = DAG.getNode(ISD::ADD, dl, BasePTR.getValueType(), BasePTR,