X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=2b8277864ebfd93f310c8eb965698db5fac086a3;hb=baeb911d60401818dc9fe0db6182cd048e4fdd03;hp=ee077d56b628ae3da838f9f9a75465a4cc86ae44;hpb=4dfad295475387211012f17bf640a9f8330872be;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index ee077d56b62..2b8277864eb 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -14,7 +14,7 @@ // There are several aspects to this library. First is the representation of // scalar expressions, which are represented as subclasses of the SCEV class. // These classes are used to represent certain types of subexpressions that we -// can handle. These classes are reference counted, managed by the SCEVHandle +// can handle. These classes are reference counted, managed by the const SCEV* // class. We only create one SCEV of a particular shape, so pointer-comparisons // for equality are legal. // @@ -68,6 +68,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Assembly/Writer.h" #include "llvm/Target/TargetData.h" #include "llvm/Support/CommandLine.h" @@ -75,7 +76,6 @@ #include "llvm/Support/ConstantRange.h" #include "llvm/Support/GetElementPtrTypeIterator.h" #include "llvm/Support/InstIterator.h" -#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/Statistic.h" @@ -95,7 +95,8 @@ STATISTIC(NumBruteForceTripCountsComputed, static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " - "symbolically execute a constant derived loop"), + "symbolically execute a constant " + "derived loop"), cl::init(100)); static RegisterPass @@ -132,8 +133,18 @@ bool SCEV::isOne() const { return false; } -SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} -SCEVCouldNotCompute::~SCEVCouldNotCompute() {} +bool SCEV::isAllOnesValue() const { + if (const SCEVConstant *SC = dyn_cast(this)) + return SC->getValue()->isAllOnesValue(); + return false; +} + +SCEVCouldNotCompute::SCEVCouldNotCompute() : + SCEV(scCouldNotCompute) {} + +void SCEVCouldNotCompute::Profile(FoldingSetNodeID &ID) const { + assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); +} bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); @@ -150,10 +161,11 @@ bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { return false; } -SCEVHandle SCEVCouldNotCompute:: -replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc, - ScalarEvolution &SE) const { +const SCEV * +SCEVCouldNotCompute::replaceSymbolicValuesWithConcrete( + const SCEV *Sym, + const SCEV *Conc, + ScalarEvolution &SE) const { return this; } @@ -165,25 +177,30 @@ bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } +const SCEV* ScalarEvolution::getConstant(ConstantInt *V) { + FoldingSetNodeID ID; + ID.AddInteger(scConstant); + ID.AddPointer(V); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVConstant(V); + UniqueSCEVs.InsertNode(S, IP); + return S; +} -// SCEVConstants - Only allow the creation of one SCEVConstant for any -// particular value. Don't use a SCEVHandle here, or else the object will -// never be deleted! -static ManagedStatic > SCEVConstants; - - -SCEVConstant::~SCEVConstant() { - SCEVConstants->erase(V); +const SCEV* ScalarEvolution::getConstant(const APInt& Val) { + return getConstant(ConstantInt::get(Val)); } -SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) { - SCEVConstant *&R = (*SCEVConstants)[V]; - if (R == 0) R = new SCEVConstant(V); - return R; +const SCEV* +ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) { + return getConstant(ConstantInt::get(cast(Ty), V, isSigned)); } -SCEVHandle ScalarEvolution::getConstant(const APInt& Val) { - return getConstant(ConstantInt::get(Val)); +void SCEVConstant::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(scConstant); + ID.AddPointer(V); } const Type *SCEVConstant::getType() const { return V->getType(); } @@ -193,89 +210,52 @@ void SCEVConstant::print(raw_ostream &OS) const { } SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy, - const SCEVHandle &op, const Type *ty) + const SCEV* op, const Type *ty) : SCEV(SCEVTy), Op(op), Ty(ty) {} -SCEVCastExpr::~SCEVCastExpr() {} +void SCEVCastExpr::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(getSCEVType()); + ID.AddPointer(Op); + ID.AddPointer(Ty); +} bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return Op->dominates(BB, DT); } -// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any -// particular input. Don't use a SCEVHandle here, or else the object will -// never be deleted! -static ManagedStatic, - SCEVTruncateExpr*> > SCEVTruncates; - -SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) +SCEVTruncateExpr::SCEVTruncateExpr(const SCEV* op, const Type *ty) : SCEVCastExpr(scTruncate, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot truncate non-integer value!"); } -SCEVTruncateExpr::~SCEVTruncateExpr() { - SCEVTruncates->erase(std::make_pair(Op, Ty)); -} - void SCEVTruncateExpr::print(raw_ostream &OS) const { OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any -// particular input. Don't use a SCEVHandle here, or else the object will never -// be deleted! -static ManagedStatic, - SCEVZeroExtendExpr*> > SCEVZeroExtends; - -SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEV* op, const Type *ty) : SCEVCastExpr(scZeroExtend, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot zero extend non-integer value!"); } -SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { - SCEVZeroExtends->erase(std::make_pair(Op, Ty)); -} - void SCEVZeroExtendExpr::print(raw_ostream &OS) const { OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any -// particular input. Don't use a SCEVHandle here, or else the object will never -// be deleted! -static ManagedStatic, - SCEVSignExtendExpr*> > SCEVSignExtends; - -SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEV* op, const Type *ty) : SCEVCastExpr(scSignExtend, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot sign extend non-integer value!"); } -SCEVSignExtendExpr::~SCEVSignExtendExpr() { - SCEVSignExtends->erase(std::make_pair(Op, Ty)); -} - void SCEVSignExtendExpr::print(raw_ostream &OS) const { OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any -// particular input. Don't use a SCEVHandle here, or else the object will never -// be deleted! -static ManagedStatic >, - SCEVCommutativeExpr*> > SCEVCommExprs; - -SCEVCommutativeExpr::~SCEVCommutativeExpr() { - std::vector SCEVOps(Operands.begin(), Operands.end()); - SCEVCommExprs->erase(std::make_pair(getSCEVType(), SCEVOps)); -} - void SCEVCommutativeExpr::print(raw_ostream &OS) const { assert(Operands.size() > 1 && "This plus expr shouldn't exist!"); const char *OpStr = getOperationStr(); @@ -285,15 +265,16 @@ void SCEVCommutativeExpr::print(raw_ostream &OS) const { OS << ")"; } -SCEVHandle SCEVCommutativeExpr:: -replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc, - ScalarEvolution &SE) const { +const SCEV * +SCEVCommutativeExpr::replaceSymbolicValuesWithConcrete( + const SCEV *Sym, + const SCEV *Conc, + ScalarEvolution &SE) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - SCEVHandle H = + const SCEV* H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(getNumOperands()); for (unsigned j = 0; j != i; ++j) NewOps.push_back(getOperand(j)); @@ -317,6 +298,13 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, return this; } +void SCEVNAryExpr::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(getSCEVType()); + ID.AddInteger(Operands.size()); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + ID.AddPointer(Operands[i]); +} + bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { if (!getOperand(i)->dominates(BB, DT)) @@ -325,15 +313,10 @@ bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return true; } - -// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular -// input. Don't use a SCEVHandle here, or else the object will never be -// deleted! -static ManagedStatic, - SCEVUDivExpr*> > SCEVUDivs; - -SCEVUDivExpr::~SCEVUDivExpr() { - SCEVUDivs->erase(std::make_pair(LHS, RHS)); +void SCEVUDivExpr::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); } bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { @@ -353,27 +336,23 @@ const Type *SCEVUDivExpr::getType() const { return RHS->getType(); } -// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any -// particular input. Don't use a SCEVHandle here, or else the object will never -// be deleted! -static ManagedStatic >, - SCEVAddRecExpr*> > SCEVAddRecExprs; - -SCEVAddRecExpr::~SCEVAddRecExpr() { - std::vector SCEVOps(Operands.begin(), Operands.end()); - SCEVAddRecExprs->erase(std::make_pair(L, SCEVOps)); +void SCEVAddRecExpr::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(scAddRecExpr); + ID.AddInteger(Operands.size()); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + ID.AddPointer(Operands[i]); + ID.AddPointer(L); } -SCEVHandle SCEVAddRecExpr:: -replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, - const SCEVHandle &Conc, - ScalarEvolution &SE) const { +const SCEV * +SCEVAddRecExpr::replaceSymbolicValuesWithConcrete(const SCEV *Sym, + const SCEV *Conc, + ScalarEvolution &SE) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - SCEVHandle H = + const SCEV* H = getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); if (H != getOperand(i)) { - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(getNumOperands()); for (unsigned j = 0; j != i; ++j) NewOps.push_back(getOperand(j)); @@ -390,12 +369,22 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { - // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't - // contain L and if the start is invariant. // Add recurrences are never invariant in the function-body (null loop). - return QueryLoop && - !QueryLoop->contains(L->getHeader()) && - getOperand(0)->isLoopInvariant(QueryLoop); + if (!QueryLoop) + return false; + + // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L. + if (QueryLoop->contains(L->getHeader())) + return false; + + // This recurrence is variant w.r.t. QueryLoop if any of its operands + // are variant. + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) + if (!getOperand(i)->isLoopInvariant(QueryLoop)) + return false; + + // Otherwise it's loop-invariant. + return true; } @@ -406,12 +395,10 @@ void SCEVAddRecExpr::print(raw_ostream &OS) const { OS << "}<" << L->getHeader()->getName() + ">"; } -// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular -// value. Don't use a SCEVHandle here, or else the object will never be -// deleted! -static ManagedStatic > SCEVUnknowns; - -SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); } +void SCEVUnknown::Profile(FoldingSetNodeID &ID) const { + ID.AddInteger(scUnknown); + ID.AddPointer(V); +} bool SCEVUnknown::isLoopInvariant(const Loop *L) const { // All non-instruction values are loop invariant. All instructions are loop @@ -567,7 +554,7 @@ namespace { /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. /// -static void GroupByComplexity(SmallVectorImpl &Ops, +static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop if (Ops.size() == 2) { @@ -610,7 +597,7 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// BinomialCoefficient - Compute BC(It, K). The result has width W. /// Assume, K > 0. -static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, +static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, ScalarEvolution &SE, const Type* ResultTy) { // Handle the simplest case efficiently. @@ -627,7 +614,7 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, // safe in modular arithmetic. // // However, this code doesn't use exactly that formula; the formula it uses - // is something like the following, where T is the number of factors of 2 in + // is something like the following, where T is the number of factors of 2 in // K! (i.e. trailing zeros in the binary representation of K!), and ^ is // exponentiation: // @@ -639,7 +626,7 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, // arithmetic. To do exact division in modular arithmetic, all we have // to do is multiply by the inverse. Therefore, this step can be done at // width W. - // + // // The next issue is how to safely do the division by 2^T. The way this // is done is by doing the multiplication step at a width of at least W + T // bits. This way, the bottom W+T bits of the product are accurate. Then, @@ -703,15 +690,15 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, // Calculate the product, at width T+W const IntegerType *CalculationTy = IntegerType::get(CalculationBits); - SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + const SCEV* Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { - SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + const SCEV* S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } // Divide by 2^T - SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + const SCEV* DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); // Truncate the result, and divide by K! / 2^T. @@ -728,14 +715,14 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, /// /// where BC(It, k) stands for binomial coefficient. /// -SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, +const SCEV* SCEVAddRecExpr::evaluateAtIteration(const SCEV* It, ScalarEvolution &SE) const { - SCEVHandle Result = getStart(); + const SCEV* Result = getStart(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - SCEVHandle Coeff = BinomialCoefficient(It, i, SE, getType()); + const SCEV* Coeff = BinomialCoefficient(It, i, SE, getType()); if (isa(Coeff)) return Coeff; @@ -748,7 +735,7 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, +const SCEV* ScalarEvolution::getTruncateExpr(const SCEV* Op, const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); @@ -757,8 +744,8 @@ SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, Ty = getEffectiveSCEVType(Ty); if (const SCEVConstant *SC = dyn_cast(Op)) - return getUnknown( - ConstantExpr::getTrunc(SC->getValue(), Ty)); + return getConstant( + cast(ConstantExpr::getTrunc(SC->getValue(), Ty))); // trunc(trunc(x)) --> trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast(Op)) @@ -772,21 +759,27 @@ SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty); - // If the input value is a chrec scev made out of constants, truncate - // all of the constants. + // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); return getAddRecExpr(Operands, AddRec->getLoop()); } - SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scTruncate); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVTruncateExpr(Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } -SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, +const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); @@ -798,7 +791,7 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, const Type *IntTy = getEffectiveSCEVType(Ty); Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy); if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); - return getUnknown(C); + return getConstant(cast(C)); } // zext(zext(x)) --> zext(x) @@ -819,28 +812,28 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - SCEVHandle Start = AR->getStart(); - SCEVHandle Step = AR->getStepRecurrence(*this); + const SCEV* Start = AR->getStart(); + const SCEV* Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - SCEVHandle CastedMaxBECount = + const SCEV* CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); - SCEVHandle RecastedMaxBECount = + const SCEV* RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { const Type *WideTy = IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. - SCEVHandle ZMul = + const SCEV* ZMul = getMulExpr(CastedMaxBECount, getTruncateOrZeroExtend(Step, Start->getType())); - SCEVHandle Add = getAddExpr(Start, ZMul); - SCEVHandle OperandExtendedAdd = + const SCEV* Add = getAddExpr(Start, ZMul); + const SCEV* OperandExtendedAdd = getAddExpr(getZeroExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getZeroExtendExpr(Step, WideTy))); @@ -852,7 +845,7 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, // Similar to above, only this time treat the step value as signed. // This covers loops that count down. - SCEVHandle SMul = + const SCEV* SMul = getMulExpr(CastedMaxBECount, getTruncateOrSignExtend(Step, Start->getType())); Add = getAddExpr(Start, SMul); @@ -869,12 +862,19 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, } } - SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scZeroExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVZeroExtendExpr(Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } -SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, +const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op, const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); @@ -886,7 +886,7 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *IntTy = getEffectiveSCEVType(Ty); Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy); if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); - return getUnknown(C); + return getConstant(cast(C)); } // sext(sext(x)) --> sext(x) @@ -907,28 +907,28 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - SCEVHandle Start = AR->getStart(); - SCEVHandle Step = AR->getStepRecurrence(*this); + const SCEV* Start = AR->getStart(); + const SCEV* Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - SCEVHandle CastedMaxBECount = + const SCEV* CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); - SCEVHandle RecastedMaxBECount = + const SCEV* RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { const Type *WideTy = IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); // Check whether Start+Step*MaxBECount has no signed overflow. - SCEVHandle SMul = + const SCEV* SMul = getMulExpr(CastedMaxBECount, getTruncateOrSignExtend(Step, Start->getType())); - SCEVHandle Add = getAddExpr(Start, SMul); - SCEVHandle OperandExtendedAdd = + const SCEV* Add = getAddExpr(Start, SMul); + const SCEV* OperandExtendedAdd = getAddExpr(getSignExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getSignExtendExpr(Step, WideTy))); @@ -941,15 +941,22 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, } } - SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scSignExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVSignExtendExpr(Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. /// -SCEVHandle ScalarEvolution::getAnyExtendExpr(const SCEVHandle &Op, +const SCEV* ScalarEvolution::getAnyExtendExpr(const SCEV* Op, const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); @@ -964,19 +971,19 @@ SCEVHandle ScalarEvolution::getAnyExtendExpr(const SCEVHandle &Op, // Peel off a truncate cast. if (const SCEVTruncateExpr *T = dyn_cast(Op)) { - SCEVHandle NewOp = T->getOperand(); + const SCEV* NewOp = T->getOperand(); if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) return getAnyExtendExpr(NewOp, Ty); return getTruncateOrNoop(NewOp, Ty); } // Next try a zext cast. If the cast is folded, use it. - SCEVHandle ZExt = getZeroExtendExpr(Op, Ty); + const SCEV* ZExt = getZeroExtendExpr(Op, Ty); if (!isa(ZExt)) return ZExt; // Next try a sext cast. If the cast is folded, use it. - SCEVHandle SExt = getSignExtendExpr(Op, Ty); + const SCEV* SExt = getSignExtendExpr(Op, Ty); if (!isa(SExt)) return SExt; @@ -988,9 +995,103 @@ SCEVHandle ScalarEvolution::getAnyExtendExpr(const SCEVHandle &Op, return ZExt; } +/// CollectAddOperandsWithScales - Process the given Ops list, which is +/// a list of operands to be added under the given scale, update the given +/// map. This is a helper function for getAddRecExpr. As an example of +/// what it does, given a sequence of operands that would form an add +/// expression like this: +/// +/// m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r) +/// +/// where A and B are constants, update the map with these values: +/// +/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) +/// +/// and add 13 + A*B*29 to AccumulatedConstant. +/// This will allow getAddRecExpr to produce this: +/// +/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) +/// +/// This form often exposes folding opportunities that are hidden in +/// the original operand list. +/// +/// Return true iff it appears that any interesting folding opportunities +/// may be exposed. This helps getAddRecExpr short-circuit extra work in +/// the common case where no interesting opportunities are present, and +/// is also used as a check to avoid infinite recursion. +/// +static bool +CollectAddOperandsWithScales(DenseMap &M, + SmallVector &NewOps, + APInt &AccumulatedConstant, + const SmallVectorImpl &Ops, + const APInt &Scale, + ScalarEvolution &SE) { + bool Interesting = false; + + // Iterate over the add operands. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + const SCEVMulExpr *Mul = dyn_cast(Ops[i]); + if (Mul && isa(Mul->getOperand(0))) { + APInt NewScale = + Scale * cast(Mul->getOperand(0))->getValue()->getValue(); + if (Mul->getNumOperands() == 2 && isa(Mul->getOperand(1))) { + // A multiplication of a constant with another add; recurse. + Interesting |= + CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + cast(Mul->getOperand(1)) + ->getOperands(), + NewScale, SE); + } else { + // A multiplication of a constant with some other value. Update + // the map. + SmallVector MulOps(Mul->op_begin()+1, Mul->op_end()); + const SCEV* Key = SE.getMulExpr(MulOps); + std::pair::iterator, bool> Pair = + M.insert(std::make_pair(Key, NewScale)); + if (Pair.second) { + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += NewScale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } else if (const SCEVConstant *C = dyn_cast(Ops[i])) { + // Pull a buried constant out to the outside. + if (Scale != 1 || AccumulatedConstant != 0 || C->isZero()) + Interesting = true; + AccumulatedConstant += Scale * C->getValue()->getValue(); + } else { + // An ordinary operand. Update the map. + std::pair::iterator, bool> Pair = + M.insert(std::make_pair(Ops[i], Scale)); + if (Pair.second) { + NewOps.push_back(Pair.first->first); + } else { + Pair.first->second += Scale; + // The map already had an entry for this value, which may indicate + // a folding opportunity. + Interesting = true; + } + } + } + + return Interesting; +} + +namespace { + struct APIntCompare { + bool operator()(const APInt &LHS, const APInt &RHS) const { + return LHS.ult(RHS); + } + }; +} + /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. -SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { +const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1012,8 +1113,8 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // We found two constants, fold them together! Ops[0] = getConstant(LHSC->getValue()->getValue() + RHSC->getValue()->getValue()); + if (Ops.size() == 2) return Ops[0]; Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; LHSC = cast(Ops[0]); } @@ -1034,8 +1135,8 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 // Found a match, merge the two values into a multiply, and add any // remaining values to the result. - SCEVHandle Two = getIntegerSCEV(2, Ty); - SCEVHandle Mul = getMulExpr(Ops[i], Two); + const SCEV* Two = getIntegerSCEV(2, Ty); + const SCEV* Mul = getMulExpr(Ops[i], Two); if (Ops.size() == 2) return Mul; Ops.erase(Ops.begin()+i, Ops.begin()+i+2); @@ -1051,7 +1152,7 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); const Type *DstType = Trunc->getType(); const Type *SrcType = Trunc->getOperand()->getType(); - SmallVector LargeOps; + SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. @@ -1067,7 +1168,7 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // is much more likely to be foldable here. LargeOps.push_back(getSignExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast(Ops[i])) { - SmallVector LargeMulOps; + SmallVector LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast(M->getOperand(j))) { @@ -1095,7 +1196,7 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { } if (Ok) { // Evaluate the expression in the larger type. - SCEVHandle Fold = getAddExpr(LargeOps); + const SCEV* Fold = getAddExpr(LargeOps); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); @@ -1128,6 +1229,39 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) ++Idx; + // Check to see if there are any folding opportunities present with + // operands multiplied by constant values. + if (Idx < Ops.size() && isa(Ops[Idx])) { + uint64_t BitWidth = getTypeSizeInBits(Ty); + DenseMap M; + SmallVector NewOps; + APInt AccumulatedConstant(BitWidth, 0); + if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, + Ops, APInt(BitWidth, 1), *this)) { + // Some interesting folding opportunity is present, so its worthwhile to + // re-generate the operands list. Group the operands by constant scale, + // to avoid multiplying by the same constant scale multiple times. + std::map, APIntCompare> MulOpLists; + for (SmallVector::iterator I = NewOps.begin(), + E = NewOps.end(); I != E; ++I) + MulOpLists[M.find(*I)->second].push_back(*I); + // Re-generate the operands list. + Ops.clear(); + if (AccumulatedConstant != 0) + Ops.push_back(getConstant(AccumulatedConstant)); + for (std::map, APIntCompare>::iterator + I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) + if (I->first != 0) + Ops.push_back(getMulExpr(getConstant(I->first), + getAddExpr(I->second))); + if (Ops.empty()) + return getIntegerSCEV(0, Ty); + if (Ops.size() == 1) + return Ops[0]; + return getAddExpr(Ops); + } + } + // If we are adding something to a multiply expression, make sure the // something is not already an operand of the multiply. If so, merge it into // the multiply. @@ -1138,17 +1272,17 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) if (MulOpSCEV == Ops[AddOp] && !isa(Ops[AddOp])) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) - SCEVHandle InnerMul = Mul->getOperand(MulOp == 0); + const SCEV* InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - SmallVector MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul = getMulExpr(MulOps); } - SCEVHandle One = getIntegerSCEV(1, Ty); - SCEVHandle AddOne = getAddExpr(InnerMul, One); - SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]); + const SCEV* One = getIntegerSCEV(1, Ty); + const SCEV* AddOne = getAddExpr(InnerMul, One); + const SCEV* OuterMul = getMulExpr(AddOne, Ops[AddOp]); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -1172,21 +1306,22 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { OMulOp != e; ++OMulOp) if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) - SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0); + const SCEV* InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { - SmallVector MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector MulOps(Mul->op_begin(), + Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul1 = getMulExpr(MulOps); } - SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0); + const SCEV* InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { - SmallVector MulOps(OtherMul->op_begin(), - OtherMul->op_end()); + SmallVector MulOps(OtherMul->op_begin(), + OtherMul->op_end()); MulOps.erase(MulOps.begin()+OMulOp); InnerMul2 = getMulExpr(MulOps); } - SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2); - SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); + const SCEV* InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + const SCEV* OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); @@ -1207,7 +1342,7 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1221,11 +1356,11 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - SmallVector AddRecOps(AddRec->op_begin(), + SmallVector AddRecOps(AddRec->op_begin(), AddRec->op_end()); AddRecOps[0] = getAddExpr(LIOps); - SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); + const SCEV* NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1247,7 +1382,8 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { const SCEVAddRecExpr *OtherAddRec = cast(Ops[OtherIdx]); if (AddRec->getLoop() == OtherAddRec->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - SmallVector NewOps(AddRec->op_begin(), AddRec->op_end()); + SmallVector NewOps(AddRec->op_begin(), + AddRec->op_end()); for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { if (i >= NewOps.size()) { NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i, @@ -1256,7 +1392,7 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { } NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i)); } - SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); + const SCEV* NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); if (Ops.size() == 2) return NewAddRec; @@ -1273,17 +1409,23 @@ SCEVHandle ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVAddExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scAddExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVAddExpr(Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. -SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { +const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty mul!"); #ifndef NDEBUG for (unsigned i = 1, e = Ops.size(); i != e; ++i) @@ -1311,7 +1453,7 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * RHSC->getValue()->getValue()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element @@ -1364,7 +1506,7 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1376,7 +1518,7 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { const SCEV *Scale = LIOps[0]; @@ -1384,13 +1526,13 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); } else { for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - SmallVector MulOps(LIOps.begin(), LIOps.end()); + SmallVector MulOps(LIOps.begin(), LIOps.end()); MulOps.push_back(AddRec->getOperand(i)); NewOps.push_back(getMulExpr(MulOps)); } } - SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); + const SCEV* NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1414,14 +1556,14 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { if (AddRec->getLoop() == OtherAddRec->getLoop()) { // F * G --> {A,+,B} * {C,+,D} --> {A*C,+,F*D + G*B + B*D} const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; - SCEVHandle NewStart = getMulExpr(F->getStart(), + const SCEV* NewStart = getMulExpr(F->getStart(), G->getStart()); - SCEVHandle B = F->getStepRecurrence(*this); - SCEVHandle D = G->getStepRecurrence(*this); - SCEVHandle NewStep = getAddExpr(getMulExpr(F, D), + const SCEV* B = F->getStepRecurrence(*this); + const SCEV* D = G->getStepRecurrence(*this); + const SCEV* NewStep = getAddExpr(getMulExpr(F, D), getMulExpr(G, B), getMulExpr(B, D)); - SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep, + const SCEV* NewAddRec = getAddRecExpr(NewStart, NewStep, F->getLoop()); if (Ops.size() == 2) return NewAddRec; @@ -1438,18 +1580,23 @@ SCEVHandle ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr, - SCEVOps)]; - if (Result == 0) - Result = new SCEVMulExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scMulExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVMulExpr(Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getUDivExpr - Get a canonical multiply expression, or something simpler if /// possible. -SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, - const SCEVHandle &RHS) { +const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, + const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); @@ -1482,24 +1629,24 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop())) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); return getAddRecExpr(Operands, AR->getLoop()); } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { - SCEVHandle Op = M->getOperand(i); - SCEVHandle Div = getUDivExpr(Op, RHSC); + const SCEV* Op = M->getOperand(i); + const SCEV* Div = getUDivExpr(Op, RHSC); if (!isa(Div) && getMulExpr(Div, RHSC) == Op) { - const SmallVectorImpl &MOperands = M->getOperands(); - Operands = SmallVector(MOperands.begin(), + const SmallVectorImpl &MOperands = M->getOperands(); + Operands = SmallVector(MOperands.begin(), MOperands.end()); Operands[i] = Div; return getMulExpr(Operands); @@ -1508,13 +1655,13 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, } // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddRecExpr *A = dyn_cast(LHS)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { - SCEVHandle Op = getUDivExpr(A->getOperand(i), RHS); + const SCEV* Op = getUDivExpr(A->getOperand(i), RHS); if (isa(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) break; Operands.push_back(Op); @@ -1528,21 +1675,29 @@ SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, if (const SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); - return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV)); + return getConstant(cast(ConstantExpr::getUDiv(LHSCV, + RHSCV))); } } - SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUDivExpr(LHS, RHS); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, - const SCEVHandle &Step, const Loop *L) { - SmallVector Operands; +const SCEV* ScalarEvolution::getAddRecExpr(const SCEV* Start, + const SCEV* Step, const Loop *L) { + SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { @@ -1557,8 +1712,9 @@ SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -SCEVHandle ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L) { +const SCEV * +ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, + const Loop *L) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG for (unsigned i = 1, e = Operands.size(); i != e; ++i) @@ -1576,31 +1732,59 @@ SCEVHandle ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { const Loop* NestedLoop = NestedAR->getLoop(); if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { - SmallVector NestedOperands(NestedAR->op_begin(), + SmallVector NestedOperands(NestedAR->op_begin(), NestedAR->op_end()); - SCEVHandle NestedARHandle(NestedAR); Operands[0] = NestedAR->getStart(); - NestedOperands[0] = getAddRecExpr(Operands, L); - return getAddRecExpr(NestedOperands, NestedLoop); + // AddRecs require their operands be loop-invariant with respect to their + // loops. Don't perform this transformation if it would break this + // requirement. + bool AllInvariant = true; + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + if (!Operands[i]->isLoopInvariant(L)) { + AllInvariant = false; + break; + } + if (AllInvariant) { + NestedOperands[0] = getAddRecExpr(Operands, L); + AllInvariant = true; + for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) + if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) { + AllInvariant = false; + break; + } + if (AllInvariant) + // Ok, both add recurrences are valid after the transformation. + return getAddRecExpr(NestedOperands, NestedLoop); + } + // Reset Operands to its original state. + Operands[0] = NestedAR; } } - std::vector SCEVOps(Operands.begin(), Operands.end()); - SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, SCEVOps)]; - if (Result == 0) Result = new SCEVAddRecExpr(Operands, L); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddInteger(Operands.size()); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + ID.AddPointer(Operands[i]); + ID.AddPointer(L); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVAddRecExpr(Operands, L); + UniqueSCEVs.InsertNode(S, IP); + return S; } -SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, - const SCEVHandle &RHS) { - SmallVector Ops; +const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getSMaxExpr(Ops); } -SCEVHandle -ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { +const SCEV* +ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1629,10 +1813,14 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { LHSC = cast(Ops[0]); } - // If we are left with a constant -inf, strip it off. + // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(true)) { Ops.erase(Ops.begin()); --Idx; + } else if (cast(Ops[0])->getValue()->isMaxValue(true)) { + // If we have an smax with a constant maximum-int, it will always be + // maximum-int. + return Ops[0]; } } @@ -1671,23 +1859,29 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an smax expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVSMaxExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scSMaxExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVSMaxExpr(Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } -SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, - const SCEVHandle &RHS) { - SmallVector Ops; +const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getUMaxExpr(Ops); } -SCEVHandle -ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { +const SCEV* +ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty umax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1716,10 +1910,14 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { LHSC = cast(Ops[0]); } - // If we are left with a constant zero, strip it off. + // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(false)) { Ops.erase(Ops.begin()); --Idx; + } else if (cast(Ops[0])->getValue()->isMaxValue(false)) { + // If we have an umax with a constant maximum-int, it will always be + // maximum-int. + return Ops[0]; } } @@ -1758,21 +1956,46 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need a umax expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVUMaxExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scUMaxExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUMaxExpr(Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } -SCEVHandle ScalarEvolution::getUnknown(Value *V) { - if (ConstantInt *CI = dyn_cast(V)) - return getConstant(CI); - if (isa(V)) - return getIntegerSCEV(0, V->getType()); - SCEVUnknown *&Result = (*SCEVUnknowns)[V]; - if (Result == 0) Result = new SCEVUnknown(V); - return Result; +const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, + const SCEV *RHS) { + // ~smax(~x, ~y) == smin(x, y). + return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); +} + +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, + const SCEV *RHS) { + // ~umax(~x, ~y) == umin(x, y) + return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); +} + +const SCEV* ScalarEvolution::getUnknown(Value *V) { + // Don't attempt to do anything other than create a SCEVUnknown object + // here. createSCEV only calls getUnknown after checking for all other + // interesting possibilities, and any other code that calls getUnknown + // is doing so in order to hide a value from SCEV canonicalization. + + FoldingSetNodeID ID; + ID.AddInteger(scUnknown); + ID.AddPointer(V); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUnknown(V); + UniqueSCEVs.InsertNode(S, IP); + return S; } //===----------------------------------------------------------------------===// @@ -1825,8 +2048,8 @@ const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const { return TD->getIntPtrType(); } -SCEVHandle ScalarEvolution::getCouldNotCompute() { - return CouldNotCompute; +const SCEV* ScalarEvolution::getCouldNotCompute() { + return &CouldNotCompute; } /// hasSCEV - Return true if the SCEV for this value has already been @@ -1837,36 +2060,28 @@ bool ScalarEvolution::hasSCEV(Value *V) const { /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the /// expression and create a new one. -SCEVHandle ScalarEvolution::getSCEV(Value *V) { +const SCEV* ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - std::map::iterator I = Scalars.find(V); + std::map::iterator I = Scalars.find(V); if (I != Scalars.end()) return I->second; - SCEVHandle S = createSCEV(V); + const SCEV* S = createSCEV(V); Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S)); return S; } -/// getIntegerSCEV - Given an integer or FP type, create a constant for the +/// getIntegerSCEV - Given a SCEVable type, create a constant for the /// specified signed integer value and return a SCEV for the constant. -SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { - Ty = getEffectiveSCEVType(Ty); - Constant *C; - if (Val == 0) - C = Constant::getNullValue(Ty); - else if (Ty->isFloatingPoint()) - C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : - APFloat::IEEEdouble, Val)); - else - C = ConstantInt::get(Ty, Val); - return getUnknown(C); +const SCEV* ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { + const IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); + return getConstant(ConstantInt::get(ITy, Val)); } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// -SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { +const SCEV* ScalarEvolution::getNegativeSCEV(const SCEV* V) { if (const SCEVConstant *VC = dyn_cast(V)) - return getUnknown(ConstantExpr::getNeg(VC->getValue())); + return getConstant(cast(ConstantExpr::getNeg(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); @@ -1874,20 +2089,20 @@ SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V -SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { +const SCEV* ScalarEvolution::getNotSCEV(const SCEV* V) { if (const SCEVConstant *VC = dyn_cast(V)) - return getUnknown(ConstantExpr::getNot(VC->getValue())); + return getConstant(cast(ConstantExpr::getNot(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty)); + const SCEV* AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty)); return getMinusSCEV(AllOnes, V); } /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. /// -SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, - const SCEVHandle &RHS) { +const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, + const SCEV *RHS) { // X - Y --> X + -Y return getAddExpr(LHS, getNegativeSCEV(RHS)); } @@ -1895,8 +2110,8 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. -SCEVHandle -ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, +const SCEV* +ScalarEvolution::getTruncateOrZeroExtend(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && @@ -1912,8 +2127,8 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is sign /// extended. -SCEVHandle -ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V, +const SCEV* +ScalarEvolution::getTruncateOrSignExtend(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && @@ -1929,8 +2144,8 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V, /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. The conversion must not be narrowing. -SCEVHandle -ScalarEvolution::getNoopOrZeroExtend(const SCEVHandle &V, const Type *Ty) { +const SCEV* +ScalarEvolution::getNoopOrZeroExtend(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && (Ty->isInteger() || (TD && isa(Ty))) && @@ -1945,8 +2160,8 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEVHandle &V, const Type *Ty) { /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is sign /// extended. The conversion must not be narrowing. -SCEVHandle -ScalarEvolution::getNoopOrSignExtend(const SCEVHandle &V, const Type *Ty) { +const SCEV* +ScalarEvolution::getNoopOrSignExtend(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && (Ty->isInteger() || (TD && isa(Ty))) && @@ -1962,8 +2177,8 @@ ScalarEvolution::getNoopOrSignExtend(const SCEVHandle &V, const Type *Ty) { /// the input value to the specified type. If the type must be extended, /// it is extended with unspecified bits. The conversion must not be /// narrowing. -SCEVHandle -ScalarEvolution::getNoopOrAnyExtend(const SCEVHandle &V, const Type *Ty) { +const SCEV* +ScalarEvolution::getNoopOrAnyExtend(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && (Ty->isInteger() || (TD && isa(Ty))) && @@ -1977,8 +2192,8 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEVHandle &V, const Type *Ty) { /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the /// input value to the specified type. The conversion must not be widening. -SCEVHandle -ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) { +const SCEV* +ScalarEvolution::getTruncateOrNoop(const SCEV* V, const Type *Ty) { const Type *SrcTy = V->getType(); assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && (Ty->isInteger() || (TD && isa(Ty))) && @@ -1990,17 +2205,50 @@ ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) { return getTruncateExpr(V, Ty); } +/// getUMaxFromMismatchedTypes - Promote the operands to the wider of +/// the types using zero-extension, and then perform a umax operation +/// with them. +const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + const SCEV* PromotedLHS = LHS; + const SCEV* PromotedRHS = RHS; + + if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) + PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); + else + PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + + return getUMaxExpr(PromotedLHS, PromotedRHS); +} + +/// getUMinFromMismatchedTypes - Promote the operands to the wider of +/// the types using zero-extension, and then perform a umin operation +/// with them. +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + const SCEV* PromotedLHS = LHS; + const SCEV* PromotedRHS = RHS; + + if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) + PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); + else + PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + + return getUMinExpr(PromotedLHS, PromotedRHS); +} + /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for /// the specified instruction and replaces any references to the symbolic value /// SymName with the specified value. This is used during PHI resolution. -void ScalarEvolution:: -ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName, - const SCEVHandle &NewVal) { - std::map::iterator SI = +void +ScalarEvolution::ReplaceSymbolicValueWithConcrete(Instruction *I, + const SCEV *SymName, + const SCEV *NewVal) { + std::map::iterator SI = Scalars.find(SCEVCallbackVH(I, this)); if (SI == Scalars.end()) return; - SCEVHandle NV = + const SCEV* NV = SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this); if (NV == SI->second) return; // No change. @@ -2016,7 +2264,7 @@ ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName, /// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in /// a loop header, making it a potential recurrence, or it doesn't. /// -SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { +const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { if (PN->getNumIncomingValues() == 2) // The loops have been canonicalized. if (const Loop *L = LI->getLoopFor(PN->getParent())) if (L->getHeader() == PN->getParent()) { @@ -2026,14 +2274,14 @@ SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { unsigned BackEdge = IncomingEdge^1; // While we are analyzing this PHI node, handle its value symbolically. - SCEVHandle SymbolicName = getUnknown(PN); + const SCEV* SymbolicName = getUnknown(PN); assert(Scalars.find(PN) == Scalars.end() && "PHI node already processed?"); Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. - SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge)); + const SCEV* BEValue = getSCEV(PN->getIncomingValue(BackEdge)); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. @@ -2053,19 +2301,21 @@ SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); - SCEVHandle Accum = getAddExpr(Ops); + const SCEV* Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. if (Accum->isLoopInvariant(L) || (isa(Accum) && cast(Accum)->getLoop() == L)) { - SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); - SCEVHandle PHISCEV = getAddRecExpr(StartVal, Accum, L); + const SCEV *StartVal = + getSCEV(PN->getIncomingValue(IncomingEdge)); + const SCEV *PHISCEV = + getAddRecExpr(StartVal, Accum, L); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and update all of the @@ -2084,13 +2334,13 @@ SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. if (AddRec->getLoop() == L && AddRec->isAffine()) { - SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); + const SCEV* StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); // If StartVal = j.start - j.stride, we can use StartVal as the // initial step of the addrec evolution. if (StartVal == getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { - SCEVHandle PHISCEV = + const SCEV* PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1), L); // Okay, for the entire analysis of this edge we assumed the PHI @@ -2114,14 +2364,14 @@ SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { /// createNodeForGEP - Expand GEP instructions into add and multiply /// operations. This allows them to be analyzed by regular SCEV code. /// -SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) { +const SCEV* ScalarEvolution::createNodeForGEP(User *GEP) { const Type *IntPtrTy = TD->getIntPtrType(); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!cast(Base->getType())->getElementType()->isSized()) return getUnknown(GEP); - SCEVHandle TotalOffset = getIntegerSCEV(0, IntPtrTy); + const SCEV* TotalOffset = getIntegerSCEV(0, IntPtrTy); gep_type_iterator GTI = gep_type_begin(GEP); for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()), E = GEP->op_end(); @@ -2137,7 +2387,7 @@ SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) { getIntegerSCEV(Offset, IntPtrTy)); } else { // For an array, add the element offset, explicitly scaled. - SCEVHandle LocalOffset = getSCEV(Index); + const SCEV* LocalOffset = getSCEV(Index); if (!isa(LocalOffset->getType())) // Getelementptr indicies are signed. LocalOffset = getTruncateOrSignExtend(LocalOffset, @@ -2156,77 +2406,170 @@ SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) { /// guaranteed to end in (at every loop iteration). It is, at the same time, /// the minimum number of times S is divisible by 2. For example, given {4,+,8} /// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. -static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) { +uint32_t +ScalarEvolution::GetMinTrailingZeros(const SCEV* S) { if (const SCEVConstant *C = dyn_cast(S)) return C->getValue()->getValue().countTrailingZeros(); if (const SCEVTruncateExpr *T = dyn_cast(S)) - return std::min(GetMinTrailingZeros(T->getOperand(), SE), - (uint32_t)SE.getTypeSizeInBits(T->getType())); + return std::min(GetMinTrailingZeros(T->getOperand()), + (uint32_t)getTypeSizeInBits(T->getType())); if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); - return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? - SE.getTypeSizeInBits(E->getType()) : OpRes; + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); - return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? - SE.getTypeSizeInBits(E->getType()) : OpRes; + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? + getTypeSizeInBits(E->getType()) : OpRes; } if (const SCEVAddExpr *A = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVMulExpr *M = dyn_cast(S)) { // The result is the sum of all operands results. - uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE); - uint32_t BitWidth = SE.getTypeSizeInBits(M->getType()); + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE), + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } if (const SCEVAddRecExpr *A = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); return MinOpRes; } if (const SCEVSMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } if (const SCEVUMaxExpr *M = dyn_cast(S)) { // The result is the min of all operands results. - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); return MinOpRes; } - // SCEVUDivExpr, SCEVUnknown + if (const SCEVUnknown *U = dyn_cast(S)) { + // For a SCEVUnknown, ask ValueTracking. + unsigned BitWidth = getTypeSizeInBits(U->getType()); + APInt Mask = APInt::getAllOnesValue(BitWidth); + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones); + return Zeros.countTrailingOnes(); + } + + // SCEVUDivExpr return 0; } +uint32_t +ScalarEvolution::GetMinLeadingZeros(const SCEV* S) { + // TODO: Handle other SCEV expression types here. + + if (const SCEVConstant *C = dyn_cast(S)) + return C->getValue()->getValue().countLeadingZeros(); + + if (const SCEVZeroExtendExpr *C = dyn_cast(S)) { + // A zero-extension cast adds zero bits. + return GetMinLeadingZeros(C->getOperand()) + + (getTypeSizeInBits(C->getType()) - + getTypeSizeInBits(C->getOperand()->getType())); + } + + if (const SCEVUnknown *U = dyn_cast(S)) { + // For a SCEVUnknown, ask ValueTracking. + unsigned BitWidth = getTypeSizeInBits(U->getType()); + APInt Mask = APInt::getAllOnesValue(BitWidth); + APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); + ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); + return Zeros.countLeadingOnes(); + } + + return 1; +} + +uint32_t +ScalarEvolution::GetMinSignBits(const SCEV* S) { + // TODO: Handle other SCEV expression types here. + + if (const SCEVConstant *C = dyn_cast(S)) { + const APInt &A = C->getValue()->getValue(); + return A.isNegative() ? A.countLeadingOnes() : + A.countLeadingZeros(); + } + + if (const SCEVSignExtendExpr *C = dyn_cast(S)) { + // A sign-extension cast adds sign bits. + return GetMinSignBits(C->getOperand()) + + (getTypeSizeInBits(C->getType()) - + getTypeSizeInBits(C->getOperand()->getType())); + } + + if (const SCEVAddExpr *A = dyn_cast(S)) { + unsigned BitWidth = getTypeSizeInBits(A->getType()); + + // Special case decrementing a value (ADD X, -1): + if (const SCEVConstant *CRHS = dyn_cast(A->getOperand(0))) + if (CRHS->isAllOnesValue()) { + SmallVector OtherOps(A->op_begin() + 1, A->op_end()); + const SCEV *OtherOpsAdd = getAddExpr(OtherOps); + unsigned LZ = GetMinLeadingZeros(OtherOpsAdd); + + // If the input is known to be 0 or 1, the output is 0/-1, which is all + // sign bits set. + if (LZ == BitWidth - 1) + return BitWidth; + + // If we are subtracting one from a positive number, there is no carry + // out of the result. + if (LZ > 0) + return GetMinSignBits(OtherOpsAdd); + } + + // Add can have at most one carry bit. Thus we know that the output + // is, at worst, one more bit than the inputs. + unsigned Min = BitWidth; + for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { + unsigned N = GetMinSignBits(A->getOperand(i)); + Min = std::min(Min, N) - 1; + if (Min == 0) return 1; + } + return 1; + } + + if (const SCEVUnknown *U = dyn_cast(S)) { + // For a SCEVUnknown, ask ValueTracking. + return ComputeNumSignBits(U->getValue(), TD); + } + + return 1; +} + /// createSCEV - We know that there is no SCEV for the specified value. /// Analyze the expression. /// -SCEVHandle ScalarEvolution::createSCEV(Value *V) { +const SCEV* ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -2235,6 +2578,12 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { Opcode = I->getOpcode(); else if (ConstantExpr *CE = dyn_cast(V)) Opcode = CE->getOpcode(); + else if (ConstantInt *CI = dyn_cast(V)) + return getConstant(CI); + else if (isa(V)) + return getIntegerSCEV(0, V->getType()); + else if (isa(V)) + return getIntegerSCEV(0, V->getType()); else return getUnknown(V); @@ -2261,14 +2610,27 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (CI->isAllOnesValue()) return getSCEV(U->getOperand(0)); const APInt &A = CI->getValue(); - unsigned Ones = A.countTrailingOnes(); - if (APIntOps::isMask(Ones, A)) + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use ComputeMaskedBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned BitWidth = A.getBitWidth(); + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD); + + APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ); + + if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask)) return getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), - IntegerType::get(Ones)), + IntegerType::get(BitWidth - LZ)), U->getType()); } break; + case Instruction::Or: // If the RHS of the Or is a constant, we may have something like: // X*4+1 which got turned into X*4|1. Handle this as an Add so loop @@ -2277,9 +2639,9 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { // In order for this transformation to be safe, the LHS must be of the // form X*(2^n) and the Or constant must be less than 2^n. if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { - SCEVHandle LHS = getSCEV(U->getOperand(0)); + const SCEV* LHS = getSCEV(U->getOperand(0)); const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS, *this) >= + if (GetMinTrailingZeros(LHS) >= (CIVal.getBitWidth() - CIVal.countLeadingZeros())) return getAddExpr(LHS, getSCEV(U->getOperand(1))); } @@ -2305,9 +2667,27 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (BO->getOpcode() == Instruction::And && LCI->getValue() == CI->getValue()) if (const SCEVZeroExtendExpr *Z = - dyn_cast(getSCEV(U->getOperand(0)))) - return getZeroExtendExpr(getNotSCEV(Z->getOperand()), - U->getType()); + dyn_cast(getSCEV(U->getOperand(0)))) { + const Type *UTy = U->getType(); + const SCEV* Z0 = Z->getOperand(); + const Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is zerving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (APIntOps::isMask(Z0TySize, CI->getValue())) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize); + if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignBit()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } } break; @@ -2398,10 +2778,7 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) return getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) - // ~smax(~x, ~y) == smin(x, y). - return getNotSCEV(getSMaxExpr( - getNotSCEV(getSCEV(LHS)), - getNotSCEV(getSCEV(RHS)))); + return getSMinExpr(getSCEV(LHS), getSCEV(RHS)); break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2412,9 +2789,25 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) return getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) - // ~umax(~x, ~y) == umin(x, y) - return getNotSCEV(getUMaxExpr(getNotSCEV(getSCEV(LHS)), - getNotSCEV(getSCEV(RHS)))); + return getUMinExpr(getSCEV(LHS), getSCEV(RHS)); + break; + case ICmpInst::ICMP_NE: + // n != 0 ? n : 1 -> umax(n, 1) + if (LHS == U->getOperand(1) && + isa(U->getOperand(2)) && + cast(U->getOperand(2))->isOne() && + isa(RHS) && + cast(RHS)->isZero()) + return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2))); + break; + case ICmpInst::ICMP_EQ: + // n == 0 ? 1 : n -> umax(n, 1) + if (LHS == U->getOperand(2) && + isa(U->getOperand(1)) && + cast(U->getOperand(1))->isOne() && + isa(RHS) && + cast(RHS)->isZero()) + return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1))); break; default: break; @@ -2445,14 +2838,14 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) { /// loop-invariant backedge-taken count (see /// hasLoopInvariantBackedgeTakenCount). /// -SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) { +const SCEV* ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).Exact; } /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except /// return the least SCEV value that is known never to be less than the /// actual backedge taken count. -SCEVHandle ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { +const SCEV* ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).Max; } @@ -2467,7 +2860,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute())); if (Pair.second) { BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L); - if (ItCount.Exact != CouldNotCompute) { + if (ItCount.Exact != getCouldNotCompute()) { assert(ItCount.Exact->isLoopInvariant(L) && ItCount.Max->isLoopInvariant(L) && "Computed trip count isn't loop invariant for loop!"); @@ -2475,9 +2868,13 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Update the value in the map. Pair.first->second = ItCount; - } else if (isa(L->getHeader()->begin())) { - // Only count loops that have phi nodes as not being computable. - ++NumTripCountsNotComputed; + } else { + if (ItCount.Max != getCouldNotCompute()) + // Update the value in the map. + Pair.first->second = ItCount; + if (isa(L->getHeader()->begin())) + // Only count loops that have phi nodes as not being computable. + ++NumTripCountsNotComputed; } // Now that we know more about the trip count for this loop, forget any @@ -2515,7 +2912,8 @@ void ScalarEvolution::forgetLoopPHIs(const Loop *L) { SmallVector Worklist; for (BasicBlock::iterator I = Header->begin(); PHINode *PN = dyn_cast(I); ++I) { - std::map::iterator It = Scalars.find((Value*)I); + std::map::iterator It = + Scalars.find((Value*)I); if (It != Scalars.end() && !isa(It->second)) Worklist.push_back(PN); } @@ -2533,25 +2931,51 @@ void ScalarEvolution::forgetLoopPHIs(const Loop *L) { /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { - // If the loop has a non-one exit block count, we can't analyze it. - BasicBlock *ExitBlock = L->getExitBlock(); - if (!ExitBlock) - return CouldNotCompute; - - // Okay, there is one exit block. Try to find the condition that causes the - // loop to be exited. - BasicBlock *ExitingBlock = L->getExitingBlock(); - if (!ExitingBlock) - return CouldNotCompute; // More than one block exiting! - - // Okay, we've computed the exiting block. See what condition causes us to - // exit. + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Examine all exits and pick the most conservative values. + const SCEV* BECount = getCouldNotCompute(); + const SCEV* MaxBECount = getCouldNotCompute(); + bool CouldNotComputeBECount = false; + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + BackedgeTakenInfo NewBTI = + ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]); + + if (NewBTI.Exact == getCouldNotCompute()) { + // We couldn't compute an exact value for this exit, so + // we won't be able to compute an exact value for the loop. + CouldNotComputeBECount = true; + BECount = getCouldNotCompute(); + } else if (!CouldNotComputeBECount) { + if (BECount == getCouldNotCompute()) + BECount = NewBTI.Exact; + else + BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact); + } + if (MaxBECount == getCouldNotCompute()) + MaxBECount = NewBTI.Max; + else if (NewBTI.Max != getCouldNotCompute()) + MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max); + } + + return BackedgeTakenInfo(BECount, MaxBECount); +} + +/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge +/// of the specified loop will execute if it exits via the specified block. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, + BasicBlock *ExitingBlock) { + + // Okay, we've chosen an exiting block. See what condition causes us to + // exit at this block. // // FIXME: we should be able to handle switch instructions (with a single exit) BranchInst *ExitBr = dyn_cast(ExitingBlock->getTerminator()); - if (ExitBr == 0) return CouldNotCompute; + if (ExitBr == 0) return getCouldNotCompute(); assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!"); - + // At this point, we know we have a conditional branch that determines whether // the loop is exited. However, we don't know if the branch is executed each // time through the loop. If not, then the execution count of the branch will @@ -2560,23 +2984,154 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // Currently we check for this by checking to see if the Exit branch goes to // the loop header. If so, we know it will always execute the same number of // times as the loop. We also handle the case where the exit block *is* the - // loop header. This is common for un-rotated loops. More extensive analysis - // could be done to handle more cases here. + // loop header. This is common for un-rotated loops. + // + // If both of those tests fail, walk up the unique predecessor chain to the + // header, stopping if there is an edge that doesn't exit the loop. If the + // header is reached, the execution count of the branch will be equal to the + // trip count of the loop. + // + // More extensive analysis could be done to handle more cases here. + // if (ExitBr->getSuccessor(0) != L->getHeader() && ExitBr->getSuccessor(1) != L->getHeader() && - ExitBr->getParent() != L->getHeader()) - return CouldNotCompute; - - ICmpInst *ExitCond = dyn_cast(ExitBr->getCondition()); + ExitBr->getParent() != L->getHeader()) { + // The simple checks failed, try climbing the unique predecessor chain + // up to the header. + bool Ok = false; + for (BasicBlock *BB = ExitBr->getParent(); BB; ) { + BasicBlock *Pred = BB->getUniquePredecessor(); + if (!Pred) + return getCouldNotCompute(); + TerminatorInst *PredTerm = Pred->getTerminator(); + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { + BasicBlock *PredSucc = PredTerm->getSuccessor(i); + if (PredSucc == BB) + continue; + // If the predecessor has a successor that isn't BB and isn't + // outside the loop, assume the worst. + if (L->contains(PredSucc)) + return getCouldNotCompute(); + } + if (Pred == L->getHeader()) { + Ok = true; + break; + } + BB = Pred; + } + if (!Ok) + return getCouldNotCompute(); + } + + // Procede to the next level to examine the exit condition expression. + return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(), + ExitBr->getSuccessor(0), + ExitBr->getSuccessor(1)); +} + +/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the +/// backedge of the specified loop will execute if its exit condition +/// were a conditional branch of ExitCond, TBB, and FBB. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, + Value *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { + // Check if the controlling expression for this loop is an And or Or. + if (BinaryOperator *BO = dyn_cast(ExitCond)) { + if (BO->getOpcode() == Instruction::And) { + // Recurse on the operands of the and. + BackedgeTakenInfo BTI0 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); + BackedgeTakenInfo BTI1 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + const SCEV* BECount = getCouldNotCompute(); + const SCEV* MaxBECount = getCouldNotCompute(); + if (L->contains(TBB)) { + // Both conditions must be true for the loop to continue executing. + // Choose the less conservative count. + if (BTI0.Exact == getCouldNotCompute() || + BTI1.Exact == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max == getCouldNotCompute()) + MaxBECount = BTI1.Max; + else if (BTI1.Max == getCouldNotCompute()) + MaxBECount = BTI0.Max; + else + MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); + } else { + // Both conditions must be true for the loop to exit. + assert(L->contains(FBB) && "Loop block has no successor in loop!"); + if (BTI0.Exact != getCouldNotCompute() && + BTI1.Exact != getCouldNotCompute()) + BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max != getCouldNotCompute() && + BTI1.Max != getCouldNotCompute()) + MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + } + + return BackedgeTakenInfo(BECount, MaxBECount); + } + if (BO->getOpcode() == Instruction::Or) { + // Recurse on the operands of the or. + BackedgeTakenInfo BTI0 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); + BackedgeTakenInfo BTI1 = + ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); + const SCEV* BECount = getCouldNotCompute(); + const SCEV* MaxBECount = getCouldNotCompute(); + if (L->contains(FBB)) { + // Both conditions must be false for the loop to continue executing. + // Choose the less conservative count. + if (BTI0.Exact == getCouldNotCompute() || + BTI1.Exact == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max == getCouldNotCompute()) + MaxBECount = BTI1.Max; + else if (BTI1.Max == getCouldNotCompute()) + MaxBECount = BTI0.Max; + else + MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); + } else { + // Both conditions must be false for the loop to exit. + assert(L->contains(TBB) && "Loop block has no successor in loop!"); + if (BTI0.Exact != getCouldNotCompute() && + BTI1.Exact != getCouldNotCompute()) + BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); + if (BTI0.Max != getCouldNotCompute() && + BTI1.Max != getCouldNotCompute()) + MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); + } + + return BackedgeTakenInfo(BECount, MaxBECount); + } + } + + // With an icmp, it may be feasible to compute an exact backedge-taken count. + // Procede to the next level to examine the icmp. + if (ICmpInst *ExitCondICmp = dyn_cast(ExitCond)) + return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB); // If it's not an integer or pointer comparison then compute it the hard way. - if (ExitCond == 0) - return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(), - ExitBr->getSuccessor(0) == ExitBlock); + return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); +} + +/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the +/// backedge of the specified loop will execute if its exit condition +/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, + ICmpInst *ExitCond, + BasicBlock *TBB, + BasicBlock *FBB) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; - if (ExitBr->getSuccessor(1) == ExitBlock) + if (!L->contains(FBB)) Cond = ExitCond->getPredicate(); else Cond = ExitCond->getInversePredicate(); @@ -2584,19 +3139,24 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { - SCEVHandle ItCnt = + const SCEV* ItCnt = ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); - if (!isa(ItCnt)) return ItCnt; + if (!isa(ItCnt)) { + unsigned BitWidth = getTypeSizeInBits(ItCnt->getType()); + return BackedgeTakenInfo(ItCnt, + isa(ItCnt) ? ItCnt : + getConstant(APInt::getMaxValue(BitWidth)-1)); + } } - SCEVHandle LHS = getSCEV(ExitCond->getOperand(0)); - SCEVHandle RHS = getSCEV(ExitCond->getOperand(1)); + const SCEV* LHS = getSCEV(ExitCond->getOperand(0)); + const SCEV* RHS = getSCEV(ExitCond->getOperand(1)); // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); RHS = getSCEVAtScope(RHS, L); - // At this point, we would like to compute how many iterations of the + // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { // If there is a loop-invariant, force it into the RHS. @@ -2613,20 +3173,20 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { ConstantRange CompRange( ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); - SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, *this); + const SCEV* Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; } switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - SCEVHandle TC = HowFarToZero(getMinusSCEV(LHS, RHS), L); + const SCEV* TC = HowFarToZero(getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } case ICmpInst::ICMP_EQ: { // Convert to: while (X-Y == 0) // while (X == Y) - SCEVHandle TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + const SCEV* TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } @@ -2658,21 +3218,20 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { if (ExitCond->getOperand(0)->getType()->isUnsigned()) errs() << "[unsigned] "; errs() << *LHS << " " - << Instruction::getOpcodeName(Instruction::ICmp) + << Instruction::getOpcodeName(Instruction::ICmp) << " " << *RHS << "\n"; #endif break; } return - ComputeBackedgeTakenCountExhaustively(L, ExitCond, - ExitBr->getSuccessor(0) == ExitBlock); + ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB)); } static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE) { - SCEVHandle InVal = SE.getConstant(C); - SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE); + const SCEV* InVal = SE.getConstant(C); + const SCEV* Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); @@ -2715,15 +3274,17 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of /// 'icmp op load X, cst', try to see if we can compute the backedge /// execution count. -SCEVHandle ScalarEvolution:: -ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, - const Loop *L, - ICmpInst::Predicate predicate) { - if (LI->isVolatile()) return CouldNotCompute; +const SCEV * +ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( + LoadInst *LI, + Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { + if (LI->isVolatile()) return getCouldNotCompute(); // Check to see if the loaded pointer is a getelementptr of a global. GetElementPtrInst *GEP = dyn_cast(LI->getOperand(0)); - if (!GEP) return CouldNotCompute; + if (!GEP) return getCouldNotCompute(); // Make sure that it is really a constant global we are gepping, with an // initializer, and make sure the first IDX is really 0. @@ -2731,7 +3292,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, if (!GV || !GV->isConstant() || !GV->hasInitializer() || GEP->getNumOperands() < 3 || !isa(GEP->getOperand(1)) || !cast(GEP->getOperand(1))->isNullValue()) - return CouldNotCompute; + return getCouldNotCompute(); // Okay, we allow one non-constant index into the GEP instruction. Value *VarIdx = 0; @@ -2741,7 +3302,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, if (ConstantInt *CI = dyn_cast(GEP->getOperand(i))) { Indexes.push_back(CI); } else if (!isa(GEP->getOperand(i))) { - if (VarIdx) return CouldNotCompute; // Multiple non-constant idx's. + if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. VarIdx = GEP->getOperand(i); VarIdxNum = i-2; Indexes.push_back(0); @@ -2749,7 +3310,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. // Check to see if X is a loop variant variable value now. - SCEVHandle Idx = getSCEV(VarIdx); + const SCEV* Idx = getSCEV(VarIdx); Idx = getSCEVAtScope(Idx, L); // We can only recognize very limited forms of loop index expressions, in @@ -2758,12 +3319,12 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) || !isa(IdxExpr->getOperand(0)) || !isa(IdxExpr->getOperand(1))) - return CouldNotCompute; + return getCouldNotCompute(); unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { ConstantInt *ItCst = - ConstantInt::get(IdxExpr->getType(), IterationNum); + ConstantInt::get(cast(IdxExpr->getType()), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); // Form the GEP offset. @@ -2785,7 +3346,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, return getConstant(ItCst); // Found terminating iteration! } } - return CouldNotCompute; + return getCouldNotCompute(); } @@ -2874,8 +3435,10 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. -Constant *ScalarEvolution:: -getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ +Constant * +ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, + const APInt& BEs, + const Loop *L) { std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) @@ -2924,11 +3487,13 @@ getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ /// constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot -/// evaluate the trip count of the loop, return CouldNotCompute. -SCEVHandle ScalarEvolution:: -ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { +/// evaluate the trip count of the loop, return getCouldNotCompute(). +const SCEV * +ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); - if (PN == 0) return CouldNotCompute; + if (PN == 0) return getCouldNotCompute(); // Since the loop is canonicalized, the PHI node must have two entries. One // entry must be a constant (coming in from outside of the loop), and the @@ -2936,11 +3501,11 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); Constant *StartCST = dyn_cast(PN->getIncomingValue(!SecondIsBackedge)); - if (StartCST == 0) return CouldNotCompute; // Must be a constant. + if (StartCST == 0) return getCouldNotCompute(); // Must be a constant. Value *BEValue = PN->getIncomingValue(SecondIsBackedge); PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); - if (PN2 != PN) return CouldNotCompute; // Not derived from same PHI. + if (PN2 != PN) return getCouldNotCompute(); // Not derived from same PHI. // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of @@ -2953,23 +3518,22 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) dyn_cast_or_null(EvaluateExpression(Cond, PHIVal)); // Couldn't symbolically evaluate. - if (!CondVal) return CouldNotCompute; + if (!CondVal) return getCouldNotCompute(); if (CondVal->getValue() == uint64_t(ExitWhen)) { - ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return getConstant(ConstantInt::get(Type::Int32Ty, IterationNum)); + return getConstant(Type::Int32Ty, IterationNum); } // Compute the value of the PHI node for the next iteration. Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); if (NextPHI == 0 || NextPHI == PHIVal) - return CouldNotCompute; // Couldn't evaluate or not making progress... + return getCouldNotCompute();// Couldn't evaluate or not making progress... PHIVal = NextPHI; } // Too many iterations were needed to evaluate. - return CouldNotCompute; + return getCouldNotCompute(); } /// getSCEVAtScope - Return a SCEV expression handle for the specified value @@ -2982,7 +3546,7 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) /// /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. -SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { +const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // FIXME: this should be turned into a virtual method on SCEV! if (isa(V)) return V; @@ -2999,7 +3563,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. - SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI); + const SCEV* BackedgeTakenCount = getBackedgeTakenCount(LI); if (const SCEVConstant *BTCC = dyn_cast(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If @@ -3008,7 +3572,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { Constant *RV = getConstantEvolutionLoopExitValue(PN, BTCC->getValue()->getValue(), LI); - if (RV) return getUnknown(RV); + if (RV) return getSCEV(RV); } } @@ -3022,7 +3586,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { std::pair::iterator, bool> Pair = Values.insert(std::make_pair(L, static_cast(0))); if (!Pair.second) - return Pair.first->second ? &*getUnknown(Pair.first->second) : V; + return Pair.first->second ? &*getSCEV(Pair.first->second) : V; std::vector Operands; Operands.reserve(I->getNumOperands()); @@ -3037,7 +3601,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L); + const SCEV* OpV = getSCEVAtScope(getSCEV(Op), L); if (const SCEVConstant *SC = dyn_cast(OpV)) { Constant *C = SC->getValue(); if (C->getType() != Op->getType()) @@ -3062,7 +3626,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } } } - + Constant *C; if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), @@ -3071,7 +3635,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), &Operands[0], Operands.size()); Pair.first->second = C; - return getUnknown(C); + return getSCEV(C); } } @@ -3083,11 +3647,12 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { - SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + const SCEV* OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); if (OpAtScope != Comm->getOperand(i)) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps(Comm->op_begin(), Comm->op_begin()+i); + SmallVector NewOps(Comm->op_begin(), + Comm->op_begin()+i); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) { @@ -3110,8 +3675,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } if (const SCEVUDivExpr *Div = dyn_cast(V)) { - SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); - SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); + const SCEV* LHS = getSCEVAtScope(Div->getLHS(), L); + const SCEV* RHS = getSCEVAtScope(Div->getRHS(), L); if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant return getUDivExpr(LHS, RHS); @@ -3123,8 +3688,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (!L || !AddRec->getLoop()->contains(L->getHeader())) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. - SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); - if (BackedgeTakenCount == CouldNotCompute) return AddRec; + const SCEV* BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; // Then, evaluate the AddRec. return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); @@ -3133,21 +3698,21 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } if (const SCEVZeroExtendExpr *Cast = dyn_cast(V)) { - SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getZeroExtendExpr(Op, Cast->getType()); } if (const SCEVSignExtendExpr *Cast = dyn_cast(V)) { - SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getSignExtendExpr(Op, Cast->getType()); } if (const SCEVTruncateExpr *Cast = dyn_cast(V)) { - SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getTruncateExpr(Op, Cast->getType()); @@ -3159,7 +3724,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { /// getSCEVAtScope - This is a convenience function which does /// getSCEVAtScope(getSCEV(V), L). -SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { +const SCEV* ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } @@ -3172,7 +3737,7 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. -static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, +static const SCEV* SolveLinEquationWithOverflow(const APInt &A, const APInt &B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); assert(BW == B.getBitWidth() && "Bit widths must be the same."); @@ -3215,7 +3780,7 @@ static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which /// might be the same) or two SCEVCouldNotCompute objects. /// -static std::pair +static std::pair SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); @@ -3235,7 +3800,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { APInt Two(BitWidth, 2); APInt Four(BitWidth, 4); - { + { using namespace APIntOps; const APInt& C = L; // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C @@ -3255,7 +3820,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // integer value or else APInt::sqrt() will assert. APInt SqrtVal(SqrtTerm.sqrt()); - // Compute the two solutions for the quadratic formula. + // Compute the two solutions for the quadratic formula. // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA( A << 1 ); @@ -3267,24 +3832,24 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); - return std::make_pair(SE.getConstant(Solution1), + return std::make_pair(SE.getConstant(Solution1), SE.getConstant(Solution2)); } // end APIntOps namespace } /// HowFarToZero - Return the number of times a backedge comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. -SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { +const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. if (C->getValue()->isZero()) return C; - return CouldNotCompute; // Otherwise it will loop infinitely. + return getCouldNotCompute(); // Otherwise it will loop infinitely. } const SCEVAddRecExpr *AddRec = dyn_cast(V); if (!AddRec || AddRec->getLoop() != L) - return CouldNotCompute; + return getCouldNotCompute(); if (AddRec->isAffine()) { // If this is an affine expression, the execution count of this branch is @@ -3299,8 +3864,10 @@ SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // where BW is the common bit width of Start and Step. // Get the initial value for the loop. - SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); - SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + const SCEV *Start = getSCEVAtScope(AddRec->getStart(), + L->getParentLoop()); + const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), + L->getParentLoop()); if (const SCEVConstant *StepC = dyn_cast(Step)) { // For now we handle only constant steps. @@ -3320,7 +3887,7 @@ SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. - std::pair Roots = SolveQuadraticEquation(AddRec, + std::pair Roots = SolveQuadraticEquation(AddRec, *this); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); @@ -3331,7 +3898,7 @@ SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { #endif // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -3339,20 +3906,20 @@ SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. - SCEVHandle Val = AddRec->evaluateAtIteration(R1, *this); + const SCEV* Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) return R1; // We found a quadratic root! } } } - return CouldNotCompute; + return getCouldNotCompute(); } /// HowFarToNonZero - Return the number of times a backedge checking the /// specified value for nonzero will execute. If not computable, return /// CouldNotCompute -SCEVHandle ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +const SCEV* ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -3362,12 +3929,12 @@ SCEVHandle ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { if (const SCEVConstant *C = dyn_cast(V)) { if (!C->getValue()->isNullValue()) return getIntegerSCEV(0, C->getType()); - return CouldNotCompute; // Otherwise it will loop infinitely. + return getCouldNotCompute(); // Otherwise it will loop infinitely. } // We could implement others, but I really doubt anyone writes loops like // this, and if they did, they would already be constant folded. - return CouldNotCompute; + return getCouldNotCompute(); } /// getLoopPredecessor - If the given loop's header has exactly one unique @@ -3407,6 +3974,29 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { return 0; } +/// HasSameValue - SCEV structural equivalence is usually sufficient for +/// testing whether two expressions are equal, however for the purposes of +/// looking for a condition guarding a loop, it can be useful to be a little +/// more general, since a front-end may have replicated the controlling +/// expression. +/// +static bool HasSameValue(const SCEV* A, const SCEV* B) { + // Quick check to see if they are the same SCEV. + if (A == B) return true; + + // Otherwise, if they're both SCEVUnknown, it's possible that they hold + // two different instructions with the same value. Check for this case. + if (const SCEVUnknown *AU = dyn_cast(A)) + if (const SCEVUnknown *BU = dyn_cast(B)) + if (const Instruction *AI = dyn_cast(AU->getValue())) + if (const Instruction *BI = dyn_cast(BU->getValue())) + if (AI->isIdenticalTo(BI)) + return true; + + // Otherwise assume they may have a different value. + return false; +} + /// isLoopGuardedByCond - Test whether entry to the loop is protected by /// a conditional between LHS and RHS. This is used to help avoid max /// expressions in loop trip counts. @@ -3433,112 +4023,162 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, LoopEntryPredicate->isUnconditional()) continue; - ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); - if (!ICI) continue; + if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS, + LoopEntryPredicate->getSuccessor(0) != PredecessorDest)) + return true; + } - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); + return false; +} - if (Cond == Pred) - ; // An exact match. - else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) - ; // The actual condition is beyond sufficient. - else - // Check a few special cases. - switch (Cond) { - case ICmpInst::ICMP_UGT: - if (Pred == ICmpInst::ICMP_ULT) { - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - } - continue; - case ICmpInst::ICMP_SGT: - if (Pred == ICmpInst::ICMP_SLT) { - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; +/// isNecessaryCond - Test whether the given CondValue value is a condition +/// which is at least as strict as the one described by Pred, LHS, and RHS. +bool ScalarEvolution::isNecessaryCond(Value *CondValue, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + bool Inverse) { + // Recursivly handle And and Or conditions. + if (BinaryOperator *BO = dyn_cast(CondValue)) { + if (BO->getOpcode() == Instruction::And) { + if (!Inverse) + return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + } else if (BO->getOpcode() == Instruction::Or) { + if (Inverse) + return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + } + } + + ICmpInst *ICI = dyn_cast(CondValue); + if (!ICI) return false; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (Inverse) + Cond = ICI->getInversePredicate(); + else + Cond = ICI->getPredicate(); + + if (Cond == Pred) + ; // An exact match. + else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) + ; // The actual condition is beyond sufficient. + else + // Check a few special cases. + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (Pred == ICmpInst::ICMP_ULT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + } + return false; + case ICmpInst::ICMP_SGT: + if (Pred == ICmpInst::ICMP_SLT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + } + return false; + case ICmpInst::ICMP_NE: + // Expressions like (x >u 0) are often canonicalized to (x != 0), + // so check for this case by checking if the NE is comparing against + // a minimum or maximum constant. + if (!ICmpInst::isTrueWhenEqual(Pred)) + if (ConstantInt *CI = dyn_cast(PreCondRHS)) { + const APInt &A = CI->getValue(); + switch (Pred) { + case ICmpInst::ICMP_SLT: + if (A.isMaxSignedValue()) break; + return false; + case ICmpInst::ICMP_SGT: + if (A.isMinSignedValue()) break; + return false; + case ICmpInst::ICMP_ULT: + if (A.isMaxValue()) break; + return false; + case ICmpInst::ICMP_UGT: + if (A.isMinValue()) break; + return false; + default: + return false; + } + Cond = ICmpInst::ICMP_NE; + // NE is symmetric but the original comparison may not be. Swap + // the operands if necessary so that they match below. + if (isa(LHS)) + std::swap(PreCondLHS, PreCondRHS); break; } - continue; - case ICmpInst::ICMP_NE: - // Expressions like (x >u 0) are often canonicalized to (x != 0), - // so check for this case by checking if the NE is comparing against - // a minimum or maximum constant. - if (!ICmpInst::isTrueWhenEqual(Pred)) - if (ConstantInt *CI = dyn_cast(PreCondRHS)) { - const APInt &A = CI->getValue(); - switch (Pred) { - case ICmpInst::ICMP_SLT: - if (A.isMaxSignedValue()) break; - continue; - case ICmpInst::ICMP_SGT: - if (A.isMinSignedValue()) break; - continue; - case ICmpInst::ICMP_ULT: - if (A.isMaxValue()) break; - continue; - case ICmpInst::ICMP_UGT: - if (A.isMinValue()) break; - continue; - default: - continue; - } - Cond = ICmpInst::ICMP_NE; - // NE is symmetric but the original comparison may not be. Swap - // the operands if necessary so that they match below. - if (isa(LHS)) - std::swap(PreCondLHS, PreCondRHS); - break; - } - continue; - default: - // We weren't able to reconcile the condition. - continue; - } + return false; + default: + // We weren't able to reconcile the condition. + return false; + } - if (!PreCondLHS->getType()->isInteger()) continue; + if (!PreCondLHS->getType()->isInteger()) return false; - SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); - SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); - if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || - (LHS == getNotSCEV(PreCondRHSSCEV) && - RHS == getNotSCEV(PreCondLHSSCEV))) - return true; - } + const SCEV *PreCondLHSSCEV = getSCEV(PreCondLHS); + const SCEV *PreCondRHSSCEV = getSCEV(PreCondRHS); + return (HasSameValue(LHS, PreCondLHSSCEV) && + HasSameValue(RHS, PreCondRHSSCEV)) || + (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) && + HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV))); +} - return false; +/// getBECount - Subtract the end and start values and divide by the step, +/// rounding up, to get the number of times the backedge is executed. Return +/// CouldNotCompute if an intermediate computation overflows. +const SCEV* ScalarEvolution::getBECount(const SCEV* Start, + const SCEV* End, + const SCEV* Step) { + const Type *Ty = Start->getType(); + const SCEV* NegOne = getIntegerSCEV(-1, Ty); + const SCEV* Diff = getMinusSCEV(End, Start); + const SCEV* RoundUp = getAddExpr(Step, NegOne); + + // Add an adjustment to the difference between End and Start so that + // the division will effectively round up. + const SCEV* Add = getAddExpr(Diff, RoundUp); + + // Check Add for unsigned overflow. + // TODO: More sophisticated things could be done here. + const Type *WideTy = IntegerType::get(getTypeSizeInBits(Ty) + 1); + const SCEV* OperandExtendedAdd = + getAddExpr(getZeroExtendExpr(Diff, WideTy), + getZeroExtendExpr(RoundUp, WideTy)); + if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) + return getCouldNotCompute(); + + return getUDivExpr(Add, Step); } /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. -ScalarEvolution::BackedgeTakenInfo ScalarEvolution:: -HowManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool isSigned) { +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, + const Loop *L, bool isSigned) { // Only handle: "ADDREC < LoopInvariant". - if (!RHS->isLoopInvariant(L)) return CouldNotCompute; + if (!RHS->isLoopInvariant(L)) return getCouldNotCompute(); const SCEVAddRecExpr *AddRec = dyn_cast(LHS); if (!AddRec || AddRec->getLoop() != L) - return CouldNotCompute; + return getCouldNotCompute(); if (AddRec->isAffine()) { // FORNOW: We only support unit strides. unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); - SCEVHandle Step = AddRec->getStepRecurrence(*this); - SCEVHandle NegOne = getIntegerSCEV(-1, AddRec->getType()); + const SCEV* Step = AddRec->getStepRecurrence(*this); // TODO: handle non-constant strides. const SCEVConstant *CStep = dyn_cast(Step); if (!CStep || CStep->isZero()) - return CouldNotCompute; + return getCouldNotCompute(); if (CStep->isOne()) { // With unit stride, the iteration never steps past the limit value. } else if (CStep->getValue()->getValue().isStrictlyPositive()) { @@ -3549,19 +4189,19 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, APInt Max = APInt::getSignedMaxValue(BitWidth); if ((Max - CStep->getValue()->getValue()) .slt(CLimit->getValue()->getValue())) - return CouldNotCompute; + return getCouldNotCompute(); } else { APInt Max = APInt::getMaxValue(BitWidth); if ((Max - CStep->getValue()->getValue()) .ult(CLimit->getValue()->getValue())) - return CouldNotCompute; + return getCouldNotCompute(); } } else // TODO: handle non-constant limit values below. - return CouldNotCompute; + return getCouldNotCompute(); } else // TODO: handle negative strides below. - return CouldNotCompute; + return getCouldNotCompute(); // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant // m. So, we count the number of iterations in which {n,+,s} < m is true. @@ -3569,10 +4209,10 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // treat m-n as signed nor unsigned due to overflow possibility. // First, we get the value of the LHS in the first iteration: n - SCEVHandle Start = AddRec->getOperand(0); + const SCEV* Start = AddRec->getOperand(0); // Determine the minimum constant start value. - SCEVHandle MinStart = isa(Start) ? Start : + const SCEV *MinStart = isa(Start) ? Start : getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) : APInt::getMinValue(BitWidth)); @@ -3580,7 +4220,7 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // then we know that it will run exactly (m-n)/s times. Otherwise, we // only know that it will execute (max(m,n)-n)/s times. In both cases, // the division must round up. - SCEVHandle End = RHS; + const SCEV* End = RHS; if (!isLoopGuardedByCond(L, isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, getMinusSCEV(Start, Step), RHS)) @@ -3588,27 +4228,25 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : getUMaxExpr(RHS, Start); // Determine the maximum constant end value. - SCEVHandle MaxEnd = isa(End) ? End : - getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) : - APInt::getMaxValue(BitWidth)); + const SCEV* MaxEnd = + isa(End) ? End : + getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) + .ashr(GetMinSignBits(End) - 1) : + APInt::getMaxValue(BitWidth) + .lshr(GetMinLeadingZeros(End))); // Finally, we subtract these two values and divide, rounding up, to get // the number of times the backedge is executed. - SCEVHandle BECount = getUDivExpr(getAddExpr(getMinusSCEV(End, Start), - getAddExpr(Step, NegOne)), - Step); + const SCEV* BECount = getBECount(Start, End, Step); // The maximum backedge count is similar, except using the minimum start // value and the maximum end value. - SCEVHandle MaxBECount = getUDivExpr(getAddExpr(getMinusSCEV(MaxEnd, - MinStart), - getAddExpr(Step, NegOne)), - Step); + const SCEV* MaxBECount = getBECount(MinStart, MaxEnd, Step);; return BackedgeTakenInfo(BECount, MaxBECount); } - return CouldNotCompute; + return getCouldNotCompute(); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -3616,17 +4254,17 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, /// this is that it returns the first iteration number where the value is not in /// the condition, thus computing the exit count. If the iteration count can't /// be computed, an instance of SCEVCouldNotCompute is returned. -SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, - ScalarEvolution &SE) const { +const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, + ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { - SmallVector Operands(op_begin(), op_end()); + SmallVector Operands(op_begin(), op_end()); Operands[0] = SE.getIntegerSCEV(0, SC->getType()); - SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop()); + const SCEV* Shifted = SE.getAddRecExpr(Operands, getLoop()); if (const SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( @@ -3649,7 +4287,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // iteration exits. unsigned BitWidth = SE.getTypeSizeInBits(getType()); if (!Range.contains(APInt(BitWidth, 0))) - return SE.getConstant(ConstantInt::get(getType(),0)); + return SE.getIntegerSCEV(0, getType()); if (isAffine()) { // If this is an affine expression then we have this situation: @@ -3676,7 +4314,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( - EvaluateConstantChrecAtConstant(this, + EvaluateConstantChrecAtConstant(this, ConstantInt::get(ExitVal - One), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); @@ -3685,19 +4323,19 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. - SmallVector NewOps(op_begin(), op_end()); + SmallVector NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); + const SCEV* NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); // Next, solve the constructed addrec - std::pair Roots = + std::pair Roots = SolveQuadraticEquation(cast(NewAddRec), SE); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -3796,7 +4434,7 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(&ID), CouldNotCompute(new SCEVCouldNotCompute()) { + : FunctionPass(&ID) { } bool ScalarEvolution::runOnFunction(Function &F) { @@ -3811,6 +4449,8 @@ void ScalarEvolution::releaseMemory() { BackedgeTakenCounts.clear(); ConstantEvolutionLoopExitValue.clear(); ValuesAtScopes.clear(); + UniqueSCEVs.clear(); + SCEVAllocator.Reset(); } void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { @@ -3841,6 +4481,15 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable backedge-taken count. "; } + OS << "\n"; + OS << "Loop " << L->getHeader()->getName() << ": "; + + if (!isa(SE->getMaxBackedgeTakenCount(L))) { + OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); + } else { + OS << "Unpredictable max backedge-taken count. "; + } + OS << "\n"; } @@ -3858,13 +4507,20 @@ void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { if (isSCEVable(I->getType())) { OS << *I; OS << " --> "; - SCEVHandle SV = SE.getSCEV(&*I); + const SCEV* SV = SE.getSCEV(&*I); SV->print(OS); - OS << "\t\t"; - if (const Loop *L = LI->getLoopFor((*I).getParent())) { - OS << "Exits: "; - SCEVHandle ExitValue = SE.getSCEVAtScope(&*I, L->getParentLoop()); + const Loop *L = LI->getLoopFor((*I).getParent()); + + const SCEV* AtUse = SE.getSCEVAtScope(SV, L); + if (AtUse != SV) { + OS << " --> "; + AtUse->print(OS); + } + + if (L) { + OS << "\t\t" "Exits: "; + const SCEV* ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!ExitValue->isLoopInvariant(L)) { OS << "<>"; } else {