R600/SI: Try to keep i32 mul on SALU
[oota-llvm.git] / lib / Target / X86 / X86TargetTransformInfo.cpp
index 173bb4e6ec8841c974003bd63ac3fa708bfcdfb8..cd0336fa92f3462ab36838fdaad84cda1f656e99 100644 (file)
@@ -84,7 +84,8 @@ public:
   unsigned getRegisterBitWidth(bool Vector) const override;
   unsigned getMaximumUnrollFactor() const override;
   unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind,
-                                  OperandValueKind) const override;
+                                  OperandValueKind, OperandValueProperties,
+                                  OperandValueProperties) const override;
   unsigned getShuffleCost(ShuffleKind Kind, Type *Tp,
                           int Index, Type *SubTp) const override;
   unsigned getCastInstrCost(unsigned Opcode, Type *Dst,
@@ -178,15 +179,37 @@ unsigned X86TTI::getMaximumUnrollFactor() const {
   return 2;
 }
 
-unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty,
-                                        OperandValueKind Op1Info,
-                                        OperandValueKind Op2Info) const {
+unsigned X86TTI::getArithmeticInstrCost(
+    unsigned Opcode, Type *Ty, OperandValueKind Op1Info,
+    OperandValueKind Op2Info, OperandValueProperties Opd1PropInfo,
+    OperandValueProperties Opd2PropInfo) const {
   // Legalize the type.
   std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Ty);
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  if (ISD == ISD::SDIV &&
+      Op2Info == TargetTransformInfo::OK_UniformConstantValue &&
+      Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
+    // On X86, vector signed division by constants power-of-two are
+    // normally expanded to the sequence SRA + SRL + ADD + SRA.
+    // The OperandValue properties many not be same as that of previous
+    // operation;conservatively assume OP_None.
+    unsigned Cost =
+        2 * getArithmeticInstrCost(Instruction::AShr, Ty, Op1Info, Op2Info,
+                                   TargetTransformInfo::OP_None,
+                                   TargetTransformInfo::OP_None);
+    Cost += getArithmeticInstrCost(Instruction::LShr, Ty, Op1Info, Op2Info,
+                                   TargetTransformInfo::OP_None,
+                                   TargetTransformInfo::OP_None);
+    Cost += getArithmeticInstrCost(Instruction::Add, Ty, Op1Info, Op2Info,
+                                   TargetTransformInfo::OP_None,
+                                   TargetTransformInfo::OP_None);
+
+    return Cost;
+  }
+
   static const CostTblEntry<MVT::SimpleValueType>
   AVX2UniformConstCostTable[] = {
     { ISD::SDIV, MVT::v16i16,  6 }, // vpmulhw sequence