From e1ed72f4594726ef7ed4afcfe14c82a7097b6c29 Mon Sep 17 00:00:00 2001 From: Silviu Baranga Date: Mon, 26 Oct 2015 11:18:31 +0000 Subject: [PATCH] [SCEV] Factor out common visiting patterns for SCEV rewriters. NFC. Summary: Add a SCEVRewriteVisitor class which contains the common visiting patterns used when rewriting SCEVs. SCEVParameterRewriter and SCEVApplyRewriter now inherit from SCEVRewriteVisitor (and are therefore much simpler). Reviewers: anemet, mzolotukhin, sanjoy Subscribers: rengolin, llvm-commits, sanjoy Differential Revision: http://reviews.llvm.org/D13242 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251283 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../Analysis/ScalarEvolutionExpressions.h | 137 ++++++------------ 1 file changed, 44 insertions(+), 93 deletions(-) diff --git a/include/llvm/Analysis/ScalarEvolutionExpressions.h b/include/llvm/Analysis/ScalarEvolutionExpressions.h index d5a3fc4e9da..3d9522773bc 100644 --- a/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -553,64 +553,56 @@ namespace llvm { T.visitAll(Root); } - typedef DenseMap ValueToValueMap; - - /// The SCEVParameterRewriter takes a scalar evolution expression and updates - /// the SCEVUnknown components following the Map (Value -> Value). - struct SCEVParameterRewriter - : public SCEVVisitor { + /// Recursively visits a SCEV expression and re-writes it. + template + class SCEVRewriteVisitor : public SCEVVisitor { + protected: + ScalarEvolution &SE; public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToValueMap &Map, - bool InterpretConsts = false) { - SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); - return Rewriter.visit(Scev); - } - - SCEVParameterRewriter(ScalarEvolution &S, ValueToValueMap &M, bool C) - : SE(S), Map(M), InterpretConsts(C) {} + SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getTruncateExpr(Operand, Expr->getType()); } const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getZeroExtendExpr(Operand, Expr->getType()); } const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); return SE.getSignExtendExpr(Operand, Expr->getType()); } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getAddExpr(Operands); } const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getMulExpr(Operands); } const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); + return SE.getUDivExpr(((SC*)this)->visit(Expr->getLHS()), + ((SC*)this)->visit(Expr->getRHS())); } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getAddRecExpr(Operands, Expr->getLoop(), Expr->getNoWrapFlags()); } @@ -618,17 +610,42 @@ namespace llvm { const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getSMaxExpr(Operands); } const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { SmallVector Operands; for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); + Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); return SE.getUMaxExpr(Operands); } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return Expr; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return Expr; + } + }; + + typedef DenseMap ValueToValueMap; + + /// The SCEVParameterRewriter takes a scalar evolution expression and updates + /// the SCEVUnknown components following the Map (Value -> Value). + class SCEVParameterRewriter : public SCEVRewriteVisitor { + public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + ValueToValueMap &Map, + bool InterpretConsts = false) { + SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); + return Rewriter.visit(Scev); + } + + SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) + : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} + const SCEV *visitUnknown(const SCEVUnknown *Expr) { Value *V = Expr->getValue(); if (Map.count(V)) { @@ -640,12 +657,7 @@ namespace llvm { return Expr; } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } - private: - ScalarEvolution &SE; ValueToValueMap ⤅ bool InterpretConsts; }; @@ -654,8 +666,7 @@ namespace llvm { /// The SCEVApplyRewriter takes a scalar evolution expression and applies /// the Map (Loop -> SCEV) to all AddRecExprs. - struct SCEVApplyRewriter - : public SCEVVisitor { + class SCEVApplyRewriter : public SCEVRewriteVisitor { public: static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, ScalarEvolution &SE) { @@ -663,45 +674,8 @@ namespace llvm { return Rewriter.visit(Scev); } - SCEVApplyRewriter(ScalarEvolution &S, LoopToScevMapT &M) - : SE(S), Map(M) {} - - const SCEV *visitConstant(const SCEVConstant *Constant) { - return Constant; - } - - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getTruncateExpr(Operand, Expr->getType()); - } - - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getZeroExtendExpr(Operand, Expr->getType()); - } - - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); - return SE.getSignExtendExpr(Operand, Expr->getType()); - } - - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getAddExpr(Operands); - } - - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getMulExpr(Operands); - } - - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - return SE.getUDivExpr(visit(Expr->getLHS()), visit(Expr->getRHS())); - } + SCEVApplyRewriter(ScalarEvolution &SE, LoopToScevMapT &M) + : SCEVRewriteVisitor(SE), Map(M) {} const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { SmallVector Operands; @@ -718,30 +692,7 @@ namespace llvm { return Rec->evaluateAtIteration(Map[L], SE); } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getSMaxExpr(Operands); - } - - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - SmallVector Operands; - for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) - Operands.push_back(visit(Expr->getOperand(i))); - return SE.getUMaxExpr(Operands); - } - - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - return Expr; - } - - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } - private: - ScalarEvolution &SE; LoopToScevMapT ⤅ }; -- 2.34.1