[SCEV] Factor out common visiting patterns for SCEV rewriters. NFC.
authorSilviu Baranga <silviu.baranga@arm.com>
Mon, 26 Oct 2015 11:18:31 +0000 (11:18 +0000)
committerSilviu Baranga <silviu.baranga@arm.com>
Mon, 26 Oct 2015 11:18:31 +0000 (11:18 +0000)
Summary:
Add a SCEVRewriteVisitor class which contains the common
visiting patterns used when rewriting SCEVs.

SCEVParameterRewriter and SCEVApplyRewriter now inherit
from SCEVRewriteVisitor (and are therefore much simpler).

Reviewers: anemet, mzolotukhin, sanjoy

Subscribers: rengolin, llvm-commits, sanjoy

Differential Revision: http://reviews.llvm.org/D13242

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251283 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolutionExpressions.h

index d5a3fc4e9dacb62d953dec5dba9946d250eb90b9..3d9522773bcf644bb5ebad5596168ca8bbc6498f 100644 (file)
@@ -553,64 +553,56 @@ namespace llvm {
     T.visitAll(Root);
   }
 
-  typedef DenseMap<const Value*, Value*> ValueToValueMap;
-
-  /// The SCEVParameterRewriter takes a scalar evolution expression and updates
-  /// the SCEVUnknown components following the Map (Value -> Value).
-  struct SCEVParameterRewriter
-    : public SCEVVisitor<SCEVParameterRewriter, const SCEV*> {
+  /// Recursively visits a SCEV expression and re-writes it.
+  template<typename SC>
+  class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
+  protected:
+    ScalarEvolution &SE;
   public:
-    static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
-                               ValueToValueMap &Map,
-                               bool InterpretConsts = false) {
-      SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
-      return Rewriter.visit(Scev);
-    }
-
-    SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M, bool C)
-      : SE(S), Map(M), InterpretConsts(C) {}
+    SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
 
     const SCEV *visitConstant(const SCEVConstant *Constant) {
       return Constant;
     }
 
     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
+      const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
       return SE.getTruncateExpr(Operand, Expr->getType());
     }
 
     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
+      const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
       return SE.getZeroExtendExpr(Operand, Expr->getType());
     }
 
     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
+      const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
       return SE.getSignExtendExpr(Operand, Expr->getType());
     }
 
     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
+        Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
       return SE.getAddExpr(Operands);
     }
 
     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
+        Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
       return SE.getMulExpr(Operands);
     }
 
     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
-      return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS()));
+      return SE.getUDivExpr(((SC*)this)->visit(Expr->getLHS()),
+                            ((SC*)this)->visit(Expr->getRHS()));
     }
 
     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
+        Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
       return SE.getAddRecExpr(Operands, Expr->getLoop(),
                               Expr->getNoWrapFlags());
     }
@@ -618,17 +610,42 @@ namespace llvm {
     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
+        Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
       return SE.getSMaxExpr(Operands);
     }
 
     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
+        Operands.push_back(((SC*)this)->visit(Expr->getOperand(i)));
       return SE.getUMaxExpr(Operands);
     }
 
+    const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+      return Expr;
+    }
+
+    const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
+      return Expr;
+    }
+  };
+
+  typedef DenseMap<const Value*, Value*> ValueToValueMap;
+
+  /// The SCEVParameterRewriter takes a scalar evolution expression and updates
+  /// the SCEVUnknown components following the Map (Value -> Value).
+  class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
+  public:
+    static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+                               ValueToValueMap &Map,
+                               bool InterpretConsts = false) {
+      SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
+      return Rewriter.visit(Scev);
+    }
+
+    SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
+      : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}
+
     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
       Value *V = Expr->getValue();
       if (Map.count(V)) {
@@ -640,12 +657,7 @@ namespace llvm {
       return Expr;
     }
 
-    const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
-      return Expr;
-    }
-
   private:
-    ScalarEvolution &SE;
     ValueToValueMap &Map;
     bool InterpretConsts;
   };
@@ -654,8 +666,7 @@ namespace llvm {
 
   /// The SCEVApplyRewriter takes a scalar evolution expression and applies
   /// the Map (Loop -> SCEV) to all AddRecExprs.
-  struct SCEVApplyRewriter
-    : public SCEVVisitor<SCEVApplyRewriter, const SCEV*> {
+  class SCEVApplyRewriter : public SCEVRewriteVisitor<SCEVApplyRewriter> {
   public:
     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
                                ScalarEvolution &SE) {
@@ -663,45 +674,8 @@ namespace llvm {
       return Rewriter.visit(Scev);
     }
 
-    SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M)
-      : SE(S), Map(M) {}
-
-    const SCEV *visitConstant(const SCEVConstant *Constant) {
-      return Constant;
-    }
-
-    const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
-      return SE.getTruncateExpr(Operand, Expr->getType());
-    }
-
-    const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
-      return SE.getZeroExtendExpr(Operand, Expr->getType());
-    }
-
-    const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
-      const SCEV *Operand = visit(Expr->getOperand());
-      return SE.getSignExtendExpr(Operand, Expr->getType());
-    }
-
-    const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
-      SmallVector<const SCEV *, 2> Operands;
-      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
-      return SE.getAddExpr(Operands);
-    }
-
-    const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
-      SmallVector<const SCEV *, 2> Operands;
-      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
-      return SE.getMulExpr(Operands);
-    }
-
-    const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
-      return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS()));
-    }
+    SCEVApplyRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
+      : SCEVRewriteVisitor(SE), Map(M) {}
 
     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
       SmallVector<const SCEV *, 2> Operands;
@@ -718,30 +692,7 @@ namespace llvm {
       return Rec->evaluateAtIteration(Map[L], SE);
     }
 
-    const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
-      SmallVector<const SCEV *, 2> Operands;
-      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
-      return SE.getSMaxExpr(Operands);
-    }
-
-    const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
-      SmallVector<const SCEV *, 2> Operands;
-      for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
-        Operands.push_back(visit(Expr->getOperand(i)));
-      return SE.getUMaxExpr(Operands);
-    }
-
-    const SCEV *visitUnknown(const SCEVUnknown *Expr) {
-      return Expr;
-    }
-
-    const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
-      return Expr;
-    }
-
   private:
-    ScalarEvolution &SE;
     LoopToScevMapT &Map;
   };