[IndVarSimplify] Make cost estimation in RewriteLoopExitValues smarter
authorIgor Laevsky <igmyrj@gmail.com>
Mon, 10 Aug 2015 18:23:58 +0000 (18:23 +0000)
committerIgor Laevsky <igmyrj@gmail.com>
Mon, 10 Aug 2015 18:23:58 +0000 (18:23 +0000)
Differential Revision: http://reviews.llvm.org/D11687

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@244474 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolutionExpander.h
lib/Analysis/ScalarEvolutionExpander.cpp
lib/Transforms/Scalar/IndVarSimplify.cpp

index 8ec2078258d111be7a78ee6bbfe48f13cbd7fd24..e5647860db5c136f5d043f36f7c60931937399a4 100644 (file)
@@ -117,9 +117,13 @@ namespace llvm {
 
     /// \brief Return true for expressions that may incur non-trivial cost to
     /// evaluate at runtime.
-    bool isHighCostExpansion(const SCEV *Expr, Loop *L) {
+    /// 'At' is an optional parameter which specifies point in code where user
+    /// is going to expand this expression. Sometimes this knowledge can lead to
+    /// a more accurate cost estimation.
+    bool isHighCostExpansion(const SCEV *Expr, Loop *L,
+                             const Instruction *At = nullptr) {
       SmallPtrSet<const SCEV *, 8> Processed;
-      return isHighCostExpansionHelper(Expr, L, Processed);
+      return isHighCostExpansionHelper(Expr, L, At, Processed);
     }
 
     /// \brief This method returns the canonical induction variable of the
@@ -193,11 +197,20 @@ namespace llvm {
 
     void setChainedPhi(PHINode *PN) { ChainedPhis.insert(PN); }
 
+    /// \brief Try to find LLVM IR value for 'S' available at the point 'At'.
+    // 'L' is a hint which tells in which loop to look for the suitable value.
+    // On success return value which is equivalent to the expanded 'S' at point
+    // 'At'. Return nullptr if value was not found.
+    // Note that this function does not perform exhaustive search. I.e if it
+    // didn't find any value it does not mean that there is no such value.
+    Value *findExistingExpansion(const SCEV *S, const Instruction *At, Loop *L);
+
   private:
     LLVMContext &getContext() const { return SE.getContext(); }
 
     /// \brief Recursive helper function for isHighCostExpansion.
     bool isHighCostExpansionHelper(const SCEV *S, Loop *L,
+                                   const Instruction *At,
                                    SmallPtrSetImpl<const SCEV *> &Processed);
 
     /// \brief Insert the specified binary operator, doing a small amount
index fe2a2a5005e8287a89a1287508fe4ef0d190ec92..42069d294038c1c50530ea6f8332fdaaf2737af4 100644 (file)
@@ -1812,8 +1812,46 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
   return NumElim;
 }
 
+Value *SCEVExpander::findExistingExpansion(const SCEV *S,
+                                           const Instruction *At, Loop *L) {
+  using namespace llvm::PatternMatch;
+
+  SmallVector<BasicBlock *, 4> Latches;
+  L->getLoopLatches(Latches);
+
+  // Look for suitable value in simple conditions at the loop latches.
+  for (BasicBlock *BB : Latches) {
+    ICmpInst::Predicate Pred;
+    Instruction *LHS, *RHS;
+    BasicBlock *TrueBB, *FalseBB;
+
+    if (!match(BB->getTerminator(),
+               m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)),
+                    TrueBB, FalseBB)))
+      continue;
+
+    if (SE.getSCEV(LHS) == S && SE.DT->dominates(LHS, At))
+      return LHS;
+
+    if (SE.getSCEV(RHS) == S && SE.DT->dominates(RHS, At))
+      return RHS;
+  }
+
+  // There is potential to make this significantly smarter, but this simple
+  // heuristic already gets some interesting cases.
+
+  // Can not find suitable value.
+  return nullptr;
+}
+
 bool SCEVExpander::isHighCostExpansionHelper(
-    const SCEV *S, Loop *L, SmallPtrSetImpl<const SCEV *> &Processed) {
+    const SCEV *S, Loop *L, const Instruction *At,
+    SmallPtrSetImpl<const SCEV *> &Processed) {
+
+  // If we can find an existing value for this scev avaliable at the point "At"
+  // then consider the expression cheap.
+  if (At && findExistingExpansion(S, At, L) != nullptr)
+    return false;
 
   // Zero/One operand expressions
   switch (S->getSCEVType()) {
@@ -1821,14 +1859,14 @@ bool SCEVExpander::isHighCostExpansionHelper(
   case scConstant:
     return false;
   case scTruncate:
-    return isHighCostExpansionHelper(cast<SCEVTruncateExpr>(S)->getOperand(), L,
-                                     Processed);
+    return isHighCostExpansionHelper(cast<SCEVTruncateExpr>(S)->getOperand(),
+                                     L, At, Processed);
   case scZeroExtend:
     return isHighCostExpansionHelper(cast<SCEVZeroExtendExpr>(S)->getOperand(),
-                                     L, Processed);
+                                     L, At, Processed);
   case scSignExtend:
     return isHighCostExpansionHelper(cast<SCEVSignExtendExpr>(S)->getOperand(),
-                                     L, Processed);
+                                     L, At, Processed);
   }
 
   if (!Processed.insert(S).second)
@@ -1884,7 +1922,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
   if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) {
     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
          I != E; ++I) {
-      if (isHighCostExpansionHelper(*I, L, Processed))
+      if (isHighCostExpansionHelper(*I, L, At, Processed))
         return true;
     }
   }
index 2a954d9961f22dedf4219d7c0e162db7a5df2967..9c2e11be3f9e124586414a196fafc6af280257a0 100644 (file)
@@ -138,8 +138,7 @@ namespace {
     void SinkUnusedInvariants(Loop *L);
 
     Value *ExpandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L,
-                              Instruction *InsertPt, Type *Ty,
-                              bool &IsHighCostExpansion);
+                              Instruction *InsertPt, Type *Ty);
   };
 }
 
@@ -503,47 +502,13 @@ struct RewritePhi {
 
 Value *IndVarSimplify::ExpandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S,
                                           Loop *L, Instruction *InsertPt,
-                                          Type *ResultTy,
-                                          bool &IsHighCostExpansion) {
-  using namespace llvm::PatternMatch;
-
-  if (!Rewriter.isHighCostExpansion(S, L)) {
-    IsHighCostExpansion = false;
-    return Rewriter.expandCodeFor(S, ResultTy, InsertPt);
-  }
-
+                                          Type *ResultTy) {
   // Before expanding S into an expensive LLVM expression, see if we can use an
-  // already existing value as the expansion for S.  There is potential to make
-  // this significantly smarter, but this simple heuristic already gets some
-  // interesting cases.
-
-  SmallVector<BasicBlock *, 4> Latches;
-  L->getLoopLatches(Latches);
-
-  for (BasicBlock *BB : Latches) {
-    ICmpInst::Predicate Pred;
-    Instruction *LHS, *RHS;
-    BasicBlock *TrueBB, *FalseBB;
-
-    if (!match(BB->getTerminator(),
-               m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)),
-                    TrueBB, FalseBB)))
-      continue;
-
-    if (SE->getSCEV(LHS) == S && DT->dominates(LHS, InsertPt)) {
-      IsHighCostExpansion = false;
-      return LHS;
-    }
-
-    if (SE->getSCEV(RHS) == S && DT->dominates(RHS, InsertPt)) {
-      IsHighCostExpansion = false;
-      return RHS;
-    }
-  }
+  // already existing value as the expansion for S.
+  if (Value *RetValue = Rewriter.findExistingExpansion(S, InsertPt, L))
+    return RetValue;
 
   // We didn't find anything, fall back to using SCEVExpander.
-  assert(Rewriter.isHighCostExpansion(S, L) && "this should not have changed!");
-  IsHighCostExpansion = true;
   return Rewriter.expandCodeFor(S, ResultTy, InsertPt);
 }
 
@@ -679,9 +644,9 @@ void IndVarSimplify::RewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) {
             continue;
         }
 
-        bool HighCost = false;
-        Value *ExitVal = ExpandSCEVIfNeeded(Rewriter, ExitValue, L, Inst,
-                                            PN->getType(), HighCost);
+        bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst);
+        Value *ExitVal =
+            ExpandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType());
 
         DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n'
                      << "  LoopVal = " << *Inst << "\n");