X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=3a09f2f9d7b170892301fc124c7f8a03f80f2e32;hb=2a0f3ccc9c10186309d5d6a0c4cebe8b477f352a;hp=50137fad674469a9e5453d98e085080e78fabdd0;hpb=b8be8b70a2789f1ccb12ce4e68e54e5992d7e5a0;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 50137fad674..3a09f2f9d7b 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -14,9 +14,8 @@ // There are several aspects to this library. First is the representation of // scalar expressions, which are represented as subclasses of the SCEV class. // These classes are used to represent certain types of subexpressions that we -// can handle. These classes are reference counted, managed by the const SCEV* -// class. We only create one SCEV of a particular shape, so pointer-comparisons -// for equality are legal. +// can handle. We only create one SCEV of a particular shape, so +// pointer-comparisons for equality are legal. // // One important aspect of the SCEV objects is that they are never cyclic, even // if there is a cycle in the dataflow for an expression (ie, a PHI node). If @@ -64,7 +63,10 @@ #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" #include "llvm/GlobalVariable.h" +#include "llvm/GlobalAlias.h" #include "llvm/Instructions.h" +#include "llvm/LLVMContext.h" +#include "llvm/Operator.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" @@ -72,14 +74,15 @@ #include "llvm/Assembly/Writer.h" #include "llvm/Target/TargetData.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/ConstantRange.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GetElementPtrTypeIterator.h" #include "llvm/Support/InstIterator.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include using namespace llvm; @@ -95,7 +98,8 @@ STATISTIC(NumBruteForceTripCountsComputed, static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " - "symbolically execute a constant derived loop"), + "symbolically execute a constant " + "derived loop"), cl::init(100)); static RegisterPass @@ -109,17 +113,14 @@ char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// // Implementation of the SCEV class. // + SCEV::~SCEV() {} + void SCEV::dump() const { print(errs()); errs() << '\n'; } -void SCEV::print(std::ostream &o) const { - raw_os_ostream OS(o); - print(OS); -} - bool SCEV::isZero() const { if (const SCEVConstant *SC = dyn_cast(this)) return SC->getValue()->isZero(); @@ -139,28 +140,26 @@ bool SCEV::isAllOnesValue() const { } SCEVCouldNotCompute::SCEVCouldNotCompute() : - SCEV(scCouldNotCompute) {} + SCEV(FoldingSetNodeID(), scCouldNotCompute) {} bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { - assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return false; } const Type *SCEVCouldNotCompute::getType() const { - assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return 0; } bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { - assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); return false; } -const SCEV* SCEVCouldNotCompute:: -replaceSymbolicValuesWithConcrete(const SCEV* Sym, - const SCEV* Conc, - ScalarEvolution &SE) const { - return this; +bool SCEVCouldNotCompute::hasOperand(const SCEV *) const { + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + return false; } void SCEVCouldNotCompute::print(raw_ostream &OS) const { @@ -171,24 +170,26 @@ bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } - -// SCEVConstants - Only allow the creation of one SCEVConstant for any -// particular value. Don't use a const SCEV* here, or else the object will -// never be deleted! - -const SCEV* ScalarEvolution::getConstant(ConstantInt *V) { - SCEVConstant *&R = SCEVConstants[V]; - if (R == 0) R = new SCEVConstant(V); - return R; +const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { + FoldingSetNodeID ID; + ID.AddInteger(scConstant); + ID.AddPointer(V); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVConstant(ID, V); + UniqueSCEVs.InsertNode(S, IP); + return S; } -const SCEV* ScalarEvolution::getConstant(const APInt& Val) { - return getConstant(ConstantInt::get(Val)); +const SCEV *ScalarEvolution::getConstant(const APInt& Val) { + return getConstant(ConstantInt::get(getContext(), Val)); } -const SCEV* +const SCEV * ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) { - return getConstant(ConstantInt::get(cast(Ty), V, isSigned)); + return getConstant( + ConstantInt::get(cast(Ty), V, isSigned)); } const Type *SCEVConstant::getType() const { return V->getType(); } @@ -197,36 +198,33 @@ void SCEVConstant::print(raw_ostream &OS) const { WriteAsOperand(OS, V, false); } -SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy, - const SCEV* op, const Type *ty) - : SCEV(SCEVTy), Op(op), Ty(ty) {} +SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeID &ID, + unsigned SCEVTy, const SCEV *op, const Type *ty) + : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return Op->dominates(BB, DT); } -// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any -// particular input. Don't use a const SCEV* here, or else the object will -// never be deleted! +bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->properlyDominates(BB, DT); +} -SCEVTruncateExpr::SCEVTruncateExpr(const SCEV* op, const Type *ty) - : SCEVCastExpr(scTruncate, op, ty) { +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID, + const SCEV *op, const Type *ty) + : SCEVCastExpr(ID, scTruncate, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot truncate non-integer value!"); } - void SCEVTruncateExpr::print(raw_ostream &OS) const { OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any -// particular input. Don't use a const SCEV* here, or else the object will never -// be deleted! - -SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEV* op, const Type *ty) - : SCEVCastExpr(scZeroExtend, op, ty) { +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeID &ID, + const SCEV *op, const Type *ty) + : SCEVCastExpr(ID, scZeroExtend, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot zero extend non-integer value!"); @@ -236,12 +234,9 @@ void SCEVZeroExtendExpr::print(raw_ostream &OS) const { OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any -// particular input. Don't use a const SCEV* here, or else the object will never -// be deleted! - -SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEV* op, const Type *ty) - : SCEVCastExpr(scSignExtend, op, ty) { +SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeID &ID, + const SCEV *op, const Type *ty) + : SCEVCastExpr(ID, scSignExtend, op, ty) { assert((Op->getType()->isInteger() || isa(Op->getType())) && (Ty->isInteger() || isa(Ty)) && "Cannot sign extend non-integer value!"); @@ -251,10 +246,6 @@ void SCEVSignExtendExpr::print(raw_ostream &OS) const { OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; } -// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any -// particular input. Don't use a const SCEV* here, or else the object will never -// be deleted! - void SCEVCommutativeExpr::print(raw_ostream &OS) const { assert(Operands.size() > 1 && "This plus expr shouldn't exist!"); const char *OpStr = getOperationStr(); @@ -264,55 +255,30 @@ void SCEVCommutativeExpr::print(raw_ostream &OS) const { OS << ")"; } -const SCEV* SCEVCommutativeExpr:: -replaceSymbolicValuesWithConcrete(const SCEV* Sym, - const SCEV* Conc, - ScalarEvolution &SE) const { +bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const SCEV* H = - getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); - if (H != getOperand(i)) { - SmallVector NewOps; - NewOps.reserve(getNumOperands()); - for (unsigned j = 0; j != i; ++j) - NewOps.push_back(getOperand(j)); - NewOps.push_back(H); - for (++i; i != e; ++i) - NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); - - if (isa(this)) - return SE.getAddExpr(NewOps); - else if (isa(this)) - return SE.getMulExpr(NewOps); - else if (isa(this)) - return SE.getSMaxExpr(NewOps); - else if (isa(this)) - return SE.getUMaxExpr(NewOps); - else - assert(0 && "Unknown commutative expr!"); - } + if (!getOperand(i)->dominates(BB, DT)) + return false; } - return this; + return true; } -bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { +bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - if (!getOperand(i)->dominates(BB, DT)) + if (!getOperand(i)->properlyDominates(BB, DT)) return false; } return true; } - -// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular -// input. Don't use a const SCEV* here, or else the object will never be -// deleted! - bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { return LHS->dominates(BB, DT) && RHS->dominates(BB, DT); } +bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT); +} + void SCEVUDivExpr::print(raw_ostream &OS) const { OS << "(" << *LHS << " /u " << *RHS << ")"; } @@ -326,44 +292,25 @@ const Type *SCEVUDivExpr::getType() const { return RHS->getType(); } -// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any -// particular input. Don't use a const SCEV* here, or else the object will never -// be deleted! +bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { + // Add recurrences are never invariant in the function-body (null loop). + if (!QueryLoop) + return false; -const SCEV* SCEVAddRecExpr:: -replaceSymbolicValuesWithConcrete(const SCEV* Sym, - const SCEV* Conc, - ScalarEvolution &SE) const { - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const SCEV* H = - getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); - if (H != getOperand(i)) { - SmallVector NewOps; - NewOps.reserve(getNumOperands()); - for (unsigned j = 0; j != i; ++j) - NewOps.push_back(getOperand(j)); - NewOps.push_back(H); - for (++i; i != e; ++i) - NewOps.push_back(getOperand(i)-> - replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); - - return SE.getAddRecExpr(NewOps, L); - } - } - return this; -} + // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L. + if (QueryLoop->contains(L)) + return false; + // This recurrence is variant w.r.t. QueryLoop if any of its operands + // are variant. + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) + if (!getOperand(i)->isLoopInvariant(QueryLoop)) + return false; -bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { - // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't - // contain L and if the start is invariant. - // Add recurrences are never invariant in the function-body (null loop). - return QueryLoop && - !QueryLoop->contains(L->getHeader()) && - getOperand(0)->isLoopInvariant(QueryLoop); + // Otherwise it's loop-invariant. + return true; } - void SCEVAddRecExpr::print(raw_ostream &OS) const { OS << "{" << *Operands[0]; for (unsigned i = 1, e = Operands.size(); i != e; ++i) @@ -371,9 +318,14 @@ void SCEVAddRecExpr::print(raw_ostream &OS) const { OS << "}<" << L->getHeader()->getName() + ">"; } -// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular -// value. Don't use a const SCEV* here, or else the object will never be -// deleted! +void SCEVFieldOffsetExpr::print(raw_ostream &OS) const { + // LLVM struct fields don't have names, so just print the field number. + OS << "offsetof(" << *STy << ", " << FieldNo << ")"; +} + +void SCEVAllocSizeExpr::print(raw_ostream &OS) const { + OS << "sizeof(" << *AllocTy << ")"; +} bool SCEVUnknown::isLoopInvariant(const Loop *L) const { // All non-instruction values are loop invariant. All instructions are loop @@ -381,7 +333,7 @@ bool SCEVUnknown::isLoopInvariant(const Loop *L) const { // Instructions are never considered invariant in the function body // (null loop) because they are defined within the "loop". if (Instruction *I = dyn_cast(V)) - return L && !L->contains(I->getParent()); + return L && !L->contains(I); return true; } @@ -391,6 +343,12 @@ bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const { return true; } +bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const { + if (Instruction *I = dyn_cast(getValue())) + return DT->properlyDominates(I->getParent(), BB); + return true; +} + const Type *SCEVUnknown::getType() const { return V->getType(); } @@ -403,16 +361,55 @@ void SCEVUnknown::print(raw_ostream &OS) const { // SCEV Utilities //===----------------------------------------------------------------------===// +static bool CompareTypes(const Type *A, const Type *B) { + if (A->getTypeID() != B->getTypeID()) + return A->getTypeID() < B->getTypeID(); + if (const IntegerType *AI = dyn_cast(A)) { + const IntegerType *BI = cast(B); + return AI->getBitWidth() < BI->getBitWidth(); + } + if (const PointerType *AI = dyn_cast(A)) { + const PointerType *BI = cast(B); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const ArrayType *AI = dyn_cast(A)) { + const ArrayType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const VectorType *AI = dyn_cast(A)) { + const VectorType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + return CompareTypes(AI->getElementType(), BI->getElementType()); + } + if (const StructType *AI = dyn_cast(A)) { + const StructType *BI = cast(B); + if (AI->getNumElements() != BI->getNumElements()) + return AI->getNumElements() < BI->getNumElements(); + for (unsigned i = 0, e = AI->getNumElements(); i != e; ++i) + if (CompareTypes(AI->getElementType(i), BI->getElementType(i)) || + CompareTypes(BI->getElementType(i), AI->getElementType(i))) + return CompareTypes(AI->getElementType(i), BI->getElementType(i)); + } + return false; +} + namespace { /// SCEVComplexityCompare - Return true if the complexity of the LHS is less /// than the complexity of the RHS. This comparator is used to canonicalize /// expressions. - class VISIBILITY_HIDDEN SCEVComplexityCompare { + class SCEVComplexityCompare { LoopInfo *LI; public: explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {} bool operator()(const SCEV *LHS, const SCEV *RHS) const { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return false; + // Primarily, sort the SCEVs by their getSCEVType(). if (LHS->getSCEVType() != RHS->getSCEVType()) return LHS->getSCEVType() < RHS->getSCEVType(); @@ -469,6 +466,8 @@ namespace { // Compare constant values. if (const SCEVConstant *LC = dyn_cast(LHS)) { const SCEVConstant *RC = cast(RHS); + if (LC->getValue()->getBitWidth() != RC->getValue()->getBitWidth()) + return LC->getValue()->getBitWidth() < RC->getValue()->getBitWidth(); return LC->getValue()->getValue().ult(RC->getValue()->getValue()); } @@ -513,7 +512,22 @@ namespace { return operator()(LC->getOperand(), RC->getOperand()); } - assert(0 && "Unknown SCEV kind!"); + // Compare offsetof expressions. + if (const SCEVFieldOffsetExpr *LA = dyn_cast(LHS)) { + const SCEVFieldOffsetExpr *RA = cast(RHS); + if (CompareTypes(LA->getStructType(), RA->getStructType()) || + CompareTypes(RA->getStructType(), LA->getStructType())) + return CompareTypes(LA->getStructType(), RA->getStructType()); + return LA->getFieldNo() < RA->getFieldNo(); + } + + // Compare sizeof expressions by the allocation type. + if (const SCEVAllocSizeExpr *LA = dyn_cast(LHS)) { + const SCEVAllocSizeExpr *RA = cast(RHS); + return CompareTypes(LA->getAllocType(), RA->getAllocType()); + } + + llvm_unreachable("Unknown SCEV kind!"); return false; } }; @@ -529,7 +543,7 @@ namespace { /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. /// -static void GroupByComplexity(SmallVectorImpl &Ops, +static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop if (Ops.size() == 2) { @@ -572,9 +586,9 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// BinomialCoefficient - Compute BC(It, K). The result has width W. /// Assume, K > 0. -static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, - ScalarEvolution &SE, - const Type* ResultTy) { +static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, + ScalarEvolution &SE, + const Type* ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); @@ -589,7 +603,7 @@ static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, // safe in modular arithmetic. // // However, this code doesn't use exactly that formula; the formula it uses - // is something like the following, where T is the number of factors of 2 in + // is something like the following, where T is the number of factors of 2 in // K! (i.e. trailing zeros in the binary representation of K!), and ^ is // exponentiation: // @@ -601,7 +615,7 @@ static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, // arithmetic. To do exact division in modular arithmetic, all we have // to do is multiply by the inverse. Therefore, this step can be done at // width W. - // + // // The next issue is how to safely do the division by 2^T. The way this // is done is by doing the multiplication step at a width of at least W + T // bits. This way, the bottom W+T bits of the product are accurate. Then, @@ -664,16 +678,17 @@ static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, MultiplyFactor = MultiplyFactor.trunc(W); // Calculate the product, at width T+W - const IntegerType *CalculationTy = IntegerType::get(CalculationBits); - const SCEV* Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + const IntegerType *CalculationTy = IntegerType::get(SE.getContext(), + CalculationBits); + const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { - const SCEV* S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + const SCEV *S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } // Divide by 2^T - const SCEV* DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); // Truncate the result, and divide by K! / 2^T. @@ -690,14 +705,14 @@ static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K, /// /// where BC(It, k) stands for binomial coefficient. /// -const SCEV* SCEVAddRecExpr::evaluateAtIteration(const SCEV* It, - ScalarEvolution &SE) const { - const SCEV* Result = getStart(); +const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, + ScalarEvolution &SE) const { + const SCEV *Result = getStart(); for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - const SCEV* Coeff = BinomialCoefficient(It, i, SE, getType()); + const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); if (isa(Coeff)) return Coeff; @@ -710,14 +725,22 @@ const SCEV* SCEVAddRecExpr::evaluateAtIteration(const SCEV* It, // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -const SCEV* ScalarEvolution::getTruncateExpr(const SCEV* Op, - const Type *Ty) { +const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, + const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); + FoldingSetNodeID ID; + ID.AddInteger(scTruncate); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + + // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) return getConstant( cast(ConstantExpr::getTrunc(SC->getValue(), Ty))); @@ -736,25 +759,30 @@ const SCEV* ScalarEvolution::getTruncateExpr(const SCEV* Op, // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); return getAddRecExpr(Operands, AddRec->getLoop()); } - SCEVTruncateExpr *&Result = SCEVTruncates[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty); - return Result; + // The cast wasn't folded; create an explicit cast node. + // Recompute the insert position, as it may have been invalidated. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVTruncateExpr(ID, Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } -const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, - const Type *Ty) { +const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, + const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); + // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) { const Type *IntTy = getEffectiveSCEVType(Ty); Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy); @@ -766,12 +794,33 @@ const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, if (const SCEVZeroExtendExpr *SZ = dyn_cast(Op)) return getZeroExtendExpr(SZ->getOperand(), Ty); + // Before doing any expensive analysis, check to see if we've already + // computed a SCEV for this Op and Ty. + FoldingSetNodeID ID; + ID.AddInteger(scZeroExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can zero extend all of the // operands (often constants). This allows analysis of something like // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoUnsignedWrap()) + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -780,28 +829,25 @@ const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - const SCEV* Start = AR->getStart(); - const SCEV* Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV* CastedMaxBECount = + const SCEV *CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); - const SCEV* RecastedMaxBECount = + const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = - IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. - const SCEV* ZMul = + const SCEV *ZMul = getMulExpr(CastedMaxBECount, getTruncateOrZeroExtend(Step, Start->getType())); - const SCEV* Add = getAddExpr(Start, ZMul); - const SCEV* OperandExtendedAdd = + const SCEV *Add = getAddExpr(Start, ZMul); + const SCEV *OperandExtendedAdd = getAddExpr(getZeroExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getZeroExtendExpr(Step, WideTy))); @@ -809,11 +855,11 @@ const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getZeroExtendExpr(Step, Ty), - AR->getLoop()); + L); // Similar to above, only this time treat the step value as signed. // This covers loops that count down. - const SCEV* SMul = + const SCEV *SMul = getMulExpr(CastedMaxBECount, getTruncateOrSignExtend(Step, Start->getType())); Add = getAddExpr(Start, SMul); @@ -825,24 +871,57 @@ const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getZeroExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - AR->getLoop()); + L); + } + + // If the backedge is guarded by a comparison with the pre-inc value + // the addrec is safe. Also, if the entry is guarded by a comparison + // with the start value and the backedge is guarded by a comparison + // with the post-inc value, the addrec is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - + getUnsignedRange(Step).getUnsignedMax()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); + } else if (isKnownNegative(Step)) { + const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - + getSignedRange(Step).getSignedMin()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) && + (isLoopGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) || + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); } } } - SCEVZeroExtendExpr *&Result = SCEVZeroExtends[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty); - return Result; + // The cast wasn't folded; create an explicit cast node. + // Recompute the insert position, as it may have been invalidated. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVZeroExtendExpr(ID, Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } -const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op, - const Type *Ty) { +const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, + const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); Ty = getEffectiveSCEVType(Ty); + // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) { const Type *IntTy = getEffectiveSCEVType(Ty); Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy); @@ -854,12 +933,33 @@ const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op, if (const SCEVSignExtendExpr *SS = dyn_cast(Op)) return getSignExtendExpr(SS->getOperand(), Ty); + // Before doing any expensive analysis, check to see if we've already + // computed a SCEV for this Op and Ty. + FoldingSetNodeID ID; + ID.AddInteger(scSignExtend); + ID.AddPointer(Op); + ID.AddPointer(Ty); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the // operands (often constants). This allows analysis of something like // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { + const SCEV *Start = AR->getStart(); + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + + // If we have special knowledge that this addrec won't overflow, + // we don't need to do any further analysis. + if (AR->hasNoSignedWrap()) + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); + // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are // simply not analyzable, and it covers the case where this code is @@ -868,28 +968,25 @@ const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. - const SCEV* Start = AR->getStart(); - const SCEV* Step = AR->getStepRecurrence(*this); // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV* CastedMaxBECount = + const SCEV *CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType()); - const SCEV* RecastedMaxBECount = + const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); if (MaxBECount == RecastedMaxBECount) { - const Type *WideTy = - IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. - const SCEV* SMul = + const SCEV *SMul = getMulExpr(CastedMaxBECount, getTruncateOrSignExtend(Step, Start->getType())); - const SCEV* Add = getAddExpr(Start, SMul); - const SCEV* OperandExtendedAdd = + const SCEV *Add = getAddExpr(Start, SMul); + const SCEV *OperandExtendedAdd = getAddExpr(getSignExtendExpr(Start, WideTy), getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), getSignExtendExpr(Step, WideTy))); @@ -897,21 +994,69 @@ const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op, // Return the expression with the addrec on the outside. return getAddRecExpr(getSignExtendExpr(Start, Ty), getSignExtendExpr(Step, Ty), - AR->getLoop()); + L); + + // Similar to above, only this time treat the step value as unsigned. + // This covers loops that count up with an unsigned step. + const SCEV *UMul = + getMulExpr(CastedMaxBECount, + getTruncateOrZeroExtend(Step, Start->getType())); + Add = getAddExpr(Start, UMul); + OperandExtendedAdd = + getAddExpr(getSignExtendExpr(Start, WideTy), + getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), + getZeroExtendExpr(Step, WideTy))); + if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + L); + } + + // If the backedge is guarded by a comparison with the pre-inc value + // the addrec is safe. Also, if the entry is guarded by a comparison + // with the start value and the backedge is guarded by a comparison + // with the post-inc value, the addrec is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) - + getSignedRange(Step).getSignedMax()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); + } else if (isKnownNegative(Step)) { + const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) - + getSignedRange(Step).getSignedMin()); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) || + (isLoopGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) && + isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, + AR->getPostIncExpr(*this), N))) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + L); } } } - SCEVSignExtendExpr *&Result = SCEVSignExtends[std::make_pair(Op, Ty)]; - if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); - return Result; + // The cast wasn't folded; create an explicit cast node. + // Recompute the insert position, as it may have been invalidated. + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVSignExtendExpr(ID, Op, Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. /// -const SCEV* ScalarEvolution::getAnyExtendExpr(const SCEV* Op, - const Type *Ty) { +const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, + const Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -925,19 +1070,19 @@ const SCEV* ScalarEvolution::getAnyExtendExpr(const SCEV* Op, // Peel off a truncate cast. if (const SCEVTruncateExpr *T = dyn_cast(Op)) { - const SCEV* NewOp = T->getOperand(); + const SCEV *NewOp = T->getOperand(); if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) return getAnyExtendExpr(NewOp, Ty); return getTruncateOrNoop(NewOp, Ty); } // Next try a zext cast. If the cast is folded, use it. - const SCEV* ZExt = getZeroExtendExpr(Op, Ty); + const SCEV *ZExt = getZeroExtendExpr(Op, Ty); if (!isa(ZExt)) return ZExt; // Next try a sext cast. If the cast is folded, use it. - const SCEV* SExt = getSignExtendExpr(Op, Ty); + const SCEV *SExt = getSignExtendExpr(Op, Ty); if (!isa(SExt)) return SExt; @@ -975,10 +1120,10 @@ const SCEV* ScalarEvolution::getAnyExtendExpr(const SCEV* Op, /// is also used as a check to avoid infinite recursion. /// static bool -CollectAddOperandsWithScales(DenseMap &M, - SmallVector &NewOps, +CollectAddOperandsWithScales(DenseMap &M, + SmallVector &NewOps, APInt &AccumulatedConstant, - const SmallVectorImpl &Ops, + const SmallVectorImpl &Ops, const APInt &Scale, ScalarEvolution &SE) { bool Interesting = false; @@ -999,12 +1144,11 @@ CollectAddOperandsWithScales(DenseMap &M, } else { // A multiplication of a constant with some other value. Update // the map. - SmallVector MulOps(Mul->op_begin()+1, Mul->op_end()); - const SCEV* Key = SE.getMulExpr(MulOps); - std::pair::iterator, bool> Pair = - M.insert(std::make_pair(Key, APInt())); + SmallVector MulOps(Mul->op_begin()+1, Mul->op_end()); + const SCEV *Key = SE.getMulExpr(MulOps); + std::pair::iterator, bool> Pair = + M.insert(std::make_pair(Key, NewScale)); if (Pair.second) { - Pair.first->second = NewScale; NewOps.push_back(Pair.first->first); } else { Pair.first->second += NewScale; @@ -1020,10 +1164,9 @@ CollectAddOperandsWithScales(DenseMap &M, AccumulatedConstant += Scale * C->getValue()->getValue(); } else { // An ordinary operand. Update the map. - std::pair::iterator, bool> Pair = - M.insert(std::make_pair(Ops[i], APInt())); + std::pair::iterator, bool> Pair = + M.insert(std::make_pair(Ops[i], Scale)); if (Pair.second) { - Pair.first->second = Scale; NewOps.push_back(Pair.first->first); } else { Pair.first->second += Scale; @@ -1047,7 +1190,8 @@ namespace { /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. -const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { +const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, + bool HasNUW, bool HasNSW) { assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1091,13 +1235,13 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 // Found a match, merge the two values into a multiply, and add any // remaining values to the result. - const SCEV* Two = getIntegerSCEV(2, Ty); - const SCEV* Mul = getMulExpr(Ops[i], Two); + const SCEV *Two = getIntegerSCEV(2, Ty); + const SCEV *Mul = getMulExpr(Ops[i], Two); if (Ops.size() == 2) return Mul; Ops.erase(Ops.begin()+i, Ops.begin()+i+2); Ops.push_back(Mul); - return getAddExpr(Ops); + return getAddExpr(Ops, HasNUW, HasNSW); } // Check for truncates. If all the operands are truncated from the same @@ -1108,7 +1252,7 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); const Type *DstType = Trunc->getType(); const Type *SrcType = Trunc->getOperand()->getType(); - SmallVector LargeOps; + SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. @@ -1124,7 +1268,7 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // is much more likely to be foldable here. LargeOps.push_back(getSignExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast(Ops[i])) { - SmallVector LargeMulOps; + SmallVector LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast(M->getOperand(j))) { @@ -1152,7 +1296,7 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { } if (Ok) { // Evaluate the expression in the larger type. - const SCEV* Fold = getAddExpr(LargeOps); + const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, DstType); @@ -1189,26 +1333,27 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // operands multiplied by constant values. if (Idx < Ops.size() && isa(Ops[Idx])) { uint64_t BitWidth = getTypeSizeInBits(Ty); - DenseMap M; - SmallVector NewOps; + DenseMap M; + SmallVector NewOps; APInt AccumulatedConstant(BitWidth, 0); if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops, APInt(BitWidth, 1), *this)) { // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. - std::map, APIntCompare> MulOpLists; - for (SmallVector::iterator I = NewOps.begin(), + std::map, APIntCompare> MulOpLists; + for (SmallVector::iterator I = NewOps.begin(), E = NewOps.end(); I != E; ++I) MulOpLists[M.find(*I)->second].push_back(*I); // Re-generate the operands list. Ops.clear(); if (AccumulatedConstant != 0) Ops.push_back(getConstant(AccumulatedConstant)); - for (std::map, APIntCompare>::iterator I = - MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) + for (std::map, APIntCompare>::iterator + I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) if (I->first != 0) - Ops.push_back(getMulExpr(getConstant(I->first), getAddExpr(I->second))); + Ops.push_back(getMulExpr(getConstant(I->first), + getAddExpr(I->second))); if (Ops.empty()) return getIntegerSCEV(0, Ty); if (Ops.size() == 1) @@ -1227,17 +1372,17 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) if (MulOpSCEV == Ops[AddOp] && !isa(Ops[AddOp])) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) - const SCEV* InnerMul = Mul->getOperand(MulOp == 0); + const SCEV *InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - SmallVector MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector MulOps(Mul->op_begin(), Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul = getMulExpr(MulOps); } - const SCEV* One = getIntegerSCEV(1, Ty); - const SCEV* AddOne = getAddExpr(InnerMul, One); - const SCEV* OuterMul = getMulExpr(AddOne, Ops[AddOp]); + const SCEV *One = getIntegerSCEV(1, Ty); + const SCEV *AddOne = getAddExpr(InnerMul, One); + const SCEV *OuterMul = getMulExpr(AddOne, Ops[AddOp]); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -1261,21 +1406,22 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { OMulOp != e; ++OMulOp) if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) - const SCEV* InnerMul1 = Mul->getOperand(MulOp == 0); + const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { - SmallVector MulOps(Mul->op_begin(), Mul->op_end()); + SmallVector MulOps(Mul->op_begin(), + Mul->op_end()); MulOps.erase(MulOps.begin()+MulOp); InnerMul1 = getMulExpr(MulOps); } - const SCEV* InnerMul2 = OtherMul->getOperand(OMulOp == 0); + const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { - SmallVector MulOps(OtherMul->op_begin(), - OtherMul->op_end()); + SmallVector MulOps(OtherMul->op_begin(), + OtherMul->op_end()); MulOps.erase(MulOps.begin()+OMulOp); InnerMul2 = getMulExpr(MulOps); } - const SCEV* InnerMulSum = getAddExpr(InnerMul1,InnerMul2); - const SCEV* OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); + const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); @@ -1296,7 +1442,7 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1310,11 +1456,14 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - SmallVector AddRecOps(AddRec->op_begin(), - AddRec->op_end()); + SmallVector AddRecOps(AddRec->op_begin(), + AddRec->op_end()); AddRecOps[0] = getAddExpr(LIOps); - const SCEV* NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); + // It's tempting to propagate NUW/NSW flags here, but nuw/nsw addition + // is not associative so this isn't necessarily safe. + const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); + // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1336,7 +1485,8 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { const SCEVAddRecExpr *OtherAddRec = cast(Ops[OtherIdx]); if (AddRec->getLoop() == OtherAddRec->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - SmallVector NewOps(AddRec->op_begin(), AddRec->op_end()); + SmallVector NewOps(AddRec->op_begin(), + AddRec->op_end()); for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { if (i >= NewOps.size()) { NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i, @@ -1345,7 +1495,7 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { } NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i)); } - const SCEV* NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); + const SCEV *NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); if (Ops.size() == 2) return NewAddRec; @@ -1362,17 +1512,26 @@ const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scAddExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVAddExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scAddExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEVAddExpr *S = SCEVAllocator.Allocate(); + new (S) SCEVAddExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); + return S; } /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. -const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { +const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, + bool HasNUW, bool HasNSW) { assert(!Ops.empty() && "Cannot get empty mul!"); #ifndef NDEBUG for (unsigned i = 1, e = Ops.size(); i != e; ++i) @@ -1400,7 +1559,8 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + ConstantInt *Fold = ConstantInt::get(getContext(), + LHSC->getValue()->getValue() * RHSC->getValue()->getValue()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element @@ -1453,7 +1613,7 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { @@ -1465,7 +1625,7 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); if (LIOps.size() == 1) { const SCEV *Scale = LIOps[0]; @@ -1473,13 +1633,15 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); } else { for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - SmallVector MulOps(LIOps.begin(), LIOps.end()); + SmallVector MulOps(LIOps.begin(), LIOps.end()); MulOps.push_back(AddRec->getOperand(i)); NewOps.push_back(getMulExpr(MulOps)); } } - const SCEV* NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); + // It's tempting to propagate the NSW flag here, but nsw multiplication + // is not associative so this isn't necessarily safe. + const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -1503,14 +1665,14 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { if (AddRec->getLoop() == OtherAddRec->getLoop()) { // F * G --> {A,+,B} * {C,+,D} --> {A*C,+,F*D + G*B + B*D} const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; - const SCEV* NewStart = getMulExpr(F->getStart(), + const SCEV *NewStart = getMulExpr(F->getStart(), G->getStart()); - const SCEV* B = F->getStepRecurrence(*this); - const SCEV* D = G->getStepRecurrence(*this); - const SCEV* NewStep = getAddExpr(getMulExpr(F, D), + const SCEV *B = F->getStepRecurrence(*this); + const SCEV *D = G->getStepRecurrence(*this); + const SCEV *NewStep = getAddExpr(getMulExpr(F, D), getMulExpr(G, B), getMulExpr(B, D)); - const SCEV* NewAddRec = getAddRecExpr(NewStart, NewStep, + const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep, F->getLoop()); if (Ops.size() == 2) return NewAddRec; @@ -1527,25 +1689,32 @@ const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scMulExpr, - SCEVOps)]; - if (Result == 0) - Result = new SCEVMulExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scMulExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEVMulExpr *S = SCEVAllocator.Allocate(); + new (S) SCEVMulExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); + return S; } -/// getUDivExpr - Get a canonical multiply expression, or something simpler if -/// possible. -const SCEV* ScalarEvolution::getUDivExpr(const SCEV* LHS, - const SCEV* RHS) { +/// getUDivExpr - Get a canonical unsigned division expression, or something +/// simpler if possible. +const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, + const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->equalsInt(1)) - return LHS; // X udiv 1 --> x + return LHS; // X udiv 1 --> x if (RHSC->isZero()) return getIntegerSCEV(0, LHS->getType()); // value is undefined @@ -1560,7 +1729,7 @@ const SCEV* ScalarEvolution::getUDivExpr(const SCEV* LHS, if (!RHSC->getValue()->getValue().isPowerOf2()) ++MaxShiftAmt; const IntegerType *ExtTy = - IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt); + IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. if (const SCEVAddRecExpr *AR = dyn_cast(LHS)) if (const SCEVConstant *Step = @@ -1571,24 +1740,24 @@ const SCEV* ScalarEvolution::getUDivExpr(const SCEV* LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop())) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); return getAddRecExpr(Operands, AR->getLoop()); } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { - const SCEV* Op = M->getOperand(i); - const SCEV* Div = getUDivExpr(Op, RHSC); + const SCEV *Op = M->getOperand(i); + const SCEV *Div = getUDivExpr(Op, RHSC); if (!isa(Div) && getMulExpr(Div, RHSC) == Op) { - const SmallVectorImpl &MOperands = M->getOperands(); - Operands = SmallVector(MOperands.begin(), + const SmallVectorImpl &MOperands = M->getOperands(); + Operands = SmallVector(MOperands.begin(), MOperands.end()); Operands[i] = Div; return getMulExpr(Operands); @@ -1597,13 +1766,13 @@ const SCEV* ScalarEvolution::getUDivExpr(const SCEV* LHS, } // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddRecExpr *A = dyn_cast(LHS)) { - SmallVector Operands; + SmallVector Operands; for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { - const SCEV* Op = getUDivExpr(A->getOperand(i), RHS); + const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); if (isa(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) break; Operands.push_back(Op); @@ -1622,17 +1791,25 @@ const SCEV* ScalarEvolution::getUDivExpr(const SCEV* LHS, } } - SCEVUDivExpr *&Result = SCEVUDivs[std::make_pair(LHS, RHS)]; - if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUDivExpr(ID, LHS, RHS); + UniqueSCEVs.InsertNode(S, IP); + return S; } /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -const SCEV* ScalarEvolution::getAddRecExpr(const SCEV* Start, - const SCEV* Step, const Loop *L) { - SmallVector Operands; +const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, + const SCEV *Step, const Loop *L, + bool HasNUW, bool HasNSW) { + SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { @@ -1642,13 +1819,15 @@ const SCEV* ScalarEvolution::getAddRecExpr(const SCEV* Start, } Operands.push_back(Step); - return getAddRecExpr(Operands, L); + return getAddRecExpr(Operands, L, HasNUW, HasNSW); } /// getAddRecExpr - Get an add recurrence expression for the specified loop. /// Simplify the expression as much as possible. -const SCEV* ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L) { +const SCEV * +ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, + const Loop *L, + bool HasNUW, bool HasNSW) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG for (unsigned i = 1, e = Operands.size(); i != e; ++i) @@ -1659,37 +1838,68 @@ const SCEV* ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operand if (Operands.back()->isZero()) { Operands.pop_back(); - return getAddRecExpr(Operands, L); // {X,+,0} --> X + return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0} --> X } // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast(Operands[0])) { - const Loop* NestedLoop = NestedAR->getLoop(); + const Loop *NestedLoop = NestedAR->getLoop(); if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { - SmallVector NestedOperands(NestedAR->op_begin(), - NestedAR->op_end()); + SmallVector NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); Operands[0] = NestedAR->getStart(); - NestedOperands[0] = getAddRecExpr(Operands, L); - return getAddRecExpr(NestedOperands, NestedLoop); + // AddRecs require their operands be loop-invariant with respect to their + // loops. Don't perform this transformation if it would break this + // requirement. + bool AllInvariant = true; + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + if (!Operands[i]->isLoopInvariant(L)) { + AllInvariant = false; + break; + } + if (AllInvariant) { + NestedOperands[0] = getAddRecExpr(Operands, L); + AllInvariant = true; + for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) + if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) { + AllInvariant = false; + break; + } + if (AllInvariant) + // Ok, both add recurrences are valid after the transformation. + return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW); + } + // Reset Operands to its original state. + Operands[0] = NestedAR; } } - std::vector SCEVOps(Operands.begin(), Operands.end()); - SCEVAddRecExpr *&Result = SCEVAddRecExprs[std::make_pair(L, SCEVOps)]; - if (Result == 0) Result = new SCEVAddRecExpr(Operands, L); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddInteger(Operands.size()); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) + ID.AddPointer(Operands[i]); + ID.AddPointer(L); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEVAddRecExpr *S = SCEVAllocator.Allocate(); + new (S) SCEVAddRecExpr(ID, Operands, L); + UniqueSCEVs.InsertNode(S, IP); + if (HasNUW) S->setHasNoUnsignedWrap(true); + if (HasNSW) S->setHasNoSignedWrap(true); + return S; } -const SCEV* ScalarEvolution::getSMaxExpr(const SCEV* LHS, - const SCEV* RHS) { - SmallVector Ops; +const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getSMaxExpr(Ops); } -const SCEV* -ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { +const SCEV * +ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty smax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1709,7 +1919,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get( + ConstantInt *Fold = ConstantInt::get(getContext(), APIntOps::smax(LHSC->getValue()->getValue(), RHSC->getValue()->getValue())); Ops[0] = getConstant(Fold); @@ -1718,10 +1928,14 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { LHSC = cast(Ops[0]); } - // If we are left with a constant -inf, strip it off. + // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(true)) { Ops.erase(Ops.begin()); --Idx; + } else if (cast(Ops[0])->getValue()->isMaxValue(true)) { + // If we have an smax with a constant maximum-int, it will always be + // maximum-int. + return Ops[0]; } } @@ -1760,23 +1974,29 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need an smax expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scSMaxExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVSMaxExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scSMaxExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVSMaxExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } -const SCEV* ScalarEvolution::getUMaxExpr(const SCEV* LHS, - const SCEV* RHS) { - SmallVector Ops; +const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, + const SCEV *RHS) { + SmallVector Ops; Ops.push_back(LHS); Ops.push_back(RHS); return getUMaxExpr(Ops); } -const SCEV* -ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { +const SCEV * +ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(!Ops.empty() && "Cannot get empty umax!"); if (Ops.size() == 1) return Ops[0]; #ifndef NDEBUG @@ -1796,7 +2016,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get( + ConstantInt *Fold = ConstantInt::get(getContext(), APIntOps::umax(LHSC->getValue()->getValue(), RHSC->getValue()->getValue())); Ops[0] = getConstant(Fold); @@ -1805,10 +2025,14 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { LHSC = cast(Ops[0]); } - // If we are left with a constant zero, strip it off. + // If we are left with a constant minimum-int, strip it off. if (cast(Ops[0])->getValue()->isMinValue(false)) { Ops.erase(Ops.begin()); --Idx; + } else if (cast(Ops[0])->getValue()->isMaxValue(false)) { + // If we have an umax with a constant maximum-int, it will always be + // maximum-int. + return Ops[0]; } } @@ -1847,33 +2071,116 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { // Okay, it looks like we really DO need a umax expr. Check to see if we // already have one, otherwise create a new one. - std::vector SCEVOps(Ops.begin(), Ops.end()); - SCEVCommutativeExpr *&Result = SCEVCommExprs[std::make_pair(scUMaxExpr, - SCEVOps)]; - if (Result == 0) Result = new SCEVUMaxExpr(Ops); - return Result; + FoldingSetNodeID ID; + ID.AddInteger(scUMaxExpr); + ID.AddInteger(Ops.size()); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUMaxExpr(ID, Ops); + UniqueSCEVs.InsertNode(S, IP); + return S; } -const SCEV* ScalarEvolution::getSMinExpr(const SCEV* LHS, - const SCEV* RHS) { +const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, + const SCEV *RHS) { // ~smax(~x, ~y) == smin(x, y). return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } -const SCEV* ScalarEvolution::getUMinExpr(const SCEV* LHS, - const SCEV* RHS) { +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, + const SCEV *RHS) { // ~umax(~x, ~y) == umin(x, y) return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); } -const SCEV* ScalarEvolution::getUnknown(Value *V) { - if (ConstantInt *CI = dyn_cast(V)) - return getConstant(CI); - if (isa(V)) - return getIntegerSCEV(0, V->getType()); - SCEVUnknown *&Result = SCEVUnknowns[V]; - if (Result == 0) Result = new SCEVUnknown(V); - return Result; +const SCEV *ScalarEvolution::getFieldOffsetExpr(const StructType *STy, + unsigned FieldNo) { + // If we have TargetData we can determine the constant offset. + if (TD) { + const Type *IntPtrTy = TD->getIntPtrType(getContext()); + const StructLayout &SL = *TD->getStructLayout(STy); + uint64_t Offset = SL.getElementOffset(FieldNo); + return getIntegerSCEV(Offset, IntPtrTy); + } + + // Field 0 is always at offset 0. + if (FieldNo == 0) { + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); + return getIntegerSCEV(0, Ty); + } + + // Okay, it looks like we really DO need an offsetof expr. Check to see if we + // already have one, otherwise create a new one. + FoldingSetNodeID ID; + ID.AddInteger(scFieldOffset); + ID.AddPointer(STy); + ID.AddInteger(FieldNo); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy)); + new (S) SCEVFieldOffsetExpr(ID, Ty, STy, FieldNo); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + +const SCEV *ScalarEvolution::getAllocSizeExpr(const Type *AllocTy) { + // If we have TargetData we can determine the constant size. + if (TD && AllocTy->isSized()) { + const Type *IntPtrTy = TD->getIntPtrType(getContext()); + return getIntegerSCEV(TD->getTypeAllocSize(AllocTy), IntPtrTy); + } + + // Expand an array size into the element size times the number + // of elements. + if (const ArrayType *ATy = dyn_cast(AllocTy)) { + const SCEV *E = getAllocSizeExpr(ATy->getElementType()); + return getMulExpr( + E, getConstant(ConstantInt::get(cast(E->getType()), + ATy->getNumElements()))); + } + + // Expand a vector size into the element size times the number + // of elements. + if (const VectorType *VTy = dyn_cast(AllocTy)) { + const SCEV *E = getAllocSizeExpr(VTy->getElementType()); + return getMulExpr( + E, getConstant(ConstantInt::get(cast(E->getType()), + VTy->getNumElements()))); + } + + // Okay, it looks like we really DO need a sizeof expr. Check to see if we + // already have one, otherwise create a new one. + FoldingSetNodeID ID; + ID.AddInteger(scAllocSize); + ID.AddPointer(AllocTy); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy)); + new (S) SCEVAllocSizeExpr(ID, Ty, AllocTy); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + +const SCEV *ScalarEvolution::getUnknown(Value *V) { + // Don't attempt to do anything other than create a SCEVUnknown object + // here. createSCEV only calls getUnknown after checking for all other + // interesting possibilities, and any other code that calls getUnknown + // is doing so in order to hide a value from SCEV canonicalization. + + FoldingSetNodeID ID; + ID.AddInteger(scUnknown); + ID.AddPointer(V); + void *IP = 0; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + SCEV *S = SCEVAllocator.Allocate(); + new (S) SCEVUnknown(ID, V); + UniqueSCEVs.InsertNode(S, IP); + return S; } //===----------------------------------------------------------------------===// @@ -1885,17 +2192,8 @@ const SCEV* ScalarEvolution::getUnknown(Value *V) { /// can optionally include pointer types if the ScalarEvolution class /// has access to target-specific information. bool ScalarEvolution::isSCEVable(const Type *Ty) const { - // Integers are always SCEVable. - if (Ty->isInteger()) - return true; - - // Pointers are SCEVable if TargetData information is available - // to provide pointer size information. - if (isa(Ty)) - return TD != NULL; - - // Otherwise it's not SCEVable. - return false; + // Integers and pointers are always SCEVable. + return Ty->isInteger() || isa(Ty); } /// getTypeSizeInBits - Return the size in bits of the specified type, @@ -1907,9 +2205,14 @@ uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const { if (TD) return TD->getTypeSizeInBits(Ty); - // Otherwise, we support only integer types. - assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!"); - return Ty->getPrimitiveSizeInBits(); + // Integer types have fixed sizes. + if (Ty->isInteger()) + return Ty->getPrimitiveSizeInBits(); + + // The only other support type is pointer. Without TargetData, conservatively + // assume pointers are 64-bit. + assert(isa(Ty) && "isSCEVable permitted a non-SCEVable type!"); + return 64; } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -1922,73 +2225,67 @@ const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const { if (Ty->isInteger()) return Ty; + // The only other support type is pointer. assert(isa(Ty) && "Unexpected non-pointer non-integer type!"); - return TD->getIntPtrType(); -} + if (TD) return TD->getIntPtrType(getContext()); -const SCEV* ScalarEvolution::getCouldNotCompute() { - return CouldNotCompute; + // Without TargetData, conservatively assume pointers are 64-bit. + return Type::getInt64Ty(getContext()); } -/// hasSCEV - Return true if the SCEV for this value has already been -/// computed. -bool ScalarEvolution::hasSCEV(Value *V) const { - return Scalars.count(V); +const SCEV *ScalarEvolution::getCouldNotCompute() { + return &CouldNotCompute; } /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the /// expression and create a new one. -const SCEV* ScalarEvolution::getSCEV(Value *V) { +const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - std::map::iterator I = Scalars.find(V); + std::map::iterator I = Scalars.find(V); if (I != Scalars.end()) return I->second; - const SCEV* S = createSCEV(V); + const SCEV *S = createSCEV(V); Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S)); return S; } -/// getIntegerSCEV - Given an integer or FP type, create a constant for the +/// getIntegerSCEV - Given a SCEVable type, create a constant for the /// specified signed integer value and return a SCEV for the constant. -const SCEV* ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { - Ty = getEffectiveSCEVType(Ty); - Constant *C; - if (Val == 0) - C = Constant::getNullValue(Ty); - else if (Ty->isFloatingPoint()) - C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : - APFloat::IEEEdouble, Val)); - else - C = ConstantInt::get(Ty, Val); - return getUnknown(C); +const SCEV *ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { + const IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); + return getConstant(ConstantInt::get(ITy, Val)); } /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// -const SCEV* ScalarEvolution::getNegativeSCEV(const SCEV* V) { +const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast(V)) - return getConstant(cast(ConstantExpr::getNeg(VC->getValue()))); + return getConstant( + cast(ConstantExpr::getNeg(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(Ty))); + return getMulExpr(V, + getConstant(cast(Constant::getAllOnesValue(Ty)))); } /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V -const SCEV* ScalarEvolution::getNotSCEV(const SCEV* V) { +const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast(V)) - return getConstant(cast(ConstantExpr::getNot(VC->getValue()))); + return getConstant( + cast(ConstantExpr::getNot(VC->getValue()))); const Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - const SCEV* AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty)); + const SCEV *AllOnes = + getConstant(cast(Constant::getAllOnesValue(Ty))); return getMinusSCEV(AllOnes, V); } /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. /// -const SCEV* ScalarEvolution::getMinusSCEV(const SCEV* LHS, - const SCEV* RHS) { +const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, + const SCEV *RHS) { // X - Y --> X + -Y return getAddExpr(LHS, getNegativeSCEV(RHS)); } @@ -1996,12 +2293,12 @@ const SCEV* ScalarEvolution::getMinusSCEV(const SCEV* LHS, /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. -const SCEV* -ScalarEvolution::getTruncateOrZeroExtend(const SCEV* V, +const SCEV * +ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -2013,12 +2310,12 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV* V, /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is sign /// extended. -const SCEV* -ScalarEvolution::getTruncateOrSignExtend(const SCEV* V, +const SCEV * +ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -2030,11 +2327,11 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV* V, /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is zero /// extended. The conversion must not be narrowing. -const SCEV* -ScalarEvolution::getNoopOrZeroExtend(const SCEV* V, const Type *Ty) { +const SCEV * +ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or zero extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrZeroExtend cannot truncate!"); @@ -2046,11 +2343,11 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV* V, const Type *Ty) { /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the /// input value to the specified type. If the type must be extended, it is sign /// extended. The conversion must not be narrowing. -const SCEV* -ScalarEvolution::getNoopOrSignExtend(const SCEV* V, const Type *Ty) { +const SCEV * +ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or sign extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrSignExtend cannot truncate!"); @@ -2063,11 +2360,11 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV* V, const Type *Ty) { /// the input value to the specified type. If the type must be extended, /// it is extended with unspecified bits. The conversion must not be /// narrowing. -const SCEV* -ScalarEvolution::getNoopOrAnyExtend(const SCEV* V, const Type *Ty) { +const SCEV * +ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot noop or any extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrAnyExtend cannot truncate!"); @@ -2078,11 +2375,11 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV* V, const Type *Ty) { /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the /// input value to the specified type. The conversion must not be widening. -const SCEV* -ScalarEvolution::getTruncateOrNoop(const SCEV* V, const Type *Ty) { +const SCEV * +ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) { const Type *SrcTy = V->getType(); - assert((SrcTy->isInteger() || (TD && isa(SrcTy))) && - (Ty->isInteger() || (TD && isa(Ty))) && + assert((SrcTy->isInteger() || isa(SrcTy)) && + (Ty->isInteger() || isa(Ty)) && "Cannot truncate or noop with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && "getTruncateOrNoop cannot extend!"); @@ -2094,10 +2391,10 @@ ScalarEvolution::getTruncateOrNoop(const SCEV* V, const Type *Ty) { /// getUMaxFromMismatchedTypes - Promote the operands to the wider of /// the types using zero-extension, and then perform a umax operation /// with them. -const SCEV* ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV* LHS, - const SCEV* RHS) { - const SCEV* PromotedLHS = LHS; - const SCEV* PromotedRHS = RHS; +const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + const SCEV *PromotedLHS = LHS; + const SCEV *PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); @@ -2110,10 +2407,10 @@ const SCEV* ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV* LHS, /// getUMinFromMismatchedTypes - Promote the operands to the wider of /// the types using zero-extension, and then perform a umin operation /// with them. -const SCEV* ScalarEvolution::getUMinFromMismatchedTypes(const SCEV* LHS, - const SCEV* RHS) { - const SCEV* PromotedLHS = LHS; - const SCEV* PromotedRHS = RHS; +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, + const SCEV *RHS) { + const SCEV *PromotedLHS = LHS; + const SCEV *PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); @@ -2123,33 +2420,60 @@ const SCEV* ScalarEvolution::getUMinFromMismatchedTypes(const SCEV* LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } -/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for -/// the specified instruction and replaces any references to the symbolic value -/// SymName with the specified value. This is used during PHI resolution. -void ScalarEvolution:: -ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEV* SymName, - const SCEV* NewVal) { - std::map::iterator SI = - Scalars.find(SCEVCallbackVH(I, this)); - if (SI == Scalars.end()) return; +/// PushDefUseChildren - Push users of the given Instruction +/// onto the given Worklist. +static void +PushDefUseChildren(Instruction *I, + SmallVectorImpl &Worklist) { + // Push the def-use children onto the Worklist stack. + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) + Worklist.push_back(cast(UI)); +} + +/// ForgetSymbolicValue - This looks up computed SCEV values for all +/// instructions that depend on the given instruction and removes them from +/// the Scalars map if they reference SymName. This is used during PHI +/// resolution. +void +ScalarEvolution::ForgetSymbolicName(Instruction *I, const SCEV *SymName) { + SmallVector Worklist; + PushDefUseChildren(I, Worklist); - const SCEV* NV = - SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this); - if (NV == SI->second) return; // No change. + SmallPtrSet Visited; + Visited.insert(I); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I)) continue; + + std::map::iterator It = + Scalars.find(static_cast(I)); + if (It != Scalars.end()) { + // Short-circuit the def-use traversal if the symbolic name + // ceases to appear in expressions. + if (!It->second->hasOperand(SymName)) + continue; - SI->second = NV; // Update the scalars map! + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, or it's a PHI that's in the progress of being computed + // by createNodeForPHI. In the former case, additional loop trip + // count information isn't going to change anything. In the later + // case, createNodeForPHI will perform the necessary updates on its + // own when it gets to that point. + if (!isa(I) || !isa(It->second)) { + ValuesAtScopes.erase(It->second); + Scalars.erase(It); + } + } - // Any instruction values that use this instruction might also need to be - // updated! - for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); - UI != E; ++UI) - ReplaceSymbolicValueWithConcrete(cast(*UI), SymName, NewVal); + PushDefUseChildren(I, Worklist); + } } /// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in /// a loop header, making it a potential recurrence, or it doesn't. /// -const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { +const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (PN->getNumIncomingValues() == 2) // The loops have been canonicalized. if (const Loop *L = LI->getLoopFor(PN->getParent())) if (L->getHeader() == PN->getParent()) { @@ -2159,14 +2483,15 @@ const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { unsigned BackEdge = IncomingEdge^1; // While we are analyzing this PHI node, handle its value symbolically. - const SCEV* SymbolicName = getUnknown(PN); + const SCEV *SymbolicName = getUnknown(PN); assert(Scalars.find(PN) == Scalars.end() && "PHI node already processed?"); Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. - const SCEV* BEValue = getSCEV(PN->getIncomingValue(BackEdge)); + Value *BEValueV = PN->getIncomingValue(BackEdge); + const SCEV *BEValue = getSCEV(BEValueV); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. @@ -2186,26 +2511,48 @@ const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); - const SCEV* Accum = getAddExpr(Ops); + const SCEV *Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. if (Accum->isLoopInvariant(L) || (isa(Accum) && cast(Accum)->getLoop() == L)) { - const SCEV* StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); - const SCEV* PHISCEV = getAddRecExpr(StartVal, Accum, L); + const SCEV *StartVal = + getSCEV(PN->getIncomingValue(IncomingEdge)); + const SCEVAddRecExpr *PHISCEV = + cast(getAddRecExpr(StartVal, Accum, L)); + + // If the increment doesn't overflow, then neither the addrec nor the + // post-increment will overflow. + if (const AddOperator *OBO = dyn_cast(BEValueV)) + if (OBO->getOperand(0) == PN && + getSCEV(OBO->getOperand(1)) == + PHISCEV->getStepRecurrence(*this)) { + const SCEVAddRecExpr *PostInc = PHISCEV->getPostIncExpr(*this); + if (OBO->hasNoUnsignedWrap()) { + const_cast(PHISCEV) + ->setHasNoUnsignedWrap(true); + const_cast(PostInc) + ->setHasNoUnsignedWrap(true); + } + if (OBO->hasNoSignedWrap()) { + const_cast(PHISCEV) + ->setHasNoSignedWrap(true); + const_cast(PostInc) + ->setHasNoSignedWrap(true); + } + } // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and update all of the - // entries for the scalars that use the PHI (except for the PHI - // itself) to use the new analyzed value instead of the "symbolic" - // value. - ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2217,21 +2564,20 @@ const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. if (AddRec->getLoop() == L && AddRec->isAffine()) { - const SCEV* StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); + const SCEV *StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); // If StartVal = j.start - j.stride, we can use StartVal as the // initial step of the addrec evolution. if (StartVal == getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { - const SCEV* PHISCEV = + const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1), L); // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and update all of the - // entries for the scalars that use the PHI (except for the PHI - // itself) to use the new analyzed value instead of the "symbolic" - // value. - ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + ForgetSymbolicName(PN, SymbolicName); + Scalars[SCEVCallbackVH(PN, this)] = PHISCEV; return PHISCEV; } } @@ -2240,6 +2586,10 @@ const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { return SymbolicName; } + // It's tempting to recognize PHIs with a unique incoming value, however + // this leads passes like indvars to break LCSSA form. Fortunately, such + // PHIs are rare, as instcombine zaps them. + // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } @@ -2247,14 +2597,15 @@ const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) { /// createNodeForGEP - Expand GEP instructions into add and multiply /// operations. This allows them to be analyzed by regular SCEV code. /// -const SCEV* ScalarEvolution::createNodeForGEP(User *GEP) { +const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - const Type *IntPtrTy = TD->getIntPtrType(); + bool InBounds = GEP->isInBounds(); + const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType()); Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. if (!cast(Base->getType())->getElementType()->isSized()) return getUnknown(GEP); - const SCEV* TotalOffset = getIntegerSCEV(0, IntPtrTy); + const SCEV *TotalOffset = getIntegerSCEV(0, IntPtrTy); gep_type_iterator GTI = gep_type_begin(GEP); for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()), E = GEP->op_end(); @@ -2263,26 +2614,25 @@ const SCEV* ScalarEvolution::createNodeForGEP(User *GEP) { // Compute the (potentially symbolic) offset in bytes for this index. if (const StructType *STy = dyn_cast(*GTI++)) { // For a struct, add the member offset. - const StructLayout &SL = *TD->getStructLayout(STy); unsigned FieldNo = cast(Index)->getZExtValue(); - uint64_t Offset = SL.getElementOffset(FieldNo); TotalOffset = getAddExpr(TotalOffset, - getIntegerSCEV(Offset, IntPtrTy)); + getFieldOffsetExpr(STy, FieldNo), + /*HasNUW=*/false, /*HasNSW=*/InBounds); } else { // For an array, add the element offset, explicitly scaled. - const SCEV* LocalOffset = getSCEV(Index); + const SCEV *LocalOffset = getSCEV(Index); if (!isa(LocalOffset->getType())) // Getelementptr indicies are signed. - LocalOffset = getTruncateOrSignExtend(LocalOffset, - IntPtrTy); - LocalOffset = - getMulExpr(LocalOffset, - getIntegerSCEV(TD->getTypeAllocSize(*GTI), - IntPtrTy)); - TotalOffset = getAddExpr(TotalOffset, LocalOffset); + LocalOffset = getTruncateOrSignExtend(LocalOffset, IntPtrTy); + // Lower "inbounds" GEPs to NSW arithmetic. + LocalOffset = getMulExpr(LocalOffset, getAllocSizeExpr(*GTI), + /*HasNUW=*/false, /*HasNSW=*/InBounds); + TotalOffset = getAddExpr(TotalOffset, LocalOffset, + /*HasNUW=*/false, /*HasNSW=*/InBounds); } } - return getAddExpr(getSCEV(Base), TotalOffset); + return getAddExpr(getSCEV(Base), TotalOffset, + /*HasNUW=*/false, /*HasNSW=*/InBounds); } /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is @@ -2290,7 +2640,7 @@ const SCEV* ScalarEvolution::createNodeForGEP(User *GEP) { /// the minimum number of times S is divisible by 2. For example, given {4,+,8} /// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t -ScalarEvolution::GetMinTrailingZeros(const SCEV* S) { +ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) return C->getValue()->getValue().countTrailingZeros(); @@ -2366,18 +2716,100 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV* S) { return 0; } -uint32_t -ScalarEvolution::GetMinLeadingZeros(const SCEV* S) { - // TODO: Handle other SCEV expression types here. +/// getUnsignedRange - Determine the unsigned range for a particular SCEV. +/// +ConstantRange +ScalarEvolution::getUnsignedRange(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) - return C->getValue()->getValue().countLeadingZeros(); + return ConstantRange(C->getValue()->getValue()); + + if (const SCEVAddExpr *Add = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Add->getOperand(0)); + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) + X = X.add(getUnsignedRange(Add->getOperand(i))); + return X; + } + + if (const SCEVMulExpr *Mul = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Mul->getOperand(0)); + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) + X = X.multiply(getUnsignedRange(Mul->getOperand(i))); + return X; + } + + if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(SMax->getOperand(0)); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) + X = X.smax(getUnsignedRange(SMax->getOperand(i))); + return X; + } + + if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(UMax->getOperand(0)); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) + X = X.umax(getUnsignedRange(UMax->getOperand(i))); + return X; + } + + if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(UDiv->getLHS()); + ConstantRange Y = getUnsignedRange(UDiv->getRHS()); + return X.udiv(Y); + } + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(ZExt->getOperand()); + return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(SExt->getOperand()); + return X.signExtend(cast(SExt->getType())->getBitWidth()); + } + + if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { + ConstantRange X = getUnsignedRange(Trunc->getOperand()); + return X.truncate(cast(Trunc->getType())->getBitWidth()); + } - if (const SCEVZeroExtendExpr *C = dyn_cast(S)) { - // A zero-extension cast adds zero bits. - return GetMinLeadingZeros(C->getOperand()) + - (getTypeSizeInBits(C->getType()) - - getTypeSizeInBits(C->getOperand()->getType())); + ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); + + if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { + const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); + const SCEVConstant *Trip = dyn_cast(T); + if (!Trip) return FullSet; + + // TODO: non-affine addrec + if (AddRec->isAffine()) { + const Type *Ty = AddRec->getType(); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); + if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + + const SCEV *Start = AddRec->getStart(); + const SCEV *Step = AddRec->getStepRecurrence(*this); + const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + + // Check for overflow. + // TODO: This is very conservative. + if (!(Step->isOne() && + isKnownPredicate(ICmpInst::ICMP_ULT, Start, End)) && + !(Step->isAllOnesValue() && + isKnownPredicate(ICmpInst::ICMP_UGT, Start, End))) + return FullSet; + + ConstantRange StartRange = getUnsignedRange(Start); + ConstantRange EndRange = getUnsignedRange(End); + APInt Min = APIntOps::umin(StartRange.getUnsignedMin(), + EndRange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartRange.getUnsignedMax(), + EndRange.getUnsignedMax()); + if (Min.isMinValue() && Max.isMaxValue()) + return FullSet; + return ConstantRange(Min, Max+1); + } + } } if (const SCEVUnknown *U = dyn_cast(S)) { @@ -2386,41 +2818,128 @@ ScalarEvolution::GetMinLeadingZeros(const SCEV* S) { APInt Mask = APInt::getAllOnesValue(BitWidth); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD); - return Zeros.countLeadingOnes(); + if (Ones == ~Zeros + 1) + return FullSet; + return ConstantRange(Ones, ~Zeros + 1); } - return 1; + return FullSet; } -uint32_t -ScalarEvolution::GetMinSignBits(const SCEV* S) { - // TODO: Handle other SCEV expression types here. +/// getSignedRange - Determine the signed range for a particular SCEV. +/// +ConstantRange +ScalarEvolution::getSignedRange(const SCEV *S) { - if (const SCEVConstant *C = dyn_cast(S)) { - const APInt &A = C->getValue()->getValue(); - return A.isNegative() ? A.countLeadingOnes() : - A.countLeadingZeros(); + if (const SCEVConstant *C = dyn_cast(S)) + return ConstantRange(C->getValue()->getValue()); + + if (const SCEVAddExpr *Add = dyn_cast(S)) { + ConstantRange X = getSignedRange(Add->getOperand(0)); + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) + X = X.add(getSignedRange(Add->getOperand(i))); + return X; + } + + if (const SCEVMulExpr *Mul = dyn_cast(S)) { + ConstantRange X = getSignedRange(Mul->getOperand(0)); + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) + X = X.multiply(getSignedRange(Mul->getOperand(i))); + return X; + } + + if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { + ConstantRange X = getSignedRange(SMax->getOperand(0)); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) + X = X.smax(getSignedRange(SMax->getOperand(i))); + return X; + } + + if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { + ConstantRange X = getSignedRange(UMax->getOperand(0)); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) + X = X.umax(getSignedRange(UMax->getOperand(i))); + return X; + } + + if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { + ConstantRange X = getSignedRange(UDiv->getLHS()); + ConstantRange Y = getSignedRange(UDiv->getRHS()); + return X.udiv(Y); + } + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { + ConstantRange X = getSignedRange(ZExt->getOperand()); + return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { + ConstantRange X = getSignedRange(SExt->getOperand()); + return X.signExtend(cast(SExt->getType())->getBitWidth()); } - if (const SCEVSignExtendExpr *C = dyn_cast(S)) { - // A sign-extension cast adds sign bits. - return GetMinSignBits(C->getOperand()) + - (getTypeSizeInBits(C->getType()) - - getTypeSizeInBits(C->getOperand()->getType())); + if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { + ConstantRange X = getSignedRange(Trunc->getOperand()); + return X.truncate(cast(Trunc->getType())->getBitWidth()); + } + + ConstantRange FullSet(getTypeSizeInBits(S->getType()), true); + + if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { + const SCEV *T = getBackedgeTakenCount(AddRec->getLoop()); + const SCEVConstant *Trip = dyn_cast(T); + if (!Trip) return FullSet; + + // TODO: non-affine addrec + if (AddRec->isAffine()) { + const Type *Ty = AddRec->getType(); + const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); + if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) { + MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); + + const SCEV *Start = AddRec->getStart(); + const SCEV *Step = AddRec->getStepRecurrence(*this); + const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + + // Check for overflow. + // TODO: This is very conservative. + if (!(Step->isOne() && + isKnownPredicate(ICmpInst::ICMP_SLT, Start, End)) && + !(Step->isAllOnesValue() && + isKnownPredicate(ICmpInst::ICMP_SGT, Start, End))) + return FullSet; + + ConstantRange StartRange = getSignedRange(Start); + ConstantRange EndRange = getSignedRange(End); + APInt Min = APIntOps::smin(StartRange.getSignedMin(), + EndRange.getSignedMin()); + APInt Max = APIntOps::smax(StartRange.getSignedMax(), + EndRange.getSignedMax()); + if (Min.isMinSignedValue() && Max.isMaxSignedValue()) + return FullSet; + return ConstantRange(Min, Max+1); + } + } } if (const SCEVUnknown *U = dyn_cast(S)) { // For a SCEVUnknown, ask ValueTracking. - return ComputeNumSignBits(U->getValue(), TD); + unsigned BitWidth = getTypeSizeInBits(U->getType()); + unsigned NS = ComputeNumSignBits(U->getValue(), TD); + if (NS == 1) + return FullSet; + return + ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), + APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1); } - return 1; + return FullSet; } /// createSCEV - We know that there is no SCEV for the specified value. /// Analyze the expression. /// -const SCEV* ScalarEvolution::createSCEV(Value *V) { +const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -2429,15 +2948,29 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { Opcode = I->getOpcode(); else if (ConstantExpr *CE = dyn_cast(V)) Opcode = CE->getOpcode(); + else if (ConstantInt *CI = dyn_cast(V)) + return getConstant(CI); + else if (isa(V)) + return getIntegerSCEV(0, V->getType()); + else if (isa(V)) + return getIntegerSCEV(0, V->getType()); + else if (GlobalAlias *GA = dyn_cast(V)) + return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); else return getUnknown(V); - User *U = cast(V); + Operator *U = cast(V); switch (Opcode) { case Instruction::Add: + // Don't transfer the NSW and NUW bits from the Add instruction to the + // Add expression, because the Instruction may be guarded by control + // flow and the no-overflow bits may not be valid for the expression in + // any context. return getAddExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::Mul: + // Don't transfer the NSW and NUW bits from the Mul instruction to the + // Mul expression, as with Add. return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); case Instruction::UDiv: @@ -2471,7 +3004,7 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask)) return getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), - IntegerType::get(BitWidth - LZ)), + IntegerType::get(getContext(), BitWidth - LZ)), U->getType()); } break; @@ -2484,11 +3017,23 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { // In order for this transformation to be safe, the LHS must be of the // form X*(2^n) and the Or constant must be less than 2^n. if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { - const SCEV* LHS = getSCEV(U->getOperand(0)); + const SCEV *LHS = getSCEV(U->getOperand(0)); const APInt &CIVal = CI->getValue(); if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return getAddExpr(LHS, getSCEV(U->getOperand(1))); + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast(S)) { + const SCEVAddRecExpr *OldAR = cast(LHS); + if (OldAR->hasNoUnsignedWrap()) + const_cast(NewAR)->setHasNoUnsignedWrap(true); + if (OldAR->hasNoSignedWrap()) + const_cast(NewAR)->setHasNoSignedWrap(true); + } + return S; + } } break; case Instruction::Xor: @@ -2514,7 +3059,7 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { if (const SCEVZeroExtendExpr *Z = dyn_cast(getSCEV(U->getOperand(0)))) { const Type *UTy = U->getType(); - const SCEV* Z0 = Z->getOperand(); + const SCEV *Z0 = Z->getOperand(); const Type *Z0Ty = Z0->getType(); unsigned Z0TySize = getTypeSizeInBits(Z0Ty); @@ -2540,7 +3085,7 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { // Turn shift left of a constant amount into a multiply. if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( + Constant *X = ConstantInt::get(getContext(), APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); } @@ -2550,7 +3095,7 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { // Turn logical shift right of a constant into a unsigned divide. if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( + Constant *X = ConstantInt::get(getContext(), APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); } @@ -2570,7 +3115,7 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { return getIntegerSCEV(0, U->getType()); // value is undefined return getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(Amt)), + IntegerType::get(getContext(), Amt)), U->getType()); } break; @@ -2590,19 +3135,13 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { return getSCEV(U->getOperand(0)); break; - case Instruction::IntToPtr: - if (!TD) break; // Without TD we can't analyze pointers. - return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), - TD->getIntPtrType()); - - case Instruction::PtrToInt: - if (!TD) break; // Without TD we can't analyze pointers. - return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), - U->getType()); + // It's tempting to handle inttoptr and ptrtoint, however this can + // lead to pointer expressions which cannot be expanded to GEPs + // (because they may overflow). For now, the only pointer-typed + // expressions we handle are GEPs and address literals. case Instruction::GetElementPtr: - if (!TD) break; // Without TD we can't analyze pointers. - return createNodeForGEP(U); + return createNodeForGEP(cast(U)); case Instruction::PHI: return createNodeForPHI(cast(U)); @@ -2683,17 +3222,29 @@ const SCEV* ScalarEvolution::createSCEV(Value *V) { /// loop-invariant backedge-taken count (see /// hasLoopInvariantBackedgeTakenCount). /// -const SCEV* ScalarEvolution::getBackedgeTakenCount(const Loop *L) { +const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).Exact; } /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except /// return the least SCEV value that is known never to be less than the /// actual backedge taken count. -const SCEV* ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { +const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).Max; } +/// PushLoopPHIs - Push PHI nodes in the header of the given loop +/// onto the given Worklist. +static void +PushLoopPHIs(const Loop *L, SmallVectorImpl &Worklist) { + BasicBlock *Header = L->getHeader(); + + // Push all Loop-header PHIs onto the Worklist stack. + for (BasicBlock::iterator I = Header->begin(); + PHINode *PN = dyn_cast(I); ++I) + Worklist.push_back(PN); +} + const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert a CouldNotCompute for this loop. If the insertion @@ -2701,11 +3252,11 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // update the value. The temporary CouldNotCompute value tells SCEV // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. - std::pair::iterator, bool> Pair = + std::pair::iterator, bool> Pair = BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute())); if (Pair.second) { BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L); - if (ItCount.Exact != CouldNotCompute) { + if (ItCount.Exact != getCouldNotCompute()) { assert(ItCount.Exact->isLoopInvariant(L) && ItCount.Max->isLoopInvariant(L) && "Computed trip count isn't loop invariant for loop!"); @@ -2714,7 +3265,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Update the value in the map. Pair.first->second = ItCount; } else { - if (ItCount.Max != CouldNotCompute) + if (ItCount.Max != getCouldNotCompute()) // Update the value in the map. Pair.first->second = ItCount; if (isa(L->getHeader()->begin())) @@ -2724,50 +3275,68 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Now that we know more about the trip count for this loop, forget any // existing SCEV values for PHI nodes in this loop since they are only - // conservative estimates made without the benefit - // of trip count information. - if (ItCount.hasAnyInfo()) - forgetLoopPHIs(L); + // conservative estimates made without the benefit of trip count + // information. This is similar to the code in forgetLoop, except that + // it handles SCEVUnknown PHI nodes specially. + if (ItCount.hasAnyInfo()) { + SmallVector Worklist; + PushLoopPHIs(L, Worklist); + + SmallPtrSet Visited; + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I)) continue; + + std::map::iterator It = + Scalars.find(static_cast(I)); + if (It != Scalars.end()) { + // SCEVUnknown for a PHI either means that it has an unrecognized + // structure, or it's a PHI that's in the progress of being computed + // by createNodeForPHI. In the former case, additional loop trip + // count information isn't going to change anything. In the later + // case, createNodeForPHI will perform the necessary updates on its + // own when it gets to that point. + if (!isa(I) || !isa(It->second)) { + ValuesAtScopes.erase(It->second); + Scalars.erase(It); + } + if (PHINode *PN = dyn_cast(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist); + } + } } return Pair.first->second; } -/// forgetLoopBackedgeTakenCount - This method should be called by the -/// client when it has changed a loop in a way that may effect -/// ScalarEvolution's ability to compute a trip count, or if the loop -/// is deleted. -void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { +/// forgetLoop - This method should be called by the client when it has +/// changed a loop in a way that may effect ScalarEvolution's ability to +/// compute a trip count, or if the loop is deleted. +void ScalarEvolution::forgetLoop(const Loop *L) { + // Drop any stored trip count value. BackedgeTakenCounts.erase(L); - forgetLoopPHIs(L); -} -/// forgetLoopPHIs - Delete the memoized SCEVs associated with the -/// PHI nodes in the given loop. This is used when the trip count of -/// the loop may have changed. -void ScalarEvolution::forgetLoopPHIs(const Loop *L) { - BasicBlock *Header = L->getHeader(); - - // Push all Loop-header PHIs onto the Worklist stack, except those - // that are presently represented via a SCEVUnknown. SCEVUnknown for - // a PHI either means that it has an unrecognized structure, or it's - // a PHI that's in the progress of being computed by createNodeForPHI. - // In the former case, additional loop trip count information isn't - // going to change anything. In the later case, createNodeForPHI will - // perform the necessary updates on its own when it gets to that point. + // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; - for (BasicBlock::iterator I = Header->begin(); - PHINode *PN = dyn_cast(I); ++I) { - std::map::iterator It = Scalars.find((Value*)I); - if (It != Scalars.end() && !isa(It->second)) - Worklist.push_back(PN); - } + PushLoopPHIs(L, Worklist); + SmallPtrSet Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (Scalars.erase(I)) - for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); - UI != UE; ++UI) - Worklist.push_back(cast(UI)); + if (!Visited.insert(I)) continue; + + std::map::iterator It = + Scalars.find(static_cast(I)); + if (It != Scalars.end()) { + ValuesAtScopes.erase(It->second); + Scalars.erase(It); + if (PHINode *PN = dyn_cast(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist); } } @@ -2775,45 +3344,32 @@ void ScalarEvolution::forgetLoopPHIs(const Loop *L) { /// of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { - SmallVector ExitingBlocks; + SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); // Examine all exits and pick the most conservative values. - const SCEV* BECount = CouldNotCompute; - const SCEV* MaxBECount = CouldNotCompute; + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); bool CouldNotComputeBECount = false; - bool CouldNotComputeMaxBECount = false; for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BackedgeTakenInfo NewBTI = ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]); - if (NewBTI.Exact == CouldNotCompute) { + if (NewBTI.Exact == getCouldNotCompute()) { // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldNotComputeBECount = true; - BECount = CouldNotCompute; + BECount = getCouldNotCompute(); } else if (!CouldNotComputeBECount) { - if (BECount == CouldNotCompute) + if (BECount == getCouldNotCompute()) BECount = NewBTI.Exact; - else { - // TODO: More analysis could be done here. For example, a - // loop with a short-circuiting && operator has an exact count - // of the min of both sides. - CouldNotComputeBECount = true; - BECount = CouldNotCompute; - } - } - if (NewBTI.Max == CouldNotCompute) { - // We couldn't compute an maximum value for this exit, so - // we won't be able to compute an maximum value for the loop. - CouldNotComputeMaxBECount = true; - MaxBECount = CouldNotCompute; - } else if (!CouldNotComputeMaxBECount) { - if (MaxBECount == CouldNotCompute) - MaxBECount = NewBTI.Max; else - MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, NewBTI.Max); + BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact); } + if (MaxBECount == getCouldNotCompute()) + MaxBECount = NewBTI.Max; + else if (NewBTI.Max != getCouldNotCompute()) + MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max); } return BackedgeTakenInfo(BECount, MaxBECount); @@ -2830,9 +3386,9 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, // // FIXME: we should be able to handle switch instructions (with a single exit) BranchInst *ExitBr = dyn_cast(ExitingBlock->getTerminator()); - if (ExitBr == 0) return CouldNotCompute; + if (ExitBr == 0) return getCouldNotCompute(); assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!"); - + // At this point, we know we have a conditional branch that determines whether // the loop is exited. However, we don't know if the branch is executed each // time through the loop. If not, then the execution count of the branch will @@ -2859,7 +3415,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, for (BasicBlock *BB = ExitBr->getParent(); BB; ) { BasicBlock *Pred = BB->getUniquePredecessor(); if (!Pred) - return CouldNotCompute; + return getCouldNotCompute(); TerminatorInst *PredTerm = Pred->getTerminator(); for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) { BasicBlock *PredSucc = PredTerm->getSuccessor(i); @@ -2868,7 +3424,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, // If the predecessor has a successor that isn't BB and isn't // outside the loop, assume the worst. if (L->contains(PredSucc)) - return CouldNotCompute; + return getCouldNotCompute(); } if (Pred == L->getHeader()) { Ok = true; @@ -2877,7 +3433,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L, BB = Pred; } if (!Ok) - return CouldNotCompute; + return getCouldNotCompute(); } // Procede to the next level to examine the exit condition expression. @@ -2894,9 +3450,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB) { - // Check if the controlling expression for this loop is an and or or. In - // such cases, an exact backedge-taken count may be infeasible, but a - // maximum count may still be feasible. + // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. @@ -2904,27 +3458,30 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); BackedgeTakenInfo BTI1 = ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); - const SCEV* BECount = CouldNotCompute; - const SCEV* MaxBECount = CouldNotCompute; + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); if (L->contains(TBB)) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. - if (BTI0.Exact == CouldNotCompute || BTI1.Exact == CouldNotCompute) - BECount = CouldNotCompute; + if (BTI0.Exact == getCouldNotCompute() || + BTI1.Exact == getCouldNotCompute()) + BECount = getCouldNotCompute(); else BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max == CouldNotCompute) + if (BTI0.Max == getCouldNotCompute()) MaxBECount = BTI1.Max; - else if (BTI1.Max == CouldNotCompute) + else if (BTI1.Max == getCouldNotCompute()) MaxBECount = BTI0.Max; else MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); } else { // Both conditions must be true for the loop to exit. assert(L->contains(FBB) && "Loop block has no successor in loop!"); - if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute) + if (BTI0.Exact != getCouldNotCompute() && + BTI1.Exact != getCouldNotCompute()) BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute) + if (BTI0.Max != getCouldNotCompute() && + BTI1.Max != getCouldNotCompute()) MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); } @@ -2936,27 +3493,30 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L, ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB); BackedgeTakenInfo BTI1 = ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB); - const SCEV* BECount = CouldNotCompute; - const SCEV* MaxBECount = CouldNotCompute; + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); if (L->contains(FBB)) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. - if (BTI0.Exact == CouldNotCompute || BTI1.Exact == CouldNotCompute) - BECount = CouldNotCompute; + if (BTI0.Exact == getCouldNotCompute() || + BTI1.Exact == getCouldNotCompute()) + BECount = getCouldNotCompute(); else BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max == CouldNotCompute) + if (BTI0.Max == getCouldNotCompute()) MaxBECount = BTI1.Max; - else if (BTI1.Max == CouldNotCompute) + else if (BTI1.Max == getCouldNotCompute()) MaxBECount = BTI0.Max; else MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max); } else { // Both conditions must be false for the loop to exit. assert(L->contains(TBB) && "Loop block has no successor in loop!"); - if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute) + if (BTI0.Exact != getCouldNotCompute() && + BTI1.Exact != getCouldNotCompute()) BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact); - if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute) + if (BTI0.Max != getCouldNotCompute() && + BTI1.Max != getCouldNotCompute()) MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max); } @@ -2992,7 +3552,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast(ExitCond->getOperand(1))) { - const SCEV* ItCnt = + const SCEV *ItCnt = ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); if (!isa(ItCnt)) { unsigned BitWidth = getTypeSizeInBits(ItCnt->getType()); @@ -3002,14 +3562,14 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, } } - const SCEV* LHS = getSCEV(ExitCond->getOperand(0)); - const SCEV* RHS = getSCEV(ExitCond->getOperand(1)); + const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); + const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); RHS = getSCEVAtScope(RHS, L); - // At this point, we would like to compute how many iterations of the + // At this point, we would like to compute how many iterations of the // loop the predicate will return true for these inputs. if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { // If there is a loop-invariant, force it into the RHS. @@ -3026,20 +3586,20 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, ConstantRange CompRange( ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); - const SCEV* Ret = AddRec->getNumIterationsInRange(CompRange, *this); + const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; } switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - const SCEV* TC = HowFarToZero(getMinusSCEV(LHS, RHS), L); + const SCEV *TC = HowFarToZero(getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } - case ICmpInst::ICMP_EQ: { - // Convert to: while (X-Y == 0) // while (X == Y) - const SCEV* TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + case ICmpInst::ICMP_EQ: { // while (X == Y) + // Convert to: while (X-Y == 0) + const SCEV *TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); if (!isa(TC)) return TC; break; } @@ -3071,7 +3631,7 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, if (ExitCond->getOperand(0)->getType()->isUnsigned()) errs() << "[unsigned] "; errs() << *LHS << " " - << Instruction::getOpcodeName(Instruction::ICmp) + << Instruction::getOpcodeName(Instruction::ICmp) << " " << *RHS << "\n"; #endif break; @@ -3083,8 +3643,8 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L, static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE) { - const SCEV* InVal = SE.getConstant(C); - const SCEV* Val = AddRec->evaluateAtIteration(InVal, SE); + const SCEV *InVal = SE.getConstant(C); + const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); @@ -3114,7 +3674,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, if (Idx >= ATy->getNumElements()) return 0; // Bogus program Init = Constant::getNullValue(ATy->getElementType()); } else { - assert(0 && "Unknown constant aggregate type!"); + llvm_unreachable("Unknown constant aggregate type!"); } return 0; } else { @@ -3127,23 +3687,25 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of /// 'icmp op load X, cst', try to see if we can compute the backedge /// execution count. -const SCEV* ScalarEvolution:: -ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, - const Loop *L, - ICmpInst::Predicate predicate) { - if (LI->isVolatile()) return CouldNotCompute; +const SCEV * +ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount( + LoadInst *LI, + Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { + if (LI->isVolatile()) return getCouldNotCompute(); // Check to see if the loaded pointer is a getelementptr of a global. GetElementPtrInst *GEP = dyn_cast(LI->getOperand(0)); - if (!GEP) return CouldNotCompute; + if (!GEP) return getCouldNotCompute(); // Make sure that it is really a constant global we are gepping, with an // initializer, and make sure the first IDX is really 0. GlobalVariable *GV = dyn_cast(GEP->getOperand(0)); - if (!GV || !GV->isConstant() || !GV->hasInitializer() || + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || GEP->getNumOperands() < 3 || !isa(GEP->getOperand(1)) || !cast(GEP->getOperand(1))->isNullValue()) - return CouldNotCompute; + return getCouldNotCompute(); // Okay, we allow one non-constant index into the GEP instruction. Value *VarIdx = 0; @@ -3153,7 +3715,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, if (ConstantInt *CI = dyn_cast(GEP->getOperand(i))) { Indexes.push_back(CI); } else if (!isa(GEP->getOperand(i))) { - if (VarIdx) return CouldNotCompute; // Multiple non-constant idx's. + if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. VarIdx = GEP->getOperand(i); VarIdxNum = i-2; Indexes.push_back(0); @@ -3161,7 +3723,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. // Check to see if X is a loop variant variable value now. - const SCEV* Idx = getSCEV(VarIdx); + const SCEV *Idx = getSCEV(VarIdx); Idx = getSCEVAtScope(Idx, L); // We can only recognize very limited forms of loop index expressions, in @@ -3170,12 +3732,12 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) || !isa(IdxExpr->getOperand(0)) || !isa(IdxExpr->getOperand(1))) - return CouldNotCompute; + return getCouldNotCompute(); unsigned MaxSteps = MaxBruteForceIterations; for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { - ConstantInt *ItCst = - ConstantInt::get(cast(IdxExpr->getType()), IterationNum); + ConstantInt *ItCst = ConstantInt::get( + cast(IdxExpr->getType()), IterationNum); ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); // Form the GEP offset. @@ -3197,7 +3759,7 @@ ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, return getConstant(ItCst); // Found terminating iteration! } } - return CouldNotCompute; + return getCouldNotCompute(); } @@ -3223,7 +3785,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // If this is not an instruction, or if this is an instruction outside of the // loop, it can't be derived from a loop PHI. Instruction *I = dyn_cast(V); - if (I == 0 || !L->contains(I->getParent())) return 0; + if (I == 0 || !L->contains(I)) return 0; if (PHINode *PN = dyn_cast(I)) { if (L->getHeader() == I->getParent()) @@ -3260,7 +3822,8 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node /// in the loop has the value PHIVal. If we can't fold this expression for some /// reason, return null. -static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { +static Constant *EvaluateExpression(Value *V, Constant *PHIVal, + const TargetData *TD) { if (isa(V)) return PHIVal; if (Constant *C = dyn_cast(V)) return C; if (GlobalValue *GV = dyn_cast(V)) return GV; @@ -3270,24 +3833,25 @@ static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { Operands.resize(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal); + Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD); if (Operands[i] == 0) return 0; } if (const CmpInst *CI = dyn_cast(I)) - return ConstantFoldCompareInstOperands(CI->getPredicate(), - &Operands[0], Operands.size()); - else - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size()); + return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], + Operands[1], TD); + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size(), TD); } /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence /// involving constants, fold it. -Constant *ScalarEvolution:: -getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ +Constant * +ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, + const APInt &BEs, + const Loop *L) { std::map::iterator I = ConstantEvolutionLoopExitValue.find(PN); if (I != ConstantEvolutionLoopExitValue.end()) @@ -3323,7 +3887,7 @@ getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ return RetVal = PHIVal; // Got exit value! // Compute the value of the PHI node for the next iteration. - Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD); if (NextPHI == PHIVal) return RetVal = NextPHI; // Stopped evolving! if (NextPHI == 0) @@ -3332,15 +3896,17 @@ getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ } } -/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a +/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a /// constant number of times (the condition evolves only from constants), /// try to evaluate a few iterations of the loop until we get the exit /// condition gets a value of ExitWhen (true or false). If we cannot -/// evaluate the trip count of the loop, return CouldNotCompute. -const SCEV* ScalarEvolution:: -ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { +/// evaluate the trip count of the loop, return getCouldNotCompute(). +const SCEV * +ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); - if (PN == 0) return CouldNotCompute; + if (PN == 0) return getCouldNotCompute(); // Since the loop is canonicalized, the PHI node must have two entries. One // entry must be a constant (coming in from outside of the loop), and the @@ -3348,11 +3914,11 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); Constant *StartCST = dyn_cast(PN->getIncomingValue(!SecondIsBackedge)); - if (StartCST == 0) return CouldNotCompute; // Must be a constant. + if (StartCST == 0) return getCouldNotCompute(); // Must be a constant. Value *BEValue = PN->getIncomingValue(SecondIsBackedge); PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); - if (PN2 != PN) return CouldNotCompute; // Not derived from same PHI. + if (PN2 != PN) return getCouldNotCompute(); // Not derived from same PHI. // Okay, we find a PHI node that defines the trip count of this loop. Execute // the loop symbolically to determine when the condition gets a value of @@ -3362,29 +3928,28 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) for (Constant *PHIVal = StartCST; IterationNum != MaxIterations; ++IterationNum) { ConstantInt *CondVal = - dyn_cast_or_null(EvaluateExpression(Cond, PHIVal)); + dyn_cast_or_null(EvaluateExpression(Cond, PHIVal, TD)); // Couldn't symbolically evaluate. - if (!CondVal) return CouldNotCompute; + if (!CondVal) return getCouldNotCompute(); if (CondVal->getValue() == uint64_t(ExitWhen)) { - ConstantEvolutionLoopExitValue[PN] = PHIVal; ++NumBruteForceTripCountsComputed; - return getConstant(Type::Int32Ty, IterationNum); + return getConstant(Type::getInt32Ty(getContext()), IterationNum); } // Compute the value of the PHI node for the next iteration. - Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD); if (NextPHI == 0 || NextPHI == PHIVal) - return CouldNotCompute; // Couldn't evaluate or not making progress... + return getCouldNotCompute();// Couldn't evaluate or not making progress... PHIVal = NextPHI; } // Too many iterations were needed to evaluate. - return CouldNotCompute; + return getCouldNotCompute(); } -/// getSCEVAtScope - Return a SCEV expression handle for the specified value +/// getSCEVAtScope - Return a SCEV expression for the specified value /// at the specified scope in the program. The L value specifies a loop /// nest to evaluate the expression at, where null is the top-level or a /// specified loop is immediately inside of the loop. @@ -3394,9 +3959,21 @@ ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) /// /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. -const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { - // FIXME: this should be turned into a virtual method on SCEV! +const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { + // Check to see if we've folded this expression at this loop before. + std::map &Values = ValuesAtScopes[V]; + std::pair::iterator, bool> Pair = + Values.insert(std::make_pair(L, static_cast(0))); + if (!Pair.second) + return Pair.first->second ? Pair.first->second : V; + + // Otherwise compute it. + const SCEV *C = computeSCEVAtScope(V, L); + ValuesAtScopes[V][L] = C; + return C; +} +const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (isa(V)) return V; // If this instruction is evolved from a constant-evolving PHI, compute the @@ -3411,7 +3988,7 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. - const SCEV* BackedgeTakenCount = getBackedgeTakenCount(LI); + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); if (const SCEVConstant *BTCC = dyn_cast(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If @@ -3420,7 +3997,7 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { Constant *RV = getConstantEvolutionLoopExitValue(PN, BTCC->getValue()->getValue(), LI); - if (RV) return getUnknown(RV); + if (RV) return getSCEV(RV); } } @@ -3429,13 +4006,6 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // the arguments into constants, and if so, try to constant propagate the // result. This is particularly useful for computing loop exit values. if (CanConstantFold(I)) { - // Check to see if we've folded this instruction at this loop before. - std::map &Values = ValuesAtScopes[I]; - std::pair::iterator, bool> Pair = - Values.insert(std::make_pair(L, static_cast(0))); - if (!Pair.second) - return Pair.first->second ? &*getUnknown(Pair.first->second) : V; - std::vector Operands; Operands.reserve(I->getNumOperands()); for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { @@ -3449,7 +4019,7 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - const SCEV* OpV = getSCEVAtScope(getSCEV(Op), L); + const SCEV *OpV = getSCEVAtScope(Op, L); if (const SCEVConstant *SC = dyn_cast(OpV)) { Constant *C = SC->getValue(); if (C->getType() != Op->getType()) @@ -3474,16 +4044,15 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } } } - + Constant *C; if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), - &Operands[0], Operands.size()); + Operands[0], Operands[1], TD); else C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), - &Operands[0], Operands.size()); - Pair.first->second = C; - return getUnknown(C); + &Operands[0], Operands.size(), TD); + return getSCEV(C); } } @@ -3495,11 +4064,12 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { - const SCEV* OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); if (OpAtScope != Comm->getOperand(i)) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps(Comm->op_begin(), Comm->op_begin()+i); + SmallVector NewOps(Comm->op_begin(), + Comm->op_begin()+i); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) { @@ -3514,7 +4084,7 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { return getSMaxExpr(NewOps); if (isa(Comm)) return getUMaxExpr(NewOps); - assert(0 && "Unknown commutative SCEV type!"); + llvm_unreachable("Unknown commutative SCEV type!"); } } // If we got here, all operands are loop invariant. @@ -3522,8 +4092,8 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } if (const SCEVUDivExpr *Div = dyn_cast(V)) { - const SCEV* LHS = getSCEVAtScope(Div->getLHS(), L); - const SCEV* RHS = getSCEVAtScope(Div->getRHS(), L); + const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); + const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); if (LHS == Div->getLHS() && RHS == Div->getRHS()) return Div; // must be loop invariant return getUDivExpr(LHS, RHS); @@ -3532,11 +4102,11 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { // If this is a loop recurrence for a loop that does not contain L, then we // are dealing with the final value computed by the loop. if (const SCEVAddRecExpr *AddRec = dyn_cast(V)) { - if (!L || !AddRec->getLoop()->contains(L->getHeader())) { + if (!L || !AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. - const SCEV* BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); - if (BackedgeTakenCount == CouldNotCompute) return AddRec; + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; // Then, evaluate the AddRec. return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); @@ -3545,33 +4115,36 @@ const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { } if (const SCEVZeroExtendExpr *Cast = dyn_cast(V)) { - const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getZeroExtendExpr(Op, Cast->getType()); } if (const SCEVSignExtendExpr *Cast = dyn_cast(V)) { - const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getSignExtendExpr(Op, Cast->getType()); } if (const SCEVTruncateExpr *Cast = dyn_cast(V)) { - const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L); + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); if (Op == Cast->getOperand()) return Cast; // must be loop invariant return getTruncateExpr(Op, Cast->getType()); } - assert(0 && "Unknown SCEV type!"); + if (isa(V)) + return V; + + llvm_unreachable("Unknown SCEV type!"); return 0; } /// getSCEVAtScope - This is a convenience function which does /// getSCEVAtScope(getSCEV(V), L). -const SCEV* ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { +const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } @@ -3584,7 +4157,7 @@ const SCEV* ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. -static const SCEV* SolveLinEquationWithOverflow(const APInt &A, const APInt &B, +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); assert(BW == B.getBitWidth() && "Bit widths must be the same."); @@ -3627,7 +4200,7 @@ static const SCEV* SolveLinEquationWithOverflow(const APInt &A, const APInt &B, /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which /// might be the same) or two SCEVCouldNotCompute objects. /// -static std::pair +static std::pair SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); @@ -3647,7 +4220,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { APInt Two(BitWidth, 2); APInt Four(BitWidth, 4); - { + { using namespace APIntOps; const APInt& C = L; // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C @@ -3667,7 +4240,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // integer value or else APInt::sqrt() will assert. APInt SqrtVal(SqrtTerm.sqrt()); - // Compute the two solutions for the quadratic formula. + // Compute the two solutions for the quadratic formula. // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA( A << 1 ); @@ -3676,27 +4249,31 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { return std::make_pair(CNC, CNC); } - ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); + LLVMContext &Context = SE.getContext(); - return std::make_pair(SE.getConstant(Solution1), + ConstantInt *Solution1 = + ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = + ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + + return std::make_pair(SE.getConstant(Solution1), SE.getConstant(Solution2)); } // end APIntOps namespace } /// HowFarToZero - Return the number of times a backedge comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. -const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { +const SCEV *ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. if (C->getValue()->isZero()) return C; - return CouldNotCompute; // Otherwise it will loop infinitely. + return getCouldNotCompute(); // Otherwise it will loop infinitely. } const SCEVAddRecExpr *AddRec = dyn_cast(V); if (!AddRec || AddRec->getLoop() != L) - return CouldNotCompute; + return getCouldNotCompute(); if (AddRec->isAffine()) { // If this is an affine expression, the execution count of this branch is @@ -3711,15 +4288,17 @@ const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // where BW is the common bit width of Start and Step. // Get the initial value for the loop. - const SCEV* Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); - const SCEV* Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + const SCEV *Start = getSCEVAtScope(AddRec->getStart(), + L->getParentLoop()); + const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), + L->getParentLoop()); if (const SCEVConstant *StepC = dyn_cast(Step)) { // For now we handle only constant steps. // First, handle unitary steps. if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: - return getNegativeSCEV(Start); // N = -Start (as unsigned) + return getNegativeSCEV(Start); // N = -Start (as unsigned) if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: return Start; // N = Start (as unsigned) @@ -3732,7 +4311,7 @@ const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. - std::pair Roots = SolveQuadraticEquation(AddRec, + std::pair Roots = SolveQuadraticEquation(AddRec, *this); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); @@ -3743,7 +4322,7 @@ const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { #endif // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -3751,20 +4330,20 @@ const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. - const SCEV* Val = AddRec->evaluateAtIteration(R1, *this); + const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) return R1; // We found a quadratic root! } } } - return CouldNotCompute; + return getCouldNotCompute(); } /// HowFarToNonZero - Return the number of times a backedge checking the /// specified value for nonzero will execute. If not computable, return /// CouldNotCompute -const SCEV* ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +const SCEV *ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -3774,12 +4353,12 @@ const SCEV* ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { if (const SCEVConstant *C = dyn_cast(V)) { if (!C->getValue()->isNullValue()) return getIntegerSCEV(0, C->getType()); - return CouldNotCompute; // Otherwise it will loop infinitely. + return getCouldNotCompute(); // Otherwise it will loop infinitely. } // We could implement others, but I really doubt anyone writes loops like // this, and if they did, they would already be constant folded. - return CouldNotCompute; + return getCouldNotCompute(); } /// getLoopPredecessor - If the given loop's header has exactly one unique @@ -3825,7 +4404,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { /// more general, since a front-end may have replicated the controlling /// expression. /// -static bool HasSameValue(const SCEV* A, const SCEV* B) { +static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. if (A == B) return true; @@ -3835,19 +4414,142 @@ static bool HasSameValue(const SCEV* A, const SCEV* B) { if (const SCEVUnknown *BU = dyn_cast(B)) if (const Instruction *AI = dyn_cast(AU->getValue())) if (const Instruction *BI = dyn_cast(BU->getValue())) - if (AI->isIdenticalTo(BI)) + if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory()) return true; // Otherwise assume they may have a different value. return false; } -/// isLoopGuardedByCond - Test whether entry to the loop is protected by -/// a conditional between LHS and RHS. This is used to help avoid max -/// expressions in loop trip counts. -bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownNegative(const SCEV *S) { + return getSignedRange(S).getSignedMax().isNegative(); +} + +bool ScalarEvolution::isKnownPositive(const SCEV *S) { + return getSignedRange(S).getSignedMin().isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { + return !getSignedRange(S).getSignedMin().isNegative(); +} + +bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { + return !getSignedRange(S).getSignedMax().isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonZero(const SCEV *S) { + return isKnownNegative(S) || isKnownPositive(S); +} + +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + + if (HasSameValue(LHS, RHS)) + return ICmpInst::isTrueWhenEqual(Pred); + + switch (Pred) { + default: + llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + break; + case ICmpInst::ICMP_SGT: + Pred = ICmpInst::ICMP_SLT; + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLT: { + ConstantRange LHSRange = getSignedRange(LHS); + ConstantRange RHSRange = getSignedRange(RHS); + if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) + return true; + if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) + return false; + break; + } + case ICmpInst::ICMP_SGE: + Pred = ICmpInst::ICMP_SLE; + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLE: { + ConstantRange LHSRange = getSignedRange(LHS); + ConstantRange RHSRange = getSignedRange(RHS); + if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) + return true; + if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) + return false; + break; + } + case ICmpInst::ICMP_UGT: + Pred = ICmpInst::ICMP_ULT; + std::swap(LHS, RHS); + case ICmpInst::ICMP_ULT: { + ConstantRange LHSRange = getUnsignedRange(LHS); + ConstantRange RHSRange = getUnsignedRange(RHS); + if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) + return false; + break; + } + case ICmpInst::ICMP_UGE: + Pred = ICmpInst::ICMP_ULE; + std::swap(LHS, RHS); + case ICmpInst::ICMP_ULE: { + ConstantRange LHSRange = getUnsignedRange(LHS); + ConstantRange RHSRange = getUnsignedRange(RHS); + if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) + return true; + if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) + return false; + break; + } + case ICmpInst::ICMP_NE: { + if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) + return true; + if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) + return true; + + const SCEV *Diff = getMinusSCEV(LHS, RHS); + if (isKnownNonZero(Diff)) + return true; + break; + } + case ICmpInst::ICMP_EQ: + // The check at the top of the function catches the case where + // the values are known to be equal. + break; + } + return false; +} + +/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is +/// protected by a conditional between LHS and RHS. This is used to +/// to eliminate casts. +bool +ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) return true; + + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + + BranchInst *LoopContinuePredicate = + dyn_cast(Latch->getTerminator()); + if (!LoopContinuePredicate || + LoopContinuePredicate->isUnconditional()) + return false; + + return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS, + LoopContinuePredicate->getSuccessor(0) != L->getHeader()); +} + +/// isLoopGuardedByCond - Test whether entry to the loop is protected +/// by a conditional between LHS and RHS. This is used to help avoid max +/// expressions in loop trip counts, and to eliminate casts. +bool +ScalarEvolution::isLoopGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) return false; @@ -3868,85 +4570,276 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, LoopEntryPredicate->isUnconditional()) continue; - ICmpInst *ICI = dyn_cast(LoopEntryPredicate->getCondition()); - if (!ICI) continue; + if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS, + LoopEntryPredicate->getSuccessor(0) != PredecessorDest)) + return true; + } - // Now that we found a conditional branch that dominates the loop, check to - // see if it is the comparison we are looking for. - Value *PreCondLHS = ICI->getOperand(0); - Value *PreCondRHS = ICI->getOperand(1); - ICmpInst::Predicate Cond; - if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest) - Cond = ICI->getPredicate(); - else - Cond = ICI->getInversePredicate(); + return false; +} - if (Cond == Pred) - ; // An exact match. - else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) - ; // The actual condition is beyond sufficient. - else - // Check a few special cases. - switch (Cond) { - case ICmpInst::ICMP_UGT: - if (Pred == ICmpInst::ICMP_ULT) { - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_ULT; - break; - } - continue; - case ICmpInst::ICMP_SGT: - if (Pred == ICmpInst::ICMP_SLT) { - std::swap(PreCondLHS, PreCondRHS); - Cond = ICmpInst::ICMP_SLT; - break; - } - continue; - case ICmpInst::ICMP_NE: - // Expressions like (x >u 0) are often canonicalized to (x != 0), - // so check for this case by checking if the NE is comparing against - // a minimum or maximum constant. - if (!ICmpInst::isTrueWhenEqual(Pred)) - if (ConstantInt *CI = dyn_cast(PreCondRHS)) { - const APInt &A = CI->getValue(); - switch (Pred) { - case ICmpInst::ICMP_SLT: - if (A.isMaxSignedValue()) break; - continue; - case ICmpInst::ICMP_SGT: - if (A.isMinSignedValue()) break; - continue; - case ICmpInst::ICMP_ULT: - if (A.isMaxValue()) break; - continue; - case ICmpInst::ICMP_UGT: - if (A.isMinValue()) break; - continue; - default: - continue; - } - Cond = ICmpInst::ICMP_NE; - // NE is symmetric but the original comparison may not be. Swap - // the operands if necessary so that they match below. - if (isa(LHS)) - std::swap(PreCondLHS, PreCondRHS); - break; - } - continue; - default: - // We weren't able to reconcile the condition. - continue; +/// isImpliedCond - Test whether the condition described by Pred, LHS, +/// and RHS is true whenever the given Cond value evaluates to true. +bool ScalarEvolution::isImpliedCond(Value *CondValue, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + bool Inverse) { + // Recursivly handle And and Or conditions. + if (BinaryOperator *BO = dyn_cast(CondValue)) { + if (BO->getOpcode() == Instruction::And) { + if (!Inverse) + return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + } else if (BO->getOpcode() == Instruction::Or) { + if (Inverse) + return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) || + isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse); + } + } + + ICmpInst *ICI = dyn_cast(CondValue); + if (!ICI) return false; + + // Bail if the ICmp's operands' types are wider than the needed type + // before attempting to call getSCEV on them. This avoids infinite + // recursion, since the analysis of widening casts can require loop + // exit condition information for overflow checking, which would + // lead back here. + if (getTypeSizeInBits(LHS->getType()) < + getTypeSizeInBits(ICI->getOperand(0)->getType())) + return false; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + ICmpInst::Predicate FoundPred; + if (Inverse) + FoundPred = ICI->getInversePredicate(); + else + FoundPred = ICI->getPredicate(); + + const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); + const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); + + // Balance the types. The case where FoundLHS' type is wider than + // LHS' type is checked for above. + if (getTypeSizeInBits(LHS->getType()) > + getTypeSizeInBits(FoundLHS->getType())) { + if (CmpInst::isSigned(Pred)) { + FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); + } else { + FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); + FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); + } + } + + // Canonicalize the query to match the way instcombine will have + // canonicalized the comparison. + // First, put a constant operand on the right. + if (isa(LHS)) { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + // Then, canonicalize comparisons with boundary cases. + if (const SCEVConstant *RC = dyn_cast(RHS)) { + const APInt &RA = RC->getValue()->getValue(); + switch (Pred) { + default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_UGE: + if ((RA - 1).isMinValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMaxValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMinValue()) return true; + break; + case ICmpInst::ICMP_ULE: + if ((RA + 1).isMaxValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMinValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMaxValue()) return true; + break; + case ICmpInst::ICMP_SGE: + if ((RA - 1).isMinSignedValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMaxSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + break; } + if (RA.isMinSignedValue()) return true; + break; + case ICmpInst::ICMP_SLE: + if ((RA + 1).isMaxSignedValue()) { + Pred = ICmpInst::ICMP_NE; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMinSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + break; + } + if (RA.isMaxSignedValue()) return true; + break; + case ICmpInst::ICMP_UGT: + if (RA.isMinValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA + 1).isMaxValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMaxValue()) return false; + break; + case ICmpInst::ICMP_ULT: + if (RA.isMaxValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA - 1).isMinValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMinValue()) return false; + break; + case ICmpInst::ICMP_SGT: + if (RA.isMinSignedValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA + 1).isMaxSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA + 1); + break; + } + if (RA.isMaxSignedValue()) return false; + break; + case ICmpInst::ICMP_SLT: + if (RA.isMaxSignedValue()) { + Pred = ICmpInst::ICMP_NE; + break; + } + if ((RA - 1).isMinSignedValue()) { + Pred = ICmpInst::ICMP_EQ; + RHS = getConstant(RA - 1); + break; + } + if (RA.isMinSignedValue()) return false; + break; + } + } + + // Check to see if we can make the LHS or RHS match. + if (LHS == FoundRHS || RHS == FoundLHS) { + if (isa(RHS)) { + std::swap(FoundLHS, FoundRHS); + FoundPred = ICmpInst::getSwappedPredicate(FoundPred); + } else { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + } - if (!PreCondLHS->getType()->isInteger()) continue; + // Check whether the found predicate is the same as the desired predicate. + if (FoundPred == Pred) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); - const SCEV* PreCondLHSSCEV = getSCEV(PreCondLHS); - const SCEV* PreCondRHSSCEV = getSCEV(PreCondRHS); - if ((HasSameValue(LHS, PreCondLHSSCEV) && - HasSameValue(RHS, PreCondRHSSCEV)) || - (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) && - HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV)))) + // Check whether swapping the found predicate makes it the same as the + // desired predicate. + if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { + if (isa(RHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); + else + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), + RHS, LHS, FoundLHS, FoundRHS); + } + + // Check whether the actual condition is beyond sufficient. + if (FoundPred == ICmpInst::ICMP_EQ) + if (ICmpInst::isTrueWhenEqual(Pred)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + if (Pred == ICmpInst::ICMP_NE) + if (!ICmpInst::isTrueWhenEqual(FoundPred)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + + // Otherwise assume the worst. + return false; +} + +/// isImpliedCondOperands - Test whether the condition described by Pred, +/// LHS, and RHS is true whenever the condition desribed by Pred, FoundLHS, +/// and FoundRHS is true. +bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + return isImpliedCondOperandsHelper(Pred, LHS, RHS, + FoundLHS, FoundRHS) || + // ~x < ~y --> x > y + isImpliedCondOperandsHelper(Pred, LHS, RHS, + getNotSCEV(FoundRHS), + getNotSCEV(FoundLHS)); +} + +/// isImpliedCondOperandsHelper - Test whether the condition described by +/// Pred, LHS, and RHS is true whenever the condition desribed by Pred, +/// FoundLHS, and FoundRHS is true. +bool +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS) { + switch (Pred) { + default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + if (isKnownPredicate(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (isKnownPredicate(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (isKnownPredicate(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + return true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (isKnownPredicate(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownPredicate(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + return true; + break; } return false; @@ -3955,26 +4848,30 @@ bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, /// getBECount - Subtract the end and start values and divide by the step, /// rounding up, to get the number of times the backedge is executed. Return /// CouldNotCompute if an intermediate computation overflows. -const SCEV* ScalarEvolution::getBECount(const SCEV* Start, - const SCEV* End, - const SCEV* Step) { +const SCEV *ScalarEvolution::getBECount(const SCEV *Start, + const SCEV *End, + const SCEV *Step, + bool NoWrap) { const Type *Ty = Start->getType(); - const SCEV* NegOne = getIntegerSCEV(-1, Ty); - const SCEV* Diff = getMinusSCEV(End, Start); - const SCEV* RoundUp = getAddExpr(Step, NegOne); + const SCEV *NegOne = getIntegerSCEV(-1, Ty); + const SCEV *Diff = getMinusSCEV(End, Start); + const SCEV *RoundUp = getAddExpr(Step, NegOne); // Add an adjustment to the difference between End and Start so that // the division will effectively round up. - const SCEV* Add = getAddExpr(Diff, RoundUp); - - // Check Add for unsigned overflow. - // TODO: More sophisticated things could be done here. - const Type *WideTy = IntegerType::get(getTypeSizeInBits(Ty) + 1); - const SCEV* OperandExtendedAdd = - getAddExpr(getZeroExtendExpr(Diff, WideTy), - getZeroExtendExpr(RoundUp, WideTy)); - if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) - return CouldNotCompute; + const SCEV *Add = getAddExpr(Diff, RoundUp); + + if (!NoWrap) { + // Check Add for unsigned overflow. + // TODO: More sophisticated things could be done here. + const Type *WideTy = IntegerType::get(getContext(), + getTypeSizeInBits(Ty) + 1); + const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy); + const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp); + if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd) + return getCouldNotCompute(); + } return getUDivExpr(Add, Step); } @@ -3982,48 +4879,55 @@ const SCEV* ScalarEvolution::getBECount(const SCEV* Start, /// HowManyLessThans - Return the number of times a backedge containing the /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. -ScalarEvolution::BackedgeTakenInfo ScalarEvolution:: -HowManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool isSigned) { +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, + const Loop *L, bool isSigned) { // Only handle: "ADDREC < LoopInvariant". - if (!RHS->isLoopInvariant(L)) return CouldNotCompute; + if (!RHS->isLoopInvariant(L)) return getCouldNotCompute(); const SCEVAddRecExpr *AddRec = dyn_cast(LHS); if (!AddRec || AddRec->getLoop() != L) - return CouldNotCompute; + return getCouldNotCompute(); + + // Check to see if we have a flag which makes analysis easy. + bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() : + AddRec->hasNoUnsignedWrap(); if (AddRec->isAffine()) { // FORNOW: We only support unit strides. unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); - const SCEV* Step = AddRec->getStepRecurrence(*this); + const SCEV *Step = AddRec->getStepRecurrence(*this); // TODO: handle non-constant strides. const SCEVConstant *CStep = dyn_cast(Step); if (!CStep || CStep->isZero()) - return CouldNotCompute; + return getCouldNotCompute(); if (CStep->isOne()) { // With unit stride, the iteration never steps past the limit value. } else if (CStep->getValue()->getValue().isStrictlyPositive()) { - if (const SCEVConstant *CLimit = dyn_cast(RHS)) { + if (NoWrap) { + // We know the iteration won't step past the maximum value for its type. + ; + } else if (const SCEVConstant *CLimit = dyn_cast(RHS)) { // Test whether a positive iteration iteration can step past the limit // value and past the maximum value for its type in a single step. if (isSigned) { APInt Max = APInt::getSignedMaxValue(BitWidth); if ((Max - CStep->getValue()->getValue()) .slt(CLimit->getValue()->getValue())) - return CouldNotCompute; + return getCouldNotCompute(); } else { APInt Max = APInt::getMaxValue(BitWidth); if ((Max - CStep->getValue()->getValue()) .ult(CLimit->getValue()->getValue())) - return CouldNotCompute; + return getCouldNotCompute(); } } else // TODO: handle non-constant limit values below. - return CouldNotCompute; + return getCouldNotCompute(); } else // TODO: handle negative strides below. - return CouldNotCompute; + return getCouldNotCompute(); // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant // m. So, we count the number of iterations in which {n,+,s} < m is true. @@ -4031,44 +4935,42 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // treat m-n as signed nor unsigned due to overflow possibility. // First, we get the value of the LHS in the first iteration: n - const SCEV* Start = AddRec->getOperand(0); + const SCEV *Start = AddRec->getOperand(0); // Determine the minimum constant start value. - const SCEV* MinStart = isa(Start) ? Start : - getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) : - APInt::getMinValue(BitWidth)); + const SCEV *MinStart = getConstant(isSigned ? + getSignedRange(Start).getSignedMin() : + getUnsignedRange(Start).getUnsignedMin()); // If we know that the condition is true in order to enter the loop, // then we know that it will run exactly (m-n)/s times. Otherwise, we // only know that it will execute (max(m,n)-n)/s times. In both cases, // the division must round up. - const SCEV* End = RHS; + const SCEV *End = RHS; if (!isLoopGuardedByCond(L, - isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + isSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, getMinusSCEV(Start, Step), RHS)) End = isSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); // Determine the maximum constant end value. - const SCEV* MaxEnd = - isa(End) ? End : - getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) - .ashr(GetMinSignBits(End) - 1) : - APInt::getMaxValue(BitWidth) - .lshr(GetMinLeadingZeros(End))); + const SCEV *MaxEnd = getConstant(isSigned ? + getSignedRange(End).getSignedMax() : + getUnsignedRange(End).getUnsignedMax()); // Finally, we subtract these two values and divide, rounding up, to get // the number of times the backedge is executed. - const SCEV* BECount = getBECount(Start, End, Step); + const SCEV *BECount = getBECount(Start, End, Step, NoWrap); // The maximum backedge count is similar, except using the minimum start // value and the maximum end value. - const SCEV* MaxBECount = getBECount(MinStart, MaxEnd, Step);; + const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap); return BackedgeTakenInfo(BECount, MaxBECount); } - return CouldNotCompute; + return getCouldNotCompute(); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -4076,17 +4978,17 @@ HowManyLessThans(const SCEV *LHS, const SCEV *RHS, /// this is that it returns the first iteration number where the value is not in /// the condition, thus computing the exit count. If the iteration count can't /// be computed, an instance of SCEVCouldNotCompute is returned. -const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, - ScalarEvolution &SE) const { +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, + ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { - SmallVector Operands(op_begin(), op_end()); + SmallVector Operands(op_begin(), op_end()); Operands[0] = SE.getIntegerSCEV(0, SC->getType()); - const SCEV* Shifted = SE.getAddRecExpr(Operands, getLoop()); + const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop()); if (const SCEVAddRecExpr *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( @@ -4125,7 +5027,7 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // The exit value should be (End+A)/A. APInt ExitVal = (End + A).udiv(A); - ConstantInt *ExitValue = ConstantInt::get(ExitVal); + ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); // Evaluate at the exit value. If we really did fall out of the valid // range, then we computed our trip count, otherwise wrap around or other @@ -4136,8 +5038,8 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( - EvaluateConstantChrecAtConstant(this, - ConstantInt::get(ExitVal - One), SE)->getValue()) && + EvaluateConstantChrecAtConstant(this, + ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); } else if (isQuadratic()) { @@ -4145,20 +5047,20 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // quadratic equation to solve it. To do this, we must frame our problem in // terms of figuring out when zero is crossed, instead of when // Range.getUpper() is crossed. - SmallVector NewOps(op_begin(), op_end()); + SmallVector NewOps(op_begin(), op_end()); NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - const SCEV* NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); + const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); // Next, solve the constructed addrec - std::pair Roots = + std::pair Roots = SolveQuadraticEquation(cast(NewAddRec), SE); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { // Pick the smallest positive root value. if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, - R1->getValue(), R2->getValue()))) { + dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + R1->getValue(), R2->getValue()))) { if (CB->getZExtValue() == false) std::swap(R1, R2); // R1 is the minimum root now. @@ -4170,7 +5072,8 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, SE); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... - ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); + ConstantInt *NextVal = + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) @@ -4180,7 +5083,8 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. - ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); + ConstantInt *NextVal = + ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; @@ -4199,22 +5103,21 @@ const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, //===----------------------------------------------------------------------===// void ScalarEvolution::SCEVCallbackVH::deleted() { - assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(getValPtr())) - SE->ValuesAtScopes.erase(I); SE->Scalars.erase(getValPtr()); // this now dangles! } void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { - assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); // Forget all the expressions associated with users of the old value, // so that future queries will recompute the expressions using the new // value. SmallVector Worklist; + SmallPtrSet Visited; Value *Old = getValPtr(); bool DeleteOld = false; for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end(); @@ -4228,20 +5131,19 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { DeleteOld = true; continue; } + if (!Visited.insert(U)) + continue; if (PHINode *PN = dyn_cast(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(U)) - SE->ValuesAtScopes.erase(I); - if (SE->Scalars.erase(U)) - for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); - UI != UE; ++UI) - Worklist.push_back(*UI); + SE->Scalars.erase(U); + for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); + UI != UE; ++UI) + Worklist.push_back(*UI); } + // Delete the Old value if it (indirectly) references itself. if (DeleteOld) { if (PHINode *PN = dyn_cast(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); - if (Instruction *I = dyn_cast(Old)) - SE->ValuesAtScopes.erase(I); SE->Scalars.erase(Old); // this now dangles! } @@ -4256,7 +5158,7 @@ ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) //===----------------------------------------------------------------------===// ScalarEvolution::ScalarEvolution() - : FunctionPass(&ID), CouldNotCompute(new SCEVCouldNotCompute()) { + : FunctionPass(&ID) { } bool ScalarEvolution::runOnFunction(Function &F) { @@ -4271,45 +5173,8 @@ void ScalarEvolution::releaseMemory() { BackedgeTakenCounts.clear(); ConstantEvolutionLoopExitValue.clear(); ValuesAtScopes.clear(); - - for (std::map::iterator - I = SCEVConstants.begin(), E = SCEVConstants.end(); I != E; ++I) - delete I->second; - for (std::map, - SCEVTruncateExpr*>::iterator I = SCEVTruncates.begin(), - E = SCEVTruncates.end(); I != E; ++I) - delete I->second; - for (std::map, - SCEVZeroExtendExpr*>::iterator I = SCEVZeroExtends.begin(), - E = SCEVZeroExtends.end(); I != E; ++I) - delete I->second; - for (std::map >, - SCEVCommutativeExpr*>::iterator I = SCEVCommExprs.begin(), - E = SCEVCommExprs.end(); I != E; ++I) - delete I->second; - for (std::map, SCEVUDivExpr*>::iterator - I = SCEVUDivs.begin(), E = SCEVUDivs.end(); I != E; ++I) - delete I->second; - for (std::map, - SCEVSignExtendExpr*>::iterator I = SCEVSignExtends.begin(), - E = SCEVSignExtends.end(); I != E; ++I) - delete I->second; - for (std::map >, - SCEVAddRecExpr*>::iterator I = SCEVAddRecExprs.begin(), - E = SCEVAddRecExprs.end(); I != E; ++I) - delete I->second; - for (std::map::iterator I = SCEVUnknowns.begin(), - E = SCEVUnknowns.end(); I != E; ++I) - delete I->second; - - SCEVConstants.clear(); - SCEVTruncates.clear(); - SCEVZeroExtends.clear(); - SCEVCommExprs.clear(); - SCEVUDivs.clear(); - SCEVSignExtends.clear(); - SCEVAddRecExprs.clear(); - SCEVUnknowns.clear(); + UniqueSCEVs.clear(); + SCEVAllocator.Reset(); } void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { @@ -4329,7 +5194,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Loop " << L->getHeader()->getName() << ": "; - SmallVector ExitBlocks; + SmallVector ExitBlocks; L->getExitBlocks(ExitBlocks); if (ExitBlocks.size() != 1) OS << " "; @@ -4352,26 +5217,26 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "\n"; } -void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { +void ScalarEvolution::print(raw_ostream &OS, const Module *) const { // ScalarEvolution's implementaiton of the print method is to print // out SCEV values of all instructions that are interesting. Doing // this potentially causes it to create new SCEV objects though, // which technically conflicts with the const qualifier. This isn't - // observable from outside the class though (the hasSCEV function - // notwithstanding), so casting away the const isn't dangerous. - ScalarEvolution &SE = *const_cast(this); + // observable from outside the class though, so casting away the + // const isn't dangerous. + ScalarEvolution &SE = *const_cast(this); OS << "Classifying expressions for: " << F->getName() << "\n"; for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) if (isSCEVable(I->getType())) { - OS << *I; + OS << *I << '\n'; OS << " --> "; - const SCEV* SV = SE.getSCEV(&*I); + const SCEV *SV = SE.getSCEV(&*I); SV->print(OS); const Loop *L = LI->getLoopFor((*I).getParent()); - const SCEV* AtUse = SE.getSCEVAtScope(SV, L); + const SCEV *AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { OS << " --> "; AtUse->print(OS); @@ -4379,7 +5244,7 @@ void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { if (L) { OS << "\t\t" "Exits: "; - const SCEV* ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); + const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!ExitValue->isLoopInvariant(L)) { OS << "<>"; } else { @@ -4395,7 +5260,3 @@ void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { PrintLoopInfo(OS, &SE, *I); } -void ScalarEvolution::print(std::ostream &o, const Module *M) const { - raw_os_ostream OS(o); - print(OS, M); -}