X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=067b83e466dd77dd0665a3c86d92ffa4062f1db1;hb=b169426272b85ce28a9a56d13154e61b158fc47a;hp=dc1129469d8b157d0b5cf331b1a20ae48a0ee415;hpb=2ceb40f3da2290d37e9a4faa35bd5199e5dc90d5;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index dc1129469d8..067b83e466d 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -66,6 +66,7 @@ #include "llvm/GlobalVariable.h" #include "llvm/Instructions.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Assembly/Writer.h" #include "llvm/Transforms/Scalar.h" @@ -83,9 +84,6 @@ #include using namespace llvm; -STATISTIC(NumBruteForceEvaluations, - "Number of brute force evaluations needed to " - "calculate high-order polynomial exit values"); STATISTIC(NumArrayLenItCounts, "Number of trip counts computed with array length"); STATISTIC(NumTripCountsComputed, @@ -115,6 +113,7 @@ char ScalarEvolution::ID = 0; SCEV::~SCEV() {} void SCEV::dump() const { print(cerr); + cerr << '\n'; } uint32_t SCEV::getBitWidth() const { @@ -207,6 +206,10 @@ SCEVTruncateExpr::~SCEVTruncateExpr() { SCEVTruncates->erase(std::make_pair(Op, Ty)); } +bool SCEVTruncateExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->dominates(BB, DT); +} + void SCEVTruncateExpr::print(std::ostream &OS) const { OS << "(truncate " << *Op << " to " << *Ty << ")"; } @@ -229,6 +232,10 @@ SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { SCEVZeroExtends->erase(std::make_pair(Op, Ty)); } +bool SCEVZeroExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->dominates(BB, DT); +} + void SCEVZeroExtendExpr::print(std::ostream &OS) const { OS << "(zeroextend " << *Op << " to " << *Ty << ")"; } @@ -251,6 +258,10 @@ SCEVSignExtendExpr::~SCEVSignExtendExpr() { SCEVSignExtends->erase(std::make_pair(Op, Ty)); } +bool SCEVSignExtendExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->dominates(BB, DT); +} + void SCEVSignExtendExpr::print(std::ostream &OS) const { OS << "(signextend " << *Op << " to " << *Ty << ")"; } @@ -308,6 +319,14 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, return this; } +bool SCEVCommutativeExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + if (!getOperand(i)->dominates(BB, DT)) + return false; + } + return true; +} + // SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular // input. Don't use a SCEVHandle here, or else the object will never be @@ -319,6 +338,10 @@ SCEVUDivExpr::~SCEVUDivExpr() { SCEVUDivs->erase(std::make_pair(LHS, RHS)); } +bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return LHS->dominates(BB, DT) && RHS->dominates(BB, DT); +} + void SCEVUDivExpr::print(std::ostream &OS) const { OS << "(" << *LHS << " /u " << *RHS << ")"; } @@ -339,6 +362,15 @@ SCEVAddRecExpr::~SCEVAddRecExpr() { Operands.end()))); } +bool SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + if (!getOperand(i)->dominates(BB, DT)) + return false; + } + return true; +} + + SCEVHandle SCEVAddRecExpr:: replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, const SCEVHandle &Conc, @@ -393,6 +425,12 @@ bool SCEVUnknown::isLoopInvariant(const Loop *L) const { return true; } +bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const { + if (Instruction *I = dyn_cast(getValue())) + return DT->dominates(I->getParent(), BB); + return true; +} + const Type *SCEVUnknown::getType() const { return V->getType(); } @@ -507,77 +545,115 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, } -/// BinomialCoefficient - Compute BC(It, K). The result is of the same type as -/// It. Assume, K > 0. +/// BinomialCoefficient - Compute BC(It, K). The result has width W. +// Assume, K > 0. static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, - ScalarEvolution &SE) { + ScalarEvolution &SE, + const IntegerType* ResultTy) { + // Handle the simplest case efficiently. + if (K == 1) + return SE.getTruncateOrZeroExtend(It, ResultTy); + // We are using the following formula for BC(It, K): // // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! // - // Suppose, W is the bitwidth of It (and of the return value as well). We - // must be prepared for overflow. Hence, we must assure that the result of - // our computation is equal to the accurate one modulo 2^W. Unfortunately, - // division isn't safe in modular arithmetic. This means we must perform the - // whole computation accurately and then truncate the result to W bits. + // Suppose, W is the bitwidth of the return value. We must be prepared for + // overflow. Hence, we must assure that the result of our computation is + // equal to the accurate one modulo 2^W. Unfortunately, division isn't + // safe in modular arithmetic. // - // The dividend of the formula is a multiplication of K integers of bitwidth - // W. K*W bits suffice to compute it accurately. + // However, this code doesn't use exactly that formula; the formula it uses + // is something like the following, where T is the number of factors of 2 in + // K! (i.e. trailing zeros in the binary representation of K!), and ^ is + // exponentiation: // - // FIXME: We assume the divisor can be accurately computed using 16-bit - // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In - // future we may use APInt to use the minimum number of bits necessary to - // compute it accurately. + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) // - // It is safe to use unsigned division here: the dividend is nonnegative and - // the divisor is positive. - - // Handle the simplest case efficiently. - if (K == 1) - return It; - - assert(K < 9 && "We cannot handle such long AddRecs yet."); - - unsigned DividendBits = K * It->getBitWidth(); - if (DividendBits > 256) + // This formula is trivially equivalent to the previous formula. However, + // this formula can be implemented much more efficiently. The trick is that + // K! / 2^T is odd, and exact division by an odd number *is* safe in modular + // arithmetic. To do exact division in modular arithmetic, all we have + // to do is multiply by the inverse. Therefore, this step can be done at + // width W. + // + // The next issue is how to safely do the division by 2^T. The way this + // is done is by doing the multiplication step at a width of at least W + T + // bits. This way, the bottom W+T bits of the product are accurate. Then, + // when we perform the division by 2^T (which is equivalent to a right shift + // by T), the bottom W bits are accurate. Extra bits are okay; they'll get + // truncated out after the division by 2^T. + // + // In comparison to just directly using the first formula, this technique + // is much more efficient; using the first formula requires W * K bits, + // but this formula less than W + K bits. Also, the first formula requires + // a division step, whereas this formula only requires multiplies and shifts. + // + // It doesn't matter whether the subtraction step is done in the calculation + // width or the input iteration count's width; if the subtraction overflows, + // the result must be zero anyway. We prefer here to do it in the width of + // the induction variable because it helps a lot for certain cases; CodeGen + // isn't smart enough to ignore the overflow, which leads to much less + // efficient code if the width of the subtraction is wider than the native + // register width. + // + // (It's possible to not widen at all by pulling out factors of 2 before + // the multiplication; for example, K=2 can be calculated as + // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires + // extra arithmetic, so it's not an obvious win, and it gets + // much more complicated for K > 3.) + + // Protection from insane SCEVs; this bound is conservative, + // but it probably doesn't matter. + if (K > 1000) return new SCEVCouldNotCompute(); - const IntegerType *DividendTy = IntegerType::get(DividendBits); - const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy); - - // The final number of bits we need to perform the division is the maximum of - // dividend and divisor bitwidths. - const IntegerType *DivisionTy = - IntegerType::get(std::max(DividendBits, 16U)); - - // Compute K! We know K >= 2 here. - unsigned F = 2; - for (unsigned i = 3; i <= K; ++i) - F *= i; - APInt Divisor(DivisionTy->getBitWidth(), F); - - // Handle this case efficiently, it is common to have constant iteration - // counts while computing loop exit values. - if (SCEVConstant *SC = dyn_cast(ExIt)) { - const APInt& N = SC->getValue()->getValue(); - APInt Dividend(N.getBitWidth(), 1); - for (; K; --K) - Dividend *= N-(K-1); - if (DividendTy != DivisionTy) - Dividend = Dividend.zext(DivisionTy->getBitWidth()); - return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth())); + unsigned W = ResultTy->getBitWidth(); + + // Calculate K! / 2^T and T; we divide out the factors of two before + // multiplying for calculating K! / 2^T to avoid overflow. + // Other overflow doesn't matter because we only care about the bottom + // W bits of the result. + APInt OddFactorial(W, 1); + unsigned T = 1; + for (unsigned i = 3; i <= K; ++i) { + APInt Mult(W, i); + unsigned TwoFactors = Mult.countTrailingZeros(); + T += TwoFactors; + Mult = Mult.lshr(TwoFactors); + OddFactorial *= Mult; } - - SCEVHandle Dividend = ExIt; - for (unsigned i = 1; i != K; ++i) - Dividend = - SE.getMulExpr(Dividend, - SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy))); - if (DividendTy != DivisionTy) - Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy); - return SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)), - It->getType()); + // We need at least W + T bits for the multiplication step + unsigned CalculationBits = W + T; + + // Calcuate 2^T, at width T+W. + APInt DivFactor = APInt(CalculationBits, 1).shl(T); + + // Calculate the multiplicative inverse of K! / 2^T; + // this multiplication factor will perform the exact division by + // K! / 2^T. + APInt Mod = APInt::getSignedMinValue(W+1); + APInt MultiplyFactor = OddFactorial.zext(W+1); + MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); + MultiplyFactor = MultiplyFactor.trunc(W); + + // Calculate the product, at width T+W + const IntegerType *CalculationTy = IntegerType::get(CalculationBits); + SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + for (unsigned i = 1; i != K; ++i) { + SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + Dividend = SE.getMulExpr(Dividend, + SE.getTruncateOrZeroExtend(S, CalculationTy)); + } + + // Divide by 2^T + SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + + // Truncate the result, and divide by K! / 2^T. + + return SE.getMulExpr(SE.getConstant(MultiplyFactor), + SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } /// evaluateAtIteration - Return the value of this chain of recurrences at @@ -596,9 +672,12 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - SCEVHandle Val = SE.getMulExpr(getOperand(i), - BinomialCoefficient(It, i, SE)); - Result = SE.getAddExpr(Result, Val); + SCEVHandle Coeff = BinomialCoefficient(It, i, SE, + cast(getType())); + if (isa(Coeff)) + return Coeff; + + Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); } return Result; } @@ -676,6 +755,21 @@ SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, return getZeroExtendExpr(V, Ty); } +/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion +/// of the input value to the specified type. If the type must be +/// extended, it is sign extended. +SCEVHandle ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert(SrcTy->isInteger() && Ty->isInteger() && + "Cannot truncate or sign extend with non-integer arguments!"); + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) + return V; // No conversion + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) + return getTruncateExpr(V, Ty); + return getSignExtendExpr(V, Ty); +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -841,7 +935,7 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI + LI + { Start,+,Step} --> NLI + { LI+Start,+,Step } + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); std::vector AddRecOps(AddRec->op_begin(), AddRec->op_end()); @@ -989,7 +1083,7 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // NLI * LI * { Start,+,Step} --> NLI * { LI*Start,+,LI*Step } + // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} std::vector NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { @@ -1105,7 +1199,20 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, if (Operands.back()->isZero()) { Operands.pop_back(); - return getAddRecExpr(Operands, L); // { X,+,0 } --> X + return getAddRecExpr(Operands, L); // {X,+,0} --> X + } + + // Canonicalize nested AddRecs in by nesting them in order of loop depth. + if (SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { + const Loop* NestedLoop = NestedAR->getLoop(); + if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { + std::vector NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); + SCEVHandle NestedARHandle(NestedAR); + Operands[0] = NestedAR->getStart(); + NestedOperands[0] = getAddRecExpr(Operands, L); + return getAddRecExpr(NestedOperands, NestedLoop); + } } SCEVAddRecExpr *&Result = @@ -1312,9 +1419,9 @@ namespace { /// std::map Scalars; - /// IterationCounts - Cache the iteration count of the loops for this - /// function as they are computed. - std::map IterationCounts; + /// BackedgeTakenCounts - Cache the backedge-taken count of the loops for + /// this function as they are computed. + std::map BackedgeTakenCounts; /// ConstantEvolutionLoopExitValue - This map contains entries for all of /// the PHI instructions that we attempt to compute constant evolutions for. @@ -1342,6 +1449,7 @@ namespace { void setSCEV(Value *V, const SCEVHandle &H) { bool isNew = Scalars.insert(std::make_pair(V, H)).second; assert(isNew && "This entry already existed!"); + isNew = false; } @@ -1351,14 +1459,33 @@ namespace { SCEVHandle getSCEVAtScope(SCEV *V, const Loop *L); - /// hasLoopInvariantIterationCount - Return true if the specified loop has - /// an analyzable loop-invariant iteration count. - bool hasLoopInvariantIterationCount(const Loop *L); - - /// getIterationCount - If the specified loop has a predictable iteration - /// count, return it. Note that it is not valid to call this method on a - /// loop without a loop-invariant iteration count. - SCEVHandle getIterationCount(const Loop *L); + /// isLoopGuardedByCond - Test whether entry to the loop is protected by + /// a conditional between LHS and RHS. + bool isLoopGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, + SCEV *LHS, SCEV *RHS); + + /// hasLoopInvariantBackedgeTakenCount - Return true if the specified loop + /// has an analyzable loop-invariant backedge-taken count. + bool hasLoopInvariantBackedgeTakenCount(const Loop *L); + + /// forgetLoopBackedgeTakenCount - This method should be called by the + /// client when it has changed a loop in a way that may effect + /// ScalarEvolution's ability to compute a trip count, or if the loop + /// is deleted. + void forgetLoopBackedgeTakenCount(const Loop *L); + + /// getBackedgeTakenCount - If the specified loop has a predictable + /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute + /// object. The backedge-taken count is the number of times the loop header + /// will be branched to from within the loop. This is one less than the + /// trip count of the loop, since it doesn't count the first iteration, + /// when the header is branched to from outside the loop. + /// + /// Note that it is not valid to call this method on a loop without a + /// loop-invariant backedge-taken count (see + /// hasLoopInvariantBackedgeTakenCount). + /// + SCEVHandle getBackedgeTakenCount(const Loop *L); /// deleteValueFromRecords - This method should be called by the /// client before it removes a value from the program, to make sure @@ -1382,24 +1509,25 @@ namespace { const SCEVHandle &SymName, const SCEVHandle &NewVal); - /// ComputeIterationCount - Compute the number of times the specified loop - /// will iterate. - SCEVHandle ComputeIterationCount(const Loop *L); + /// ComputeBackedgeTakenCount - Compute the number of times the specified + /// loop will iterate. + SCEVHandle ComputeBackedgeTakenCount(const Loop *L); - /// ComputeLoadConstantCompareIterationCount - Given an exit condition of - /// 'icmp op load X, cst', try to see if we can compute the trip count. - SCEVHandle ComputeLoadConstantCompareIterationCount(LoadInst *LI, - Constant *RHS, - const Loop *L, - ICmpInst::Predicate p); + /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition + /// of 'icmp op load X, cst', try to see if we can compute the trip count. + SCEVHandle + ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, + Constant *RHS, + const Loop *L, + ICmpInst::Predicate p); - /// ComputeIterationCountExhaustively - If the trip is known to execute a - /// constant number of times (the condition evolves only from constants), + /// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute + /// a constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot /// evaluate the trip count of the loop, return UnknownValue. - SCEVHandle ComputeIterationCountExhaustively(const Loop *L, Value *Cond, - bool ExitWhen); + SCEVHandle ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, + bool ExitWhen); /// HowFarToZero - Return the number of times a backedge comparing the /// specified value to zero will execute. If not computable, return @@ -1417,15 +1545,17 @@ namespace { SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned); - /// executesAtLeastOnce - Test whether entry to the loop is protected by - /// a conditional between LHS and RHS. - bool executesAtLeastOnce(const Loop *L, bool isSigned, SCEV *LHS, SCEV *RHS); + /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB + /// (which may not be an immediate predecessor) which has exactly one + /// successor from which BB is reachable, or null if no such block is + /// found. + BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB); /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. - Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, + Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L); }; } @@ -1775,10 +1905,10 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) - // -smax(-x, -y) == smin(x, y). - return SE.getNegativeSCEV(SE.getSMaxExpr( - SE.getNegativeSCEV(getSCEV(LHS)), - SE.getNegativeSCEV(getSCEV(RHS)))); + // ~smax(~x, ~y) == smin(x, y). + return SE.getNotSCEV(SE.getSMaxExpr( + SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -1811,14 +1941,22 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { // Iteration Count Computation Code // -/// getIterationCount - If the specified loop has a predictable iteration -/// count, return it. Note that it is not valid to call this method on a -/// loop without a loop-invariant iteration count. -SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) { - std::map::iterator I = IterationCounts.find(L); - if (I == IterationCounts.end()) { - SCEVHandle ItCount = ComputeIterationCount(L); - I = IterationCounts.insert(std::make_pair(L, ItCount)).first; +/// getBackedgeTakenCount - If the specified loop has a predictable +/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute +/// object. The backedge-taken count is the number of times the loop header +/// will be branched to from within the loop. This is one less than the +/// trip count of the loop, since it doesn't count the first iteration, +/// when the header is branched to from outside the loop. +/// +/// Note that it is not valid to call this method on a loop without a +/// loop-invariant backedge-taken count (see +/// hasLoopInvariantBackedgeTakenCount). +/// +SCEVHandle ScalarEvolutionsImpl::getBackedgeTakenCount(const Loop *L) { + std::map::iterator I = BackedgeTakenCounts.find(L); + if (I == BackedgeTakenCounts.end()) { + SCEVHandle ItCount = ComputeBackedgeTakenCount(L); + I = BackedgeTakenCounts.insert(std::make_pair(L, ItCount)).first; if (ItCount != UnknownValue) { assert(ItCount->isLoopInvariant(L) && "Computed trip count isn't loop invariant for loop!"); @@ -1831,9 +1969,17 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) { return I->second; } -/// ComputeIterationCount - Compute the number of times the specified loop -/// will iterate. -SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { +/// forgetLoopBackedgeTakenCount - This method should be called by the +/// client when it has changed a loop in a way that may effect +/// ScalarEvolution's ability to compute a trip count, or if the loop +/// is deleted. +void ScalarEvolutionsImpl::forgetLoopBackedgeTakenCount(const Loop *L) { + BackedgeTakenCounts.erase(L); +} + +/// ComputeBackedgeTakenCount - Compute the number of times the backedge +/// of the specified loop will execute. +SCEVHandle ScalarEvolutionsImpl::ComputeBackedgeTakenCount(const Loop *L) { // If the loop has a non-one exit block count, we can't analyze it. SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); @@ -1883,7 +2029,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // Note that ICmpInst deals with pointer comparisons too so we must check // the type of the operand. if (ExitCond == 0 || isa(ExitCond->getOperand(0)->getType())) - return ComputeIterationCountExhaustively(L, ExitBr->getCondition(), + return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(), ExitBr->getSuccessor(0) == ExitBlock); // If the condition was exit on true, convert the condition to exit on false @@ -1897,7 +2043,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { SCEVHandle ItCnt = - ComputeLoadConstantCompareIterationCount(LI, RHS, L, Cond); + ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); if (!isa(ItCnt)) return ItCnt; } @@ -1912,8 +2058,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. - if (isa(LHS) && !isa(RHS)) { - // If there is a constant, force it into the RHS. + if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { + // If there is a loop-invariant, force it into the RHS. std::swap(LHS, RHS); Cond = ICmpInst::getSwappedPredicate(Cond); } @@ -1962,8 +2108,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { break; } case ICmpInst::ICMP_SGT: { - SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS), - SE.getNegativeSCEV(RHS), L, true); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, true); if (!isa(TC)) return TC; break; } @@ -1980,7 +2126,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { } default: #if 0 - cerr << "ComputeIterationCount "; + cerr << "ComputeBackedgeTakenCount "; if (ExitCond->getOperand(0)->getType()->isUnsigned()) cerr << "[unsigned] "; cerr << *LHS << " " @@ -1989,8 +2135,9 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { #endif break; } - return ComputeIterationCountExhaustively(L, ExitCond, - ExitBr->getSuccessor(0) == ExitBlock); + return + ComputeBackedgeTakenCountExhaustively(L, ExitCond, + ExitBr->getSuccessor(0) == ExitBlock); } static ConstantInt * @@ -2037,12 +2184,13 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, return Init; } -/// ComputeLoadConstantCompareIterationCount - Given an exit condition of -/// 'icmp op load X, cst', try to see if we can compute the trip count. +/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of +/// 'icmp op load X, cst', try to see if we can compute the backedge +/// execution count. SCEVHandle ScalarEvolutionsImpl:: -ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, - const Loop *L, - ICmpInst::Predicate predicate) { +ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { if (LI->isVolatile()) return UnknownValue; // Check to see if the loaded pointer is a getelementptr of a global. @@ -2199,13 +2347,13 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. Constant *ScalarEvolutionsImpl:: -getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){ +getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) return I->second; - if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations))) + if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations))) return ConstantEvolutionLoopExitValue[PN] = 0; // Not going to evaluate it. Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; @@ -2225,10 +2373,10 @@ getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){ return RetVal = 0; // Not derived from same PHI. // Execute the loop symbolically to determine the exit value. - if (Its.getActiveBits() >= 32) + if (BEs.getActiveBits() >= 32) return RetVal = 0; // More than 2^32-1 iterations?? Not doing it! - unsigned NumIterations = Its.getZExtValue(); // must be in range + unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; for (Constant *PHIVal = StartCST; ; ++IterationNum) { if (IterationNum == NumIterations) @@ -2244,13 +2392,13 @@ getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){ } } -/// ComputeIterationCountExhaustively - If the trip is known to execute a +/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a /// constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot /// evaluate the trip count of the loop, return UnknownValue. SCEVHandle ScalarEvolutionsImpl:: -ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { +ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); if (PN == 0) return UnknownValue; @@ -2313,15 +2461,17 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (PHINode *PN = dyn_cast(I)) if (PN->getParent() == LI->getHeader()) { // Okay, there is no closed form solution for the PHI node. Check - // to see if the loop that contains it has a known iteration count. - // If so, we may be able to force computation of the exit value. - SCEVHandle IterationCount = getIterationCount(LI); - if (SCEVConstant *ICC = dyn_cast(IterationCount)) { + // to see if the loop that contains it has a known backedge-taken + // count. If so, we may be able to force computation of the exit + // value. + SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI); + if (SCEVConstant *BTCC = + dyn_cast(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. Constant *RV = getConstantEvolutionLoopExitValue(PN, - ICC->getValue()->getValue(), + BTCC->getValue()->getValue(), LI); if (RV) return SE.getUnknown(RV); } @@ -2425,20 +2575,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (!L || !AddRec->getLoop()->contains(L->getHeader())) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. - SCEVHandle IterationCount = getIterationCount(AddRec->getLoop()); - if (IterationCount == UnknownValue) return UnknownValue; - IterationCount = SE.getTruncateOrZeroExtend(IterationCount, - AddRec->getType()); - - // If the value is affine, simplify the expression evaluation to just - // Start + Step*IterationCount. - if (AddRec->isAffine()) - return SE.getAddExpr(AddRec->getStart(), - SE.getMulExpr(IterationCount, - AddRec->getOperand(1))); - - // Otherwise, evaluate it the hard way. - return AddRec->evaluateAtIteration(IterationCount, SE); + SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == UnknownValue) return UnknownValue; + + // Then, evaluate the AddRec. + return AddRec->evaluateAtIteration(BackedgeTakenCount, SE); } return UnknownValue; } @@ -2543,6 +2684,11 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA( A << 1 ); + if (TwoA.isMinValue()) { + SCEV *CNC = new SCEVCouldNotCompute(); + return std::make_pair(CNC, CNC); + } + ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); @@ -2649,68 +2795,130 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { return UnknownValue; } -/// executesAtLeastOnce - Test whether entry to the loop is protected by +/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB +/// (which may not be an immediate predecessor) which has exactly one +/// successor from which BB is reachable, or null if no such block is +/// found. +/// +BasicBlock * +ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { + // If the block has a unique predecessor, the predecessor must have + // no other successors from which BB is reachable. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return Pred; + + // A loop's header is defined to be a block that dominates the loop. + // If the loop has a preheader, it must be a block that has exactly + // one successor that can reach BB. This is slightly more strict + // than necessary, but works if critical edges are split. + if (Loop *L = LI.getLoopFor(BB)) + return L->getLoopPreheader(); + + return 0; +} + +/// isLoopGuardedByCond - Test whether entry to the loop is protected by /// a conditional between LHS and RHS. -bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned, +bool ScalarEvolutionsImpl::isLoopGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, SCEV *LHS, SCEV *RHS) { BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *PreheaderDest = L->getHeader(); - if (Preheader == 0) return false; - - BranchInst *LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return false; - - // This might be a critical edge broken out. If the loop preheader ends in - // an unconditional branch to the loop, check to see if the preheader has a - // single predecessor, and if so, look for its terminator. - while (LoopEntryPredicate->isUnconditional()) { - PreheaderDest = Preheader; - Preheader = Preheader->getSinglePredecessor(); - if (!Preheader) return false; // Multiple preds. - - LoopEntryPredicate = + + // Starting at the preheader, climb up the predecessor chain, as long as + // there are predecessors that can be found that have unique successors + // leading to the original header. + for (; Preheader; + PreheaderDest = Preheader, + Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) { + + BranchInst *LoopEntryPredicate = dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return false; - } + if (!LoopEntryPredicate || + LoopEntryPredicate->isUnconditional()) + continue; + + ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); + if (!ICI) continue; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) + Cond = ICI->getPredicate(); + else + Cond = ICI->getInversePredicate(); - ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); - if (!ICI) return false; + if (Cond == Pred) + ; // An exact match. + else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) + ; // The actual condition is beyond sufficient. + else + // Check a few special cases. + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (Pred == ICmpInst::ICMP_ULT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + } + continue; + case ICmpInst::ICMP_SGT: + if (Pred == ICmpInst::ICMP_SLT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + } + continue; + case ICmpInst::ICMP_NE: + // Expressions like (x >u 0) are often canonicalized to (x != 0), + // so check for this case by checking if the NE is comparing against + // a minimum or maximum constant. + if (!ICmpInst::isTrueWhenEqual(Pred)) + if (ConstantInt *CI = dyn_cast(PreCondRHS)) { + const APInt &A = CI->getValue(); + switch (Pred) { + case ICmpInst::ICMP_SLT: + if (A.isMaxSignedValue()) break; + continue; + case ICmpInst::ICMP_SGT: + if (A.isMinSignedValue()) break; + continue; + case ICmpInst::ICMP_ULT: + if (A.isMaxValue()) break; + continue; + case ICmpInst::ICMP_UGT: + if (A.isMinValue()) break; + continue; + default: + continue; + } + Cond = ICmpInst::ICMP_NE; + // NE is symmetric but the original comparison may not be. Swap + // the operands if necessary so that they match below. + if (isa(LHS)) + std::swap(PreCondLHS, PreCondRHS); + break; + } + continue; + default: + // We weren't able to reconcile the condition. + continue; + } - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); + if (!PreCondLHS->getType()->isInteger()) continue; - switch (Cond) { - case ICmpInst::ICMP_UGT: - if (isSigned) return false; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SGT: - if (!isSigned) return false; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; - break; - case ICmpInst::ICMP_ULT: - if (isSigned) return false; - break; - case ICmpInst::ICMP_SLT: - if (!isSigned) return false; - break; - default: - return false; + SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); + SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); + if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || + (LHS == SE.getNotSCEV(PreCondRHSSCEV) && + RHS == SE.getNotSCEV(PreCondLHSSCEV))) + return true; } - if (!PreCondLHS->getType()->isInteger()) return false; - - return LHS == getSCEV(PreCondLHS) && RHS == getSCEV(PreCondRHS); + return false; } /// HowManyLessThans - Return the number of times a backedge containing the @@ -2739,7 +2947,8 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { // First, we get the value of the LHS in the first iteration: n SCEVHandle Start = AddRec->getOperand(0); - if (executesAtLeastOnce(L, isSigned, + if (isLoopGuardedByCond(L, + isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, SE.getMinusSCEV(AddRec->getOperand(0), One), RHS)) { // Since we know that the condition is true in order to enter the loop, // we know that it will run exactly m-n times. @@ -2875,27 +3084,6 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, } } - // Fallback, if this is a general polynomial, figure out the progression - // through brute force: evaluate until we find an iteration that fails the - // test. This is likely to be slow, but getting an accurate trip count is - // incredibly important, we will be able to simplify the exit test a lot, and - // we are almost guaranteed to get a trip count in this case. - ConstantInt *TestVal = ConstantInt::get(getType(), 0); - ConstantInt *EndVal = TestVal; // Stop when we wrap around. - do { - ++NumBruteForceEvaluations; - SCEVHandle Val = evaluateAtIteration(SE.getConstant(TestVal), SE); - if (!isa(Val)) // This shouldn't happen. - return new SCEVCouldNotCompute(); - - // Check to see if we found the value! - if (!Range.contains(cast(Val)->getValue()->getValue())) - return SE.getConstant(TestVal); - - // Increment to test the next index. - TestVal = ConstantInt::get(TestVal->getValue()+1); - } while (TestVal != EndVal); - return new SCEVCouldNotCompute(); } @@ -2938,12 +3126,23 @@ void ScalarEvolution::setSCEV(Value *V, const SCEVHandle &H) { } -SCEVHandle ScalarEvolution::getIterationCount(const Loop *L) const { - return ((ScalarEvolutionsImpl*)Impl)->getIterationCount(L); +bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + SCEV *LHS, SCEV *RHS) { + return ((ScalarEvolutionsImpl*)Impl)->isLoopGuardedByCond(L, Pred, + LHS, RHS); +} + +SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) const { + return ((ScalarEvolutionsImpl*)Impl)->getBackedgeTakenCount(L); +} + +bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) const { + return !isa(getBackedgeTakenCount(L)); } -bool ScalarEvolution::hasLoopInvariantIterationCount(const Loop *L) const { - return !isa(getIterationCount(L)); +void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { + return ((ScalarEvolutionsImpl*)Impl)->forgetLoopBackedgeTakenCount(L); } SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const { @@ -2967,10 +3166,10 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, if (ExitBlocks.size() != 1) OS << " "; - if (SE->hasLoopInvariantIterationCount(L)) { - OS << *SE->getIterationCount(L) << " iterations! "; + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); } else { - OS << "Unpredictable iteration count. "; + OS << "Unpredictable backedge-taken count. "; } OS << "\n"; @@ -2984,7 +3183,7 @@ void ScalarEvolution::print(std::ostream &OS, const Module* ) const { for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) if (I->getType()->isInteger()) { OS << *I; - OS << " --> "; + OS << " --> "; SCEVHandle SV = getSCEV(&*I); SV->print(OS); OS << "\t\t";