Cleanup. No functional change intended.
[oota-llvm.git] / lib / Target / X86 / X86TargetTransformInfo.cpp
index 8a699afa6a40686ec7915eaf0a84c2735be64cf3..488c2a42b2d41a816cd9ed9b012b03402f4d37c5 100644 (file)
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/CostTable.h"
 using namespace llvm;
 
 // Declare the pass initialization routine locally as target-specific passes
@@ -75,7 +76,6 @@ public:
 
   /// \name Scalar TTI Implementations
   /// @{
-
   virtual PopcntSupportKind getPopcntSupport(unsigned TyWidth) const;
 
   /// @}
@@ -84,7 +84,11 @@ public:
   /// @{
 
   virtual unsigned getNumberOfRegisters(bool Vector) const;
-  virtual unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty) const;
+  virtual unsigned getRegisterBitWidth(bool Vector) const;
+  virtual unsigned getMaximumUnrollFactor() const;
+  virtual unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty,
+                                          OperandValueKind,
+                                          OperandValueKind) const;
   virtual unsigned getShuffleCost(ShuffleKind Kind, Type *Tp,
                                   int Index, Type *SubTp) const;
   virtual unsigned getCastInstrCost(unsigned Opcode, Type *Dst,
@@ -118,45 +122,6 @@ llvm::createX86TargetTransformInfoPass(const X86TargetMachine *TM) {
 //
 //===----------------------------------------------------------------------===//
 
-namespace {
-struct X86CostTblEntry {
-  int ISD;
-  MVT Type;
-  unsigned Cost;
-};
-}
-
-static int
-FindInTable(const X86CostTblEntry *Tbl, unsigned len, int ISD, MVT Ty) {
-  for (unsigned int i = 0; i < len; ++i)
-    if (Tbl[i].ISD == ISD && Tbl[i].Type == Ty)
-      return i;
-
-  // Could not find an entry.
-  return -1;
-}
-
-namespace {
-struct X86TypeConversionCostTblEntry {
-  int ISD;
-  MVT Dst;
-  MVT Src;
-  unsigned Cost;
-};
-}
-
-static int
-FindInConvertTable(const X86TypeConversionCostTblEntry *Tbl, unsigned len,
-                   int ISD, MVT Dst, MVT Src) {
-  for (unsigned int i = 0; i < len; ++i)
-    if (Tbl[i].ISD == ISD && Tbl[i].Src == Src && Tbl[i].Dst == Dst)
-      return i;
-
-  // Could not find an entry.
-  return -1;
-}
-
-
 X86TTI::PopcntSupportKind X86TTI::getPopcntSupport(unsigned TyWidth) const {
   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
   // TODO: Currently the __builtin_popcount() implementation using SSE3
@@ -166,45 +131,194 @@ X86TTI::PopcntSupportKind X86TTI::getPopcntSupport(unsigned TyWidth) const {
 }
 
 unsigned X86TTI::getNumberOfRegisters(bool Vector) const {
+  if (Vector && !ST->hasSSE1())
+    return 0;
+
   if (ST->is64Bit())
     return 16;
   return 8;
 }
 
-unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty) const {
+unsigned X86TTI::getRegisterBitWidth(bool Vector) const {
+  if (Vector) {
+    if (ST->hasAVX()) return 256;
+    if (ST->hasSSE1()) return 128;
+    return 0;
+  }
+
+  if (ST->is64Bit())
+    return 64;
+  return 32;
+
+}
+
+unsigned X86TTI::getMaximumUnrollFactor() const {
+  if (ST->isAtom())
+    return 1;
+
+  // Sandybridge and Haswell have multiple execution ports and pipelined
+  // vector units.
+  if (ST->hasAVX())
+    return 4;
+
+  return 2;
+}
+
+unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
+                                        OperandValueKind Op1Info,
+                                        OperandValueKind Op2Info) const {
   // Legalize the type.
   std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Ty);
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
-  static const X86CostTblEntry AVX1CostTable[] = {
+  static const CostTblEntry<MVT> AVX2CostTable[] = {
+    // Shifts on v4i64/v8i32 on AVX2 is legal even though we declare to
+    // customize them to detect the cases where shift amount is a scalar one.
+    { ISD::SHL,     MVT::v4i32,    1 },
+    { ISD::SRL,     MVT::v4i32,    1 },
+    { ISD::SRA,     MVT::v4i32,    1 },
+    { ISD::SHL,     MVT::v8i32,    1 },
+    { ISD::SRL,     MVT::v8i32,    1 },
+    { ISD::SRA,     MVT::v8i32,    1 },
+    { ISD::SHL,     MVT::v2i64,    1 },
+    { ISD::SRL,     MVT::v2i64,    1 },
+    { ISD::SHL,     MVT::v4i64,    1 },
+    { ISD::SRL,     MVT::v4i64,    1 },
+
+    { ISD::SHL,  MVT::v32i8,  42 }, // cmpeqb sequence.
+    { ISD::SHL,  MVT::v16i16,  16*10 }, // Scalarized.
+
+    { ISD::SRL,  MVT::v32i8,  32*10 }, // Scalarized.
+    { ISD::SRL,  MVT::v16i16,  8*10 }, // Scalarized.
+
+    { ISD::SRA,  MVT::v32i8,  32*10 }, // Scalarized.
+    { ISD::SRA,  MVT::v16i16,  16*10 }, // Scalarized.
+    { ISD::SRA,  MVT::v4i64,  4*10 }, // Scalarized.
+  };
+
+  // Look for AVX2 lowering tricks.
+  if (ST->hasAVX2()) {
+    int Idx = CostTableLookup<MVT>(AVX2CostTable, array_lengthof(AVX2CostTable),
+                                   ISD, LT.second);
+    if (Idx != -1)
+      return LT.first * AVX2CostTable[Idx].Cost;
+  }
+
+  static const CostTblEntry<MVT> SSE2UniformConstCostTable[] = {
+    // We don't correctly identify costs of casts because they are marked as
+    // custom.
+    // Constant splats are cheaper for the following instructions.
+    { ISD::SHL,  MVT::v16i8,  1 }, // psllw.
+    { ISD::SHL,  MVT::v8i16,  1 }, // psllw.
+    { ISD::SHL,  MVT::v4i32,  1 }, // pslld
+    { ISD::SHL,  MVT::v2i64,  1 }, // psllq.
+
+    { ISD::SRL,  MVT::v16i8,  1 }, // psrlw.
+    { ISD::SRL,  MVT::v8i16,  1 }, // psrlw.
+    { ISD::SRL,  MVT::v4i32,  1 }, // psrld.
+    { ISD::SRL,  MVT::v2i64,  1 }, // psrlq.
+
+    { ISD::SRA,  MVT::v16i8,  4 }, // psrlw, pand, pxor, psubb.
+    { ISD::SRA,  MVT::v8i16,  1 }, // psraw.
+    { ISD::SRA,  MVT::v4i32,  1 }, // psrad.
+  };
+
+  if (Op2Info == TargetTransformInfo::OK_UniformConstantValue &&
+      ST->hasSSE2()) {
+    int Idx = CostTableLookup<MVT>(SSE2UniformConstCostTable,
+                                   array_lengthof(SSE2UniformConstCostTable),
+                                   ISD, LT.second);
+    if (Idx != -1)
+      return LT.first * SSE2UniformConstCostTable[Idx].Cost;
+  }
+
+
+  static const CostTblEntry<MVT> SSE2CostTable[] = {
+    // We don't correctly identify costs of casts because they are marked as
+    // custom.
+    // For some cases, where the shift amount is a scalar we would be able
+    // to generate better code. Unfortunately, when this is the case the value
+    // (the splat) will get hoisted out of the loop, thereby making it invisible
+    // to ISel. The cost model must return worst case assumptions because it is
+    // used for vectorization and we don't want to make vectorized code worse
+    // than scalar code.
+    { ISD::SHL,  MVT::v16i8,  30 }, // cmpeqb sequence.
+    { ISD::SHL,  MVT::v8i16,  8*10 }, // Scalarized.
+    { ISD::SHL,  MVT::v4i32,  2*5 }, // We optimized this using mul.
+    { ISD::SHL,  MVT::v2i64,  2*10 }, // Scalarized.
+
+    { ISD::SRL,  MVT::v16i8,  16*10 }, // Scalarized.
+    { ISD::SRL,  MVT::v8i16,  8*10 }, // Scalarized.
+    { ISD::SRL,  MVT::v4i32,  4*10 }, // Scalarized.
+    { ISD::SRL,  MVT::v2i64,  2*10 }, // Scalarized.
+
+    { ISD::SRA,  MVT::v16i8,  16*10 }, // Scalarized.
+    { ISD::SRA,  MVT::v8i16,  8*10 }, // Scalarized.
+    { ISD::SRA,  MVT::v4i32,  4*10 }, // Scalarized.
+    { ISD::SRA,  MVT::v2i64,  2*10 }, // Scalarized.
+  };
+
+  if (ST->hasSSE2()) {
+    int Idx = CostTableLookup<MVT>(SSE2CostTable, array_lengthof(SSE2CostTable),
+                                   ISD, LT.second);
+    if (Idx != -1)
+      return LT.first * SSE2CostTable[Idx].Cost;
+  }
+
+  static const CostTblEntry<MVT> AVX1CostTable[] = {
     // We don't have to scalarize unsupported ops. We can issue two half-sized
     // operations and we only need to extract the upper YMM half.
     // Two ops + 1 extract + 1 insert = 4.
     { ISD::MUL,     MVT::v8i32,    4 },
     { ISD::SUB,     MVT::v8i32,    4 },
     { ISD::ADD,     MVT::v8i32,    4 },
-    { ISD::MUL,     MVT::v4i64,    4 },
     { ISD::SUB,     MVT::v4i64,    4 },
     { ISD::ADD,     MVT::v4i64,    4 },
-    };
+    // A v4i64 multiply is custom lowered as two split v2i64 vectors that then
+    // are lowered as a series of long multiplies(3), shifts(4) and adds(2)
+    // Because we believe v4i64 to be a legal type, we must also include the
+    // split factor of two in the cost table. Therefore, the cost here is 18
+    // instead of 9.
+    { ISD::MUL,     MVT::v4i64,    18 },
+  };
 
   // Look for AVX1 lowering tricks.
-  if (ST->hasAVX()) {
-    int Idx = FindInTable(AVX1CostTable, array_lengthof(AVX1CostTable), ISD,
-                          LT.second);
+  if (ST->hasAVX() && !ST->hasAVX2()) {
+    int Idx = CostTableLookup<MVT>(AVX1CostTable, array_lengthof(AVX1CostTable),
+                                   ISD, LT.second);
     if (Idx != -1)
       return LT.first * AVX1CostTable[Idx].Cost;
   }
+
+  // Custom lowering of vectors.
+  static const CostTblEntry<MVT> CustomLowered[] = {
+    // A v2i64/v4i64 and multiply is custom lowered as a series of long
+    // multiplies(3), shifts(4) and adds(2).
+    { ISD::MUL,     MVT::v2i64,    9 },
+    { ISD::MUL,     MVT::v4i64,    9 },
+  };
+  int Idx = CostTableLookup<MVT>(CustomLowered, array_lengthof(CustomLowered),
+                                 ISD, LT.second);
+  if (Idx != -1)
+    return LT.first * CustomLowered[Idx].Cost;
+
+  // Special lowering of v4i32 mul on sse2, sse3: Lower v4i32 mul as 2x shuffle,
+  // 2x pmuludq, 2x shuffle.
+  if (ISD == ISD::MUL && LT.second == MVT::v4i32 && ST->hasSSE2() &&
+      !ST->hasSSE41())
+    return 6;
+
   // Fallback to the default implementation.
-  return TargetTransformInfo::getArithmeticInstrCost(Opcode, Ty);
+  return TargetTransformInfo::getArithmeticInstrCost(Opcode, Ty, Op1Info,
+                                                     Op2Info);
 }
 
 unsigned X86TTI::getShuffleCost(ShuffleKind Kind, Type *Tp, int Index,
                                 Type *SubTp) const {
   // We only estimate the cost of reverse shuffles.
-  if (Kind != Reverse)
+  if (Kind != SK_Reverse)
     return TargetTransformInfo::getShuffleCost(Kind, Tp, Index, SubTp);
 
   std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Tp);
@@ -220,32 +334,89 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  std::pair<unsigned, MVT> LTSrc = TLI->getTypeLegalizationCost(Src);
+  std::pair<unsigned, MVT> LTDest = TLI->getTypeLegalizationCost(Dst);
+
+  static const TypeConversionCostTblEntry<MVT> SSE2ConvTbl[] = {
+    // These are somewhat magic numbers justified by looking at the output of
+    // Intel's IACA, running some kernels and making sure when we take
+    // legalization into account the throughput will be overestimated.
+    { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 2*10 },
+    { ISD::UINT_TO_FP, MVT::v2f64, MVT::v4i32, 4*10 },
+    { ISD::UINT_TO_FP, MVT::v2f64, MVT::v8i16, 8*10 },
+    { ISD::UINT_TO_FP, MVT::v2f64, MVT::v16i8, 16*10 },
+    { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 2*10 },
+    { ISD::SINT_TO_FP, MVT::v2f64, MVT::v4i32, 4*10 },
+    { ISD::SINT_TO_FP, MVT::v2f64, MVT::v8i16, 8*10 },
+    { ISD::SINT_TO_FP, MVT::v2f64, MVT::v16i8, 16*10 },
+    // There are faster sequences for float conversions.
+    { ISD::UINT_TO_FP, MVT::v4f32, MVT::v2i64, 15 },
+    { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 15 },
+    { ISD::UINT_TO_FP, MVT::v4f32, MVT::v8i16, 15 },
+    { ISD::UINT_TO_FP, MVT::v4f32, MVT::v16i8, 8 },
+    { ISD::SINT_TO_FP, MVT::v4f32, MVT::v2i64, 15 },
+    { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 15 },
+    { ISD::SINT_TO_FP, MVT::v4f32, MVT::v8i16, 15 },
+    { ISD::SINT_TO_FP, MVT::v4f32, MVT::v16i8, 8 },
+  };
+
+  if (ST->hasSSE2() && !ST->hasAVX()) {
+    int Idx = ConvertCostTableLookup<MVT>(SSE2ConvTbl,
+                                          array_lengthof(SSE2ConvTbl),
+                                          ISD, LTDest.second, LTSrc.second);
+    if (Idx != -1)
+      return LTSrc.first * SSE2ConvTbl[Idx].Cost;
+  }
+
   EVT SrcTy = TLI->getValueType(Src);
   EVT DstTy = TLI->getValueType(Dst);
 
-  if (!SrcTy.isSimple() || !DstTy.isSimple())
-    return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src);
-
-  static const X86TypeConversionCostTblEntry AVXConversionTbl[] = {
+  static const TypeConversionCostTblEntry<MVT> AVXConversionTbl[] = {
     { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 1 },
     { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 1 },
     { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 1 },
     { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 1 },
     { ISD::TRUNCATE,    MVT::v4i32, MVT::v4i64, 1 },
     { ISD::TRUNCATE,    MVT::v8i16, MVT::v8i32, 1 },
-    { ISD::SINT_TO_FP,  MVT::v8f32, MVT::v8i8,  1 },
-    { ISD::SINT_TO_FP,  MVT::v4f32, MVT::v4i8,  1 },
-    { ISD::UINT_TO_FP,  MVT::v8f32, MVT::v8i8,  1 },
-    { ISD::UINT_TO_FP,  MVT::v4f32, MVT::v4i8,  1 },
+
+    { ISD::SINT_TO_FP,  MVT::v8f32, MVT::v8i1,  8 },
+    { ISD::SINT_TO_FP,  MVT::v8f32, MVT::v8i8,  8 },
+    { ISD::SINT_TO_FP,  MVT::v8f32, MVT::v8i16, 5 },
+    { ISD::SINT_TO_FP,  MVT::v8f32, MVT::v8i32, 1 },
+    { ISD::SINT_TO_FP,  MVT::v4f32, MVT::v4i1,  3 },
+    { ISD::SINT_TO_FP,  MVT::v4f32, MVT::v4i8,  3 },
+    { ISD::SINT_TO_FP,  MVT::v4f32, MVT::v4i16, 3 },
+    { ISD::SINT_TO_FP,  MVT::v4f32, MVT::v4i32, 1 },
+    { ISD::SINT_TO_FP,  MVT::v4f64, MVT::v4i1,  3 },
+    { ISD::SINT_TO_FP,  MVT::v4f64, MVT::v4i8,  3 },
+    { ISD::SINT_TO_FP,  MVT::v4f64, MVT::v4i16, 3 },
+    { ISD::SINT_TO_FP,  MVT::v4f64, MVT::v4i32, 1 },
+
+    { ISD::UINT_TO_FP,  MVT::v8f32, MVT::v8i1,  6 },
+    { ISD::UINT_TO_FP,  MVT::v8f32, MVT::v8i8,  5 },
+    { ISD::UINT_TO_FP,  MVT::v8f32, MVT::v8i16, 5 },
+    { ISD::UINT_TO_FP,  MVT::v8f32, MVT::v8i32, 9 },
+    { ISD::UINT_TO_FP,  MVT::v4f32, MVT::v4i1,  7 },
+    { ISD::UINT_TO_FP,  MVT::v4f32, MVT::v4i8,  2 },
+    { ISD::UINT_TO_FP,  MVT::v4f32, MVT::v4i16, 2 },
+    { ISD::UINT_TO_FP,  MVT::v4f32, MVT::v4i32, 6 },
+    { ISD::UINT_TO_FP,  MVT::v4f64, MVT::v4i1,  7 },
+    { ISD::UINT_TO_FP,  MVT::v4f64, MVT::v4i8,  2 },
+    { ISD::UINT_TO_FP,  MVT::v4f64, MVT::v4i16, 2 },
+    { ISD::UINT_TO_FP,  MVT::v4f64, MVT::v4i32, 6 },
+
     { ISD::FP_TO_SINT,  MVT::v8i8,  MVT::v8f32, 1 },
     { ISD::FP_TO_SINT,  MVT::v4i8,  MVT::v4f32, 1 },
     { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i1,  6 },
     { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i1,  9 },
+    { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i1,  8 },
+    { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i8,  6 },
+    { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 6 },
     { ISD::TRUNCATE,    MVT::v8i32, MVT::v8i64, 3 },
   };
 
   if (ST->hasAVX()) {
-    int Idx = FindInConvertTable(AVXConversionTbl,
+    int Idx = ConvertCostTableLookup<MVT>(AVXConversionTbl,
                                  array_lengthof(AVXConversionTbl),
                                  ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT());
     if (Idx != -1)
@@ -265,7 +436,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
-  static const X86CostTblEntry SSE42CostTbl[] = {
+  static const CostTblEntry<MVT> SSE42CostTbl[] = {
     { ISD::SETCC,   MVT::v2f64,   1 },
     { ISD::SETCC,   MVT::v4f32,   1 },
     { ISD::SETCC,   MVT::v2i64,   1 },
@@ -274,7 +445,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
     { ISD::SETCC,   MVT::v16i8,   1 },
   };
 
-  static const X86CostTblEntry AVX1CostTbl[] = {
+  static const CostTblEntry<MVT> AVX1CostTbl[] = {
     { ISD::SETCC,   MVT::v4f64,   1 },
     { ISD::SETCC,   MVT::v8f32,   1 },
     // AVX1 does not support 8-wide integer compare.
@@ -284,7 +455,7 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
     { ISD::SETCC,   MVT::v32i8,   4 },
   };
 
-  static const X86CostTblEntry AVX2CostTbl[] = {
+  static const CostTblEntry<MVT> AVX2CostTbl[] = {
     { ISD::SETCC,   MVT::v4i64,   1 },
     { ISD::SETCC,   MVT::v8i32,   1 },
     { ISD::SETCC,   MVT::v16i16,  1 },
@@ -292,19 +463,19 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
   };
 
   if (ST->hasAVX2()) {
-    int Idx = FindInTable(AVX2CostTbl, array_lengthof(AVX2CostTbl), ISD, MTy);
+    int Idx = CostTableLookup<MVT>(AVX2CostTbl, array_lengthof(AVX2CostTbl), ISD, MTy);
     if (Idx != -1)
       return LT.first * AVX2CostTbl[Idx].Cost;
   }
 
   if (ST->hasAVX()) {
-    int Idx = FindInTable(AVX1CostTbl, array_lengthof(AVX1CostTbl), ISD, MTy);
+    int Idx = CostTableLookup<MVT>(AVX1CostTbl, array_lengthof(AVX1CostTbl), ISD, MTy);
     if (Idx != -1)
       return LT.first * AVX1CostTbl[Idx].Cost;
   }
 
   if (ST->hasSSE42()) {
-    int Idx = FindInTable(SSE42CostTbl, array_lengthof(SSE42CostTbl), ISD, MTy);
+    int Idx = CostTableLookup<MVT>(SSE42CostTbl, array_lengthof(SSE42CostTbl), ISD, MTy);
     if (Idx != -1)
       return LT.first * SSE42CostTbl[Idx].Cost;
   }