X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=d615c752b0444f4c91ec9284fe109ddab3736fad;hb=73b43b9b549a75fb0015c825df68abd95705a67c;hp=cc6cde2ba212bc2985fe940d8762e4ae53f13443;hpb=83bb0055fdac3c6234c4178cd429e6a917d06c4e;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index cc6cde2ba21..d615c752b04 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -2,8 +2,8 @@ // // The LLVM Compiler Infrastructure // -// This file was developed by the LLVM research group and is distributed under -// the University of Illinois Open Source License. See LICENSE.TXT for details. +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // @@ -95,16 +95,14 @@ STATISTIC(NumTripCountsNotComputed, STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); -cl::opt +static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant derived loop"), cl::init(100)); -namespace { - RegisterPass - R("scalar-evolution", "Scalar Evolution Analysis"); -} +static RegisterPass +R("scalar-evolution", "Scalar Evolution Analysis", false, true); char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// @@ -134,6 +132,12 @@ uint32_t SCEV::getBitWidth() const { return 0; } +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast(this)) + return SC->getValue()->isZero(); + return false; +} + SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} @@ -318,6 +322,10 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, return SE.getAddExpr(NewOps); else if (isa(this)) return SE.getMulExpr(NewOps); + else if (isa(this)) + return SE.getSMaxExpr(NewOps); + else if (isa(this)) + return SE.getUMaxExpr(NewOps); else assert(0 && "Unknown commutative expr!"); } @@ -326,21 +334,21 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, } -// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular +// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular // input. Don't use a SCEVHandle here, or else the object will never be // deleted! static ManagedStatic, - SCEVSDivExpr*> > SCEVSDivs; + SCEVUDivExpr*> > SCEVUDivs; -SCEVSDivExpr::~SCEVSDivExpr() { - SCEVSDivs->erase(std::make_pair(LHS, RHS)); +SCEVUDivExpr::~SCEVUDivExpr() { + SCEVUDivs->erase(std::make_pair(LHS, RHS)); } -void SCEVSDivExpr::print(std::ostream &OS) const { - OS << "(" << *LHS << " /s " << *RHS << ")"; +void SCEVUDivExpr::print(std::ostream &OS) const { + OS << "(" << *LHS << " /u " << *RHS << ")"; } -const Type *SCEVSDivExpr::getType() const { +const Type *SCEVUDivExpr::getType() const { return LHS->getType(); } @@ -427,7 +435,7 @@ namespace { /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. struct VISIBILITY_HIDDEN SCEVComplexityCompare { - bool operator()(SCEV *LHS, SCEV *RHS) { + bool operator()(const SCEV *LHS, const SCEV *RHS) const { return LHS->getSCEVType() < RHS->getSCEVType(); } }; @@ -448,7 +456,7 @@ static void GroupByComplexity(std::vector &Ops) { if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - if (Ops[0]->getSCEVType() > Ops[1]->getSCEVType()) + if (SCEVComplexityCompare()(Ops[1], Ops[0])) std::swap(Ops[0], Ops[1]); return; } @@ -490,35 +498,29 @@ SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : - APFloat::IEEEdouble, Val)); + C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); else C = ConstantInt::get(Ty, Val); return getUnknown(C); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. -static SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty, - ScalarEvolution &SE) { - const Type *SrcTy = V->getType(); - assert(SrcTy->isInteger() && Ty->isInteger() && - "Cannot truncate or zero extend with non-integer arguments!"); - if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) - return V; // No conversion - if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) - return SE.getTruncateExpr(V, Ty); - return SE.getZeroExtendExpr(V, Ty); -} - /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { if (SCEVConstant *VC = dyn_cast(V)) return getUnknown(ConstantExpr::getNeg(VC->getValue())); - return getMulExpr(V, getIntegerSCEV(-1, V->getType())); + return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType()))); +} + +/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { + if (SCEVConstant *VC = dyn_cast(V)) + return getUnknown(ConstantExpr::getNot(VC->getValue())); + + SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType())); + return getMinusSCEV(AllOnes, V); } /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. @@ -530,57 +532,116 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, } -/// PartialFact - Compute V!/(V-NumSteps)! -static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps, - ScalarEvolution &SE) { +/// BinomialCoefficient - Compute BC(It, K). The result is of the same type as +/// It. Assume, K > 0. +static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, + ScalarEvolution &SE) { + // We are using the following formula for BC(It, K): + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! + // + // Suppose, W is the bitwidth of It (and of the return value as well). We + // must be prepared for overflow. Hence, we must assure that the result of + // our computation is equal to the accurate one modulo 2^W. Unfortunately, + // division isn't safe in modular arithmetic. This means we must perform the + // whole computation accurately and then truncate the result to W bits. + // + // The dividend of the formula is a multiplication of K integers of bitwidth + // W. K*W bits suffice to compute it accurately. + // + // FIXME: We assume the divisor can be accurately computed using 16-bit + // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In + // future we may use APInt to use the minimum number of bits necessary to + // compute it accurately. + // + // It is safe to use unsigned division here: the dividend is nonnegative and + // the divisor is positive. + + // Handle the simplest case efficiently. + if (K == 1) + return It; + + assert(K < 9 && "We cannot handle such long AddRecs yet."); + + // FIXME: A temporary hack to remove in future. Arbitrary precision integers + // aren't supported by the code generator yet. For the dividend, the bitwidth + // we use is the smallest power of 2 greater or equal to K*W and less or equal + // to 64. Note that setting the upper bound for bitwidth may still lead to + // miscompilation in some cases. + unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth()); + if (DividendBits > 64) + DividendBits = 64; +#if 0 // Waiting for the APInt support in the code generator... + unsigned DividendBits = K * It->getBitWidth(); +#endif + + const IntegerType *DividendTy = IntegerType::get(DividendBits); + const SCEVHandle ExIt = SE.getTruncateOrZeroExtend(It, DividendTy); + + // The final number of bits we need to perform the division is the maximum of + // dividend and divisor bitwidths. + const IntegerType *DivisionTy = + IntegerType::get(std::max(DividendBits, 16U)); + + // Compute K! We know K >= 2 here. + unsigned F = 2; + for (unsigned i = 3; i <= K; ++i) + F *= i; + APInt Divisor(DivisionTy->getBitWidth(), F); + // Handle this case efficiently, it is common to have constant iteration // counts while computing loop exit values. - if (SCEVConstant *SC = dyn_cast(V)) { - const APInt& Val = SC->getValue()->getValue(); - APInt Result(Val.getBitWidth(), 1); - for (; NumSteps; --NumSteps) - Result *= Val-(NumSteps-1); + if (SCEVConstant *SC = dyn_cast(ExIt)) { + const APInt& N = SC->getValue()->getValue(); + APInt Dividend(N.getBitWidth(), 1); + for (; K; --K) + Dividend *= N-(K-1); + if (DividendTy != DivisionTy) + Dividend = Dividend.zext(DivisionTy->getBitWidth()); + + APInt Result = Dividend.udiv(Divisor); + if (Result.getBitWidth() != It->getBitWidth()) + Result = Result.trunc(It->getBitWidth()); + return SE.getConstant(Result); } + + SCEVHandle Dividend = ExIt; + for (unsigned i = 1; i != K; ++i) + Dividend = + SE.getMulExpr(Dividend, + SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy))); - const Type *Ty = V->getType(); - if (NumSteps == 0) - return SE.getIntegerSCEV(1, Ty); - - SCEVHandle Result = V; - for (unsigned i = 1; i != NumSteps; ++i) - Result = SE.getMulExpr(Result, SE.getMinusSCEV(V, - SE.getIntegerSCEV(i, Ty))); - return Result; + return SE.getTruncateOrZeroExtend( + SE.getUDivExpr( + SE.getTruncateOrZeroExtend(Dividend, DivisionTy), + SE.getConstant(Divisor) + ), It->getType()); } - /// evaluateAtIteration - Return the value of this chain of recurrences at /// the specified iteration number. We can evaluate this recurrence by /// multiplying each element in the chain by the binomial coefficient /// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: /// -/// A*choose(It, 0) + B*choose(It, 1) + C*choose(It, 2) + D*choose(It, 3) +/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// -/// FIXME/VERIFY: I don't trust that this is correct in the face of overflow. -/// Is the binomial equation safe using modular arithmetic?? +/// where BC(It, k) stands for binomial coefficient. /// SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, ScalarEvolution &SE) const { SCEVHandle Result = getStart(); - int Divisor = 1; - const Type *Ty = It->getType(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { - SCEVHandle BC = PartialFact(It, i, SE); - Divisor *= i; - SCEVHandle Val = SE.getSDivExpr(SE.getMulExpr(BC, getOperand(i)), - SE.getIntegerSCEV(Divisor,Ty)); + // The computation is correct in the face of overflow provided that the + // multiplication is performed _after_ the evaluation of the binomial + // coefficient. + SCEVHandle Val = SE.getMulExpr(getOperand(i), + BinomialCoefficient(It, i, SE)); Result = SE.getAddExpr(Result, Val); } return Result; } - //===----------------------------------------------------------------------===// // SCEV Expression folder implementations //===----------------------------------------------------------------------===// @@ -639,6 +700,21 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type * return Result; } +/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion +/// of the input value to the specified type. If the type must be +/// extended, it is zero extended. +SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert(SrcTy->isInteger() && Ty->isInteger() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) + return V; // No conversion + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) + return getTruncateExpr(V, Ty); + return getZeroExtendExpr(V, Ty); +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -654,19 +730,12 @@ SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(Idx < Ops.size()); while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() + - RHSC->getValue()->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = getConstant(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant zero being added, strip it off. @@ -895,19 +964,12 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { ++Idx; while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Constant *Fold = ConstantInt::get(LHSC->getValue()->getValue() * - RHSC->getValue()->getValue()); - if (ConstantInt *CI = dyn_cast(Fold)) { - Ops[0] = getConstant(CI); - Ops.erase(Ops.begin()+1); // Erase the folded element - if (Ops.size() == 1) return Ops[0]; - LHSC = cast(Ops[0]); - } else { - // If we couldn't fold the expression, move to the next constant. Note - // that this is impossible to happen in practice because we always - // constant fold constant ints to constant ints. - ++Idx; - } + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); } // If we are left with a constant one being multiplied, strip it off. @@ -1037,24 +1099,22 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector &Ops) { return Result; } -SCEVHandle ScalarEvolution::getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { +SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) { if (SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X sdiv 1 --> x - if (RHSC->getValue()->isAllOnesValue()) - return getNegativeSCEV(LHS); // X sdiv -1 --> -x + return LHS; // X udiv 1 --> x if (SCEVConstant *LHSC = dyn_cast(LHS)) { Constant *LHSCV = LHSC->getValue(); Constant *RHSCV = RHSC->getValue(); - return getUnknown(ConstantExpr::getSDiv(LHSCV, RHSCV)); + return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV)); } } // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow. - SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS); + SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); return Result; } @@ -1082,11 +1142,10 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, const Loop *L) { if (Operands.size() == 1) return Operands[0]; - if (SCEVConstant *StepC = dyn_cast(Operands.back())) - if (StepC->getValue()->isZero()) { - Operands.pop_back(); - return getAddRecExpr(Operands, L); // { X,+,0 } --> X - } + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L); // { X,+,0 } --> X + } SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, std::vector(Operands.begin(), @@ -1095,6 +1154,166 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, return Result; } +SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getSMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getSMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty smax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::smax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } + + // If we are left with a constant -inf, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(true)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first SMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) + ++Idx; + + // Check to see if one of the operands is an SMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedSMax = false; + while (SCEVSMaxExpr *SMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedSMax = true; + } + + if (DeletedSMax) + return getSMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X smax Y smax Y --> X smax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced smax down to nothing!"); + + // Okay, it looks like we really DO need an smax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVSMaxExpr(Ops); + return Result; +} + +SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getUMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getUMaxExpr(std::vector Ops) { + assert(!Ops.empty() && "Cannot get empty umax!"); + if (Ops.size() == 1) return Ops[0]; + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (SCEVConstant *LHSC = dyn_cast(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::umax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast(Ops[0]); + } + + // If we are left with a constant zero, strip it off. + if (cast(Ops[0])->getValue()->isMinValue(false)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first UMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) + ++Idx; + + // Check to see if one of the operands is a UMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedUMax = false; + while (SCEVUMaxExpr *UMax = dyn_cast(Ops[Idx])) { + Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedUMax = true; + } + + if (DeletedUMax) + return getUMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X umax Y umax Y --> X umax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced umax down to nothing!"); + + // Okay, it looks like we really DO need a umax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVUMaxExpr(Ops); + return Result; +} + SCEVHandle ScalarEvolution::getUnknown(Value *V) { if (ConstantInt *CI = dyn_cast(V)) return getConstant(CI); @@ -1416,11 +1635,7 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { /// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. static uint32_t GetMinTrailingZeros(SCEVHandle S) { if (SCEVConstant *C = dyn_cast(S)) - // APInt::countTrailingZeros() returns the number of trailing zeros in its - // internal representation, which length may be greater than the represented - // value bitwidth. This is why we use a min operation here. - return std::min(C->getValue()->getValue().countTrailingZeros(), - C->getBitWidth()); + return C->getValue()->getValue().countTrailingZeros(); if (SCEVTruncateExpr *T = dyn_cast(S)) return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth()); @@ -1462,7 +1677,23 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { return MinOpRes; } - // SCEVSDivExpr, SCEVUnknown + if (SCEVSMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + if (SCEVUMaxExpr *M = dyn_cast(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); + return MinOpRes; + } + + // SCEVUDivExpr, SCEVUnknown return 0; } @@ -1470,77 +1701,128 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) { /// Analyze the expression. /// SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { - if (Instruction *I = dyn_cast(V)) { - switch (I->getOpcode()) { - case Instruction::Add: - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Mul: - return SE.getMulExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::SDiv: - return SE.getSDivExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Sub: - return SE.getMinusSCEV(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - SCEVHandle LHS = getSCEV(I->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return SE.getAddExpr(LHS, getSCEV(I->getOperand(1))); - } - break; - case Instruction::Xor: - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - if (CI->getValue().isSignBit()) - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - } - break; - - case Instruction::Shl: - // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { - uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( - APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); - return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X)); - } - break; + if (!isa(V->getType())) + return SE.getUnknown(V); + + unsigned Opcode = Instruction::UserOp1; + if (Instruction *I = dyn_cast(V)) + Opcode = I->getOpcode(); + else if (ConstantExpr *CE = dyn_cast(V)) + Opcode = CE->getOpcode(); + else + return SE.getUnknown(V); + + User *U = cast(V); + switch (Opcode) { + case Instruction::Add: + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Mul: + return SE.getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::UDiv: + return SE.getUDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Sub: + return SE.getMinusSCEV(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + SCEVHandle LHS = getSCEV(U->getOperand(0)); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) + return SE.getAddExpr(LHS, getSCEV(U->getOperand(1))); + } + break; + case Instruction::Xor: + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + if (CI->getValue().isSignBit()) + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + else if (CI->isAllOnesValue()) + return SE.getNotSCEV(getSCEV(U->getOperand(0))); + } + break; + + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; - case Instruction::Trunc: - return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::Trunc: + return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::ZExt: - return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::ZExt: + return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::SExt: - return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType()); + case Instruction::SExt: + return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); - case Instruction::BitCast: - // BitCasts are no-op casts so we just eliminate the cast. - if (I->getType()->isInteger() && - I->getOperand(0)->getType()->isInteger()) - return getSCEV(I->getOperand(0)); - break; + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (U->getType()->isInteger() && + U->getOperand(0)->getType()->isInteger()) + return getSCEV(U->getOperand(0)); + break; - case Instruction::PHI: - return createNodeForPHI(cast(I)); + case Instruction::PHI: + return createNodeForPHI(cast(U)); - default: // We cannot analyze this expression. - break; + case Instruction::Select: + // This could be a smax or umax that was lowered earlier. + // Try to recover it. + if (ICmpInst *ICI = dyn_cast(U->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // -smax(-x, -y) == smin(x, y). + return SE.getNegativeSCEV(SE.getSMaxExpr( + SE.getNegativeSCEV(getSCEV(LHS)), + SE.getNegativeSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + default: + break; + } } + + default: // We cannot analyze this expression. + break; } return SE.getUnknown(V); @@ -1620,7 +1902,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { ICmpInst *ExitCond = dyn_cast(ExitBr->getCondition()); - // If its not an integer comparison then compute it the hard way. + // If it's not an integer comparison then compute it the hard way. // Note that ICmpInst deals with pointer comparisons too so we must check // the type of the operand. if (ExitCond == 0 || isa(ExitCond->getOperand(0)->getType())) @@ -1714,8 +1996,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { break; } case ICmpInst::ICMP_UGT: { - SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS), - SE.getNegativeSCEV(RHS), L, false); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false); if (!isa(TC)) return TC; break; } @@ -1779,7 +2061,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, } /// ComputeLoadConstantCompareIterationCount - Given an exit condition of -/// 'icmp op load X, cst', try to se if we can compute the trip count. +/// 'icmp op load X, cst', try to see if we can compute the trip count. SCEVHandle ScalarEvolutionsImpl:: ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, @@ -1864,7 +2146,7 @@ static bool CanConstantFold(const Instruction *I) { if (const CallInst *CI = dyn_cast(I)) if (const Function *F = CI->getCalledFunction()) - return canConstantFoldCallTo((Function*)F); // FIXME: elim cast + return canConstantFoldCallTo(F); return false; } @@ -1879,13 +2161,14 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (I == 0 || !L->contains(I->getParent())) return 0; - if (PHINode *PN = dyn_cast(I)) + if (PHINode *PN = dyn_cast(I)) { if (L->getHeader() == I->getParent()) return PN; else // We don't currently keep track of the control flow needed to evaluate // PHIs, so we cannot handle PHIs inside of loops. return 0; + } // If we won't be able to constant fold this expression even if the operands // are constants, return early. @@ -1915,8 +2198,6 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// reason, return null. static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (isa(V)) return PHIVal; - if (GlobalValue *GV = dyn_cast(V)) - return GV; if (Constant *C = dyn_cast(V)) return C; Instruction *I = cast(V); @@ -1928,7 +2209,12 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { if (Operands[i] == 0) return 0; } - return ConstantFoldInstOperands(I, &Operands[0], Operands.size()); + if (const CmpInst *CI = dyn_cast(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is @@ -2041,7 +2327,7 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (isa(V)) return V; - // If this instruction is evolves from a constant-evolving PHI, compute the + // If this instruction is evolved from a constant-evolving PHI, compute the // exit value from the loop without using SCEVs. if (SCEVUnknown *SU = dyn_cast(V)) { if (Instruction *I = dyn_cast(SU->getValue())) { @@ -2076,6 +2362,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); } else { + // If any of the operands is non-constant and if they are + // non-integer, don't even try to analyze them with scev techniques. + if (!isa(Op->getType())) + return V; + SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L); if (SCEVConstant *SC = dyn_cast(OpV)) Operands.push_back(ConstantExpr::getIntegerCast(SC->getValue(), @@ -2093,7 +2384,14 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { } } } - Constant *C =ConstantFoldInstOperands(I, &Operands[0], Operands.size()); + + Constant *C; + if (const CmpInst *CI = dyn_cast(I)) + C = ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); return SE.getUnknown(C); } } @@ -2121,22 +2419,27 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { } if (isa(Comm)) return SE.getAddExpr(NewOps); - assert(isa(Comm) && "Only know about add and mul!"); - return SE.getMulExpr(NewOps); + if (isa(Comm)) + return SE.getMulExpr(NewOps); + if (isa(Comm)) + return SE.getSMaxExpr(NewOps); + if (isa(Comm)) + return SE.getUMaxExpr(NewOps); + assert(0 && "Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. return Comm; } - if (SCEVSDivExpr *Div = dyn_cast(V)) { + if (SCEVUDivExpr *Div = dyn_cast(V)) { SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); if (LHS == UnknownValue) return LHS; SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); if (RHS == UnknownValue) return RHS; if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant - return SE.getSDivExpr(LHS, RHS); + return SE.getUDivExpr(LHS, RHS); } // If this is a loop recurrence for a loop that does not contain L, then we @@ -2147,8 +2450,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // loop iterates. Compute this now. SCEVHandle IterationCount = getIterationCount(AddRec->getLoop()); if (IterationCount == UnknownValue) return UnknownValue; - IterationCount = getTruncateOrZeroExtend(IterationCount, - AddRec->getType(), SE); + IterationCount = SE.getTruncateOrZeroExtend(IterationCount, + AddRec->getType()); // If the value is affine, simplify the expression evaluation to just // Start + Step*IterationCount. @@ -2263,9 +2566,9 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { if (SCEVConstant *StartC = dyn_cast(Start)) { ConstantInt *StartCC = StartC->getValue(); Constant *StartNegC = ConstantExpr::getNeg(StartCC); - Constant *Rem = ConstantExpr::getSRem(StartNegC, StepC->getValue()); + Constant *Rem = ConstantExpr::getURem(StartNegC, StepC->getValue()); if (Rem->isNullValue()) { - Constant *Result =ConstantExpr::getSDiv(StartNegC,StepC->getValue()); + Constant *Result = ConstantExpr::getUDiv(StartNegC,StepC->getValue()); return SE.getUnknown(Result); } } @@ -2292,9 +2595,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE); - if (SCEVConstant *EvalVal = dyn_cast(Val)) - if (EvalVal->getValue()->isZero()) - return R1; // We found a quadratic root! + if (Val->isZero()) + return R1; // We found a quadratic root! } } } @@ -2313,11 +2615,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) { // If the value is a constant, check to see if it is known to be non-zero // already. If so, the backedge will execute zero times. if (SCEVConstant *C = dyn_cast(V)) { - Constant *Zero = Constant::getNullValue(C->getValue()->getType()); - Constant *NonZero = - ConstantExpr::getICmp(ICmpInst::ICMP_NE, C->getValue(), Zero); - if (NonZero == ConstantInt::getTrue()) - return getSCEV(Zero); + if (!C->getValue()->isNullValue()) + return SE.getIntegerSCEV(0, C->getType()); return UnknownValue; // Otherwise it will loop infinitely. } @@ -2340,89 +2639,26 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) { if (AddRec->isAffine()) { // FORNOW: We only support unit strides. - SCEVHandle Zero = SE.getIntegerSCEV(0, RHS->getType()); SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType()); if (AddRec->getOperand(1) != One) return UnknownValue; - // The number of iterations for "{n,+,1} < m", is m-n. However, we don't - // know that m is >= n on input to the loop. If it is, the condition return - // true zero times. What we really should return, for full generality, is - // SMAX(0, m-n). Since we cannot check this, we will instead check for a - // canonical loop form: most do-loops will have a check that dominates the - // loop, that only enters the loop if (n-1)= n. - - // Search for the check. - BasicBlock *Preheader = L->getLoopPreheader(); - BasicBlock *PreheaderDest = L->getHeader(); - if (Preheader == 0) return UnknownValue; - - BranchInst *LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - - // This might be a critical edge broken out. If the loop preheader ends in - // an unconditional branch to the loop, check to see if the preheader has a - // single predecessor, and if so, look for its terminator. - while (LoopEntryPredicate->isUnconditional()) { - PreheaderDest = Preheader; - Preheader = Preheader->getSinglePredecessor(); - if (!Preheader) return UnknownValue; // Multiple preds. - - LoopEntryPredicate = - dyn_cast(Preheader->getTerminator()); - if (!LoopEntryPredicate) return UnknownValue; - } - - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - if (ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition())){ - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); - - switch (Cond) { - case ICmpInst::ICMP_UGT: - if (isSigned) return UnknownValue; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SGT: - if (!isSigned) return UnknownValue; - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; - break; - case ICmpInst::ICMP_ULT: - if (isSigned) return UnknownValue; - break; - case ICmpInst::ICMP_SLT: - if (!isSigned) return UnknownValue; - break; - default: - return UnknownValue; - } + // We know the LHS is of the form {n,+,1} and the RHS is some loop-invariant + // m. So, we count the number of iterations in which {n,+,1} < m is true. + // Note that we cannot simply return max(m-n,0) because it's not safe to + // treat m-n as signed nor unsigned due to overflow possibility. - if (PreCondLHS->getType()->isInteger()) { - if (RHS != getSCEV(PreCondRHS)) - return UnknownValue; // Not a comparison against 'm'. + // First, we get the value of the LHS in the first iteration: n + SCEVHandle Start = AddRec->getOperand(0); - if (SE.getMinusSCEV(AddRec->getOperand(0), One) - != getSCEV(PreCondLHS)) - return UnknownValue; // Not a comparison against 'n-1'. - } - else return UnknownValue; + // Then, we get the value of the LHS in the first iteration in which the + // above condition doesn't hold. This equals to max(m,n). + SCEVHandle End = isSigned ? SE.getSMaxExpr(RHS, Start) + : SE.getUMaxExpr(RHS, Start); - // cerr << "Computed Loop Trip Count as: " - // << // *SE.getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n"; - return SE.getMinusSCEV(RHS, AddRec->getOperand(0)); - } - else - return UnknownValue; + // Finally, we subtract these two values to get the number of times the + // backedge is executed: max(m,n)-n. + return SE.getMinusSCEV(End, Start); } return UnknownValue; @@ -2629,20 +2865,20 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE, for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) PrintLoopInfo(OS, SE, *I); - cerr << "Loop " << L->getHeader()->getName() << ": "; + OS << "Loop " << L->getHeader()->getName() << ": "; SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) - cerr << " "; + OS << " "; if (SE->hasLoopInvariantIterationCount(L)) { - cerr << *SE->getIterationCount(L) << " iterations! "; + OS << *SE->getIterationCount(L) << " iterations! "; } else { - cerr << "Unpredictable iteration count. "; + OS << "Unpredictable iteration count. "; } - cerr << "\n"; + OS << "\n"; } void ScalarEvolution::print(std::ostream &OS, const Module* ) const {