Loop Strength Reduce: Scaling factor cost.
[oota-llvm.git] / lib / Transforms / Scalar / LoopStrengthReduce.cpp
index 4e4cb864641411a95871e659094005507db50744..b107fef35a0fdd9380e11c7690983a5b2a88ca4f 100644 (file)
@@ -773,6 +773,16 @@ DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakVH> &DeadInsts) {
   return Changed;
 }
 
+namespace {
+class LSRUse;
+}
+// Check if it is legal to fold 2 base registers.
+static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
+                             const Formula &F);
+// Get the cost of the scaling factor used in F for LU.
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F);
+
 namespace {
 
 /// Cost - This class is used to measure and compare candidate formulae.
@@ -785,11 +795,12 @@ class Cost {
   unsigned NumBaseAdds;
   unsigned ImmCost;
   unsigned SetupCost;
+  unsigned ScaleCost;
 
 public:
   Cost()
     : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0),
-      SetupCost(0) {}
+      SetupCost(0), ScaleCost(0) {}
 
   bool operator<(const Cost &Other) const;
 
@@ -799,9 +810,9 @@ public:
   // Once any of the metrics loses, they must all remain losers.
   bool isValid() {
     return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
-             | ImmCost | SetupCost) != ~0u)
+             | ImmCost | SetupCost | ScaleCost) != ~0u)
       || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
-           & ImmCost & SetupCost) == ~0u);
+           & ImmCost & SetupCost & ScaleCost) == ~0u);
   }
 #endif
 
@@ -810,12 +821,14 @@ public:
     return NumRegs == ~0u;
   }
 
-  void RateFormula(const Formula &F,
+  void RateFormula(const TargetTransformInfo &TTI,
+                   const Formula &F,
                    SmallPtrSet<const SCEV *, 16> &Regs,
                    const DenseSet<const SCEV *> &VisitedRegs,
                    const Loop *L,
                    const SmallVectorImpl<int64_t> &Offsets,
                    ScalarEvolution &SE, DominatorTree &DT,
+                   const LSRUse &LU,
                    SmallPtrSet<const SCEV *, 16> *LoserRegs = 0);
 
   void print(raw_ostream &OS) const;
@@ -895,17 +908,19 @@ void Cost::RatePrimaryRegister(const SCEV *Reg,
   }
   if (Regs.insert(Reg)) {
     RateRegister(Reg, Regs, L, SE, DT);
-    if (isLoser())
+    if (LoserRegs && isLoser())
       LoserRegs->insert(Reg);
   }
 }
 
-void Cost::RateFormula(const Formula &F,
+void Cost::RateFormula(const TargetTransformInfo &TTI,
+                       const Formula &F,
                        SmallPtrSet<const SCEV *, 16> &Regs,
                        const DenseSet<const SCEV *> &VisitedRegs,
                        const Loop *L,
                        const SmallVectorImpl<int64_t> &Offsets,
                        ScalarEvolution &SE, DominatorTree &DT,
+                       const LSRUse &LU,
                        SmallPtrSet<const SCEV *, 16> *LoserRegs) {
   // Tally up the registers.
   if (const SCEV *ScaledReg = F.ScaledReg) {
@@ -932,7 +947,12 @@ void Cost::RateFormula(const Formula &F,
   // Determine how many (unfolded) adds we'll need inside the loop.
   size_t NumBaseParts = F.BaseRegs.size() + (F.UnfoldedOffset != 0);
   if (NumBaseParts > 1)
-    NumBaseAdds += NumBaseParts - 1;
+    // Do not count the base and a possible second register if the target
+    // allows to fold 2 registers.
+    NumBaseAdds += NumBaseParts - (1 + isLegal2RegAMUse(TTI, LU, F));
+
+  // Accumulate non-free scaling amounts.
+  ScaleCost += getScalingFactorCost(TTI, LU, F);
 
   // Tally up the non-zero immediates.
   for (SmallVectorImpl<int64_t>::const_iterator I = Offsets.begin(),
@@ -955,6 +975,7 @@ void Cost::Loose() {
   NumBaseAdds = ~0u;
   ImmCost = ~0u;
   SetupCost = ~0u;
+  ScaleCost = ~0u;
 }
 
 /// operator< - Choose the lower cost.
@@ -967,6 +988,8 @@ bool Cost::operator<(const Cost &Other) const {
     return NumIVMuls < Other.NumIVMuls;
   if (NumBaseAdds != Other.NumBaseAdds)
     return NumBaseAdds < Other.NumBaseAdds;
+  if (ScaleCost != Other.ScaleCost)
+    return ScaleCost < Other.ScaleCost;
   if (ImmCost != Other.ImmCost)
     return ImmCost < Other.ImmCost;
   if (SetupCost != Other.SetupCost)
@@ -983,6 +1006,8 @@ void Cost::print(raw_ostream &OS) const {
   if (NumBaseAdds != 0)
     OS << ", plus " << NumBaseAdds << " base add"
        << (NumBaseAdds == 1 ? "" : "s");
+  if (ScaleCost != 0)
+    OS << ", plus " << ScaleCost << " scale cost";
   if (ImmCost != 0)
     OS << ", plus " << ImmCost << " imm cost";
   if (SetupCost != 0)
@@ -1359,6 +1384,58 @@ static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset,
                     F.BaseOffset, F.HasBaseReg, F.Scale);
 }
 
+static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
+                             const Formula &F) {
+  // If F is used as an Addressing Mode, it may fold one Base plus one
+  // scaled register. If the scaled register is nil, do as if another
+  // element of the base regs is a 1-scaled register.
+  // This is possible if BaseRegs has at least 2 registers.
+
+  // If this is not an address calculation, this is not an addressing mode
+  // use.
+  if (LU.Kind !=  LSRUse::Address)
+    return false;
+
+  // F is already scaled.
+  if (F.Scale != 0)
+    return false;
+
+  // We need to keep one register for the base and one to scale.
+  if (F.BaseRegs.size() < 2)
+    return false;
+
+  return isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy,
+                    F.BaseGV, F.BaseOffset, F.HasBaseReg, 1);
+ }
+
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F) {
+  if (!F.Scale)
+    return 0;
+  assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind,
+                    LU.AccessTy, F) && "Illegal formula in use.");
+
+  switch (LU.Kind) {
+  case LSRUse::Address: {
+    int CurScaleCost = TTI.getScalingFactorCost(LU.AccessTy, F.BaseGV,
+                                                F.BaseOffset, F.HasBaseReg,
+                                                F.Scale);
+    assert(CurScaleCost >= 0 && "Legal addressing mode has an illegal cost!");
+    return CurScaleCost;
+  }
+  case LSRUse::ICmpZero:
+    // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg.
+    // Therefore, return 0 in case F.Scale == -1. 
+    return F.Scale != -1;
+
+  case LSRUse::Basic:
+  case LSRUse::Special:
+    return 0;
+  }
+
+  llvm_unreachable("Invalid LSRUse Kind!");
+}
+
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
                              LSRUse::KindType Kind, Type *AccessTy,
                              GlobalValue *BaseGV, int64_t BaseOffset,
@@ -1895,15 +1972,13 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
   if (ICmpInst::isTrueWhenEqual(Pred)) {
     // Look for n+1, and grab n.
     if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(1)))
-      if (isa<ConstantInt>(BO->getOperand(1)) &&
-          cast<ConstantInt>(BO->getOperand(1))->isOne() &&
-          SE.getSCEV(BO->getOperand(0)) == MaxRHS)
-        NewRHS = BO->getOperand(0);
+      if (ConstantInt *BO1 = dyn_cast<ConstantInt>(BO->getOperand(1)))
+         if (BO1->isOne() && SE.getSCEV(BO->getOperand(0)) == MaxRHS)
+           NewRHS = BO->getOperand(0);
     if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(2)))
-      if (isa<ConstantInt>(BO->getOperand(1)) &&
-          cast<ConstantInt>(BO->getOperand(1))->isOne() &&
-          SE.getSCEV(BO->getOperand(0)) == MaxRHS)
-        NewRHS = BO->getOperand(0);
+      if (ConstantInt *BO1 = dyn_cast<ConstantInt>(BO->getOperand(1)))
+        if (BO1->isOne() && SE.getSCEV(BO->getOperand(0)) == MaxRHS)
+          NewRHS = BO->getOperand(0);
     if (!NewRHS)
       return Cond;
   } else if (SE.getSCEV(Sel->getOperand(1)) == MaxRHS)
@@ -2716,6 +2791,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,
   // by LSR.
   const IVInc &Head = Chain.Incs[0];
   User::op_iterator IVOpEnd = Head.UserInst->op_end();
+  // findIVOperand returns IVOpEnd if it can no longer find a valid IV user.
   User::op_iterator IVOpIter = findIVOperand(Head.UserInst->op_begin(),
                                              IVOpEnd, L, SE);
   Value *IVSrc = 0;
@@ -3608,7 +3684,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {
                    abs64(NewF.BaseOffset)) &&
                   (C->getValue()->getValue() +
                    NewF.BaseOffset).countTrailingZeros() >=
-                   CountTrailingZeros_64(NewF.BaseOffset))
+                   countTrailingZeros<uint64_t>(NewF.BaseOffset))
                 goto skip_formula;
 
           // Ok, looks good.
@@ -3691,7 +3767,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() {
       // the corresponding bad register from the Regs set.
       Cost CostF;
       Regs.clear();
-      CostF.RateFormula(F, Regs, VisitedRegs, L, LU.Offsets, SE, DT,
+      CostF.RateFormula(TTI, F, Regs, VisitedRegs, L, LU.Offsets, SE, DT, LU,
                         &LoserRegs);
       if (CostF.isLoser()) {
         // During initial formula generation, undesirable formulae are generated
@@ -3727,7 +3803,8 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() {
 
         Cost CostBest;
         Regs.clear();
-        CostBest.RateFormula(Best, Regs, VisitedRegs, L, LU.Offsets, SE, DT);
+        CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, LU.Offsets, SE,
+                             DT, LU);
         if (CostF < CostBest)
           std::swap(F, Best);
         DEBUG(dbgs() << "  Filtering out formula "; F.print(dbgs());
@@ -4080,7 +4157,8 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution,
     // the current best, prune the search at that point.
     NewCost = CurCost;
     NewRegs = CurRegs;
-    NewCost.RateFormula(F, NewRegs, VisitedRegs, L, LU.Offsets, SE, DT);
+    NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, LU.Offsets, SE, DT,
+                        LU);
     if (NewCost < SolutionCost) {
       Workspace.push_back(&F);
       if (Workspace.size() != Uses.size()) {