X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=d615c752b0444f4c91ec9284fe109ddab3736fad;hb=73b43b9b549a75fb0015c825df68abd95705a67c;hp=4b4b97e28624bb9adb706ec585a026f8d45603a2;hpb=f7b37b2d0e28cd2f2ecc03e3e6e470353dca5725;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 4b4b97e2862..d615c752b04 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -95,16 +95,14 @@ STATISTIC(NumTripCountsNotComputed, STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); -cl::opt +static cl::opt MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant derived loop"), cl::init(100)); -namespace { - RegisterPass - R("scalar-evolution", "Scalar Evolution Analysis", false, true); -} +static RegisterPass +R("scalar-evolution", "Scalar Evolution Analysis", false, true); char ScalarEvolution::ID = 0; //===----------------------------------------------------------------------===// @@ -134,6 +132,12 @@ uint32_t SCEV::getBitWidth() const { return 0; } +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast(this)) + return SC->getValue()->isZero(); + return false; +} + SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} @@ -494,28 +498,13 @@ SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { if (Val == 0) C = Constant::getNullValue(Ty); else if (Ty->isFloatingPoint()) - C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : - APFloat::IEEEdouble, Val)); + C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); else C = ConstantInt::get(Ty, Val); return getUnknown(C); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. -static SCEVHandle getTruncateOrZeroExtend(const SCEVHandle &V, const Type *Ty, - ScalarEvolution &SE) { - const Type *SrcTy = V->getType(); - assert(SrcTy->isInteger() && Ty->isInteger() && - "Cannot truncate or zero extend with non-integer arguments!"); - if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) - return V; // No conversion - if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) - return SE.getTruncateExpr(V, Ty); - return SE.getZeroExtendExpr(V, Ty); -} - /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V /// SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { @@ -587,7 +576,7 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, #endif const IntegerType *DividendTy = IntegerType::get(DividendBits); - const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy); + const SCEVHandle ExIt = SE.getTruncateOrZeroExtend(It, DividendTy); // The final number of bits we need to perform the division is the maximum of // dividend and divisor bitwidths. @@ -609,7 +598,12 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, Dividend *= N-(K-1); if (DividendTy != DivisionTy) Dividend = Dividend.zext(DivisionTy->getBitWidth()); - return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth())); + + APInt Result = Dividend.udiv(Divisor); + if (Result.getBitWidth() != It->getBitWidth()) + Result = Result.trunc(It->getBitWidth()); + + return SE.getConstant(Result); } SCEVHandle Dividend = ExIt; @@ -617,11 +611,12 @@ static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, Dividend = SE.getMulExpr(Dividend, SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy))); - if (DividendTy != DivisionTy) - Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy); - return - SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)), - It->getType()); + + return SE.getTruncateOrZeroExtend( + SE.getUDivExpr( + SE.getTruncateOrZeroExtend(Dividend, DivisionTy), + SE.getConstant(Divisor) + ), It->getType()); } /// evaluateAtIteration - Return the value of this chain of recurrences at @@ -705,6 +700,21 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type * return Result; } +/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion +/// of the input value to the specified type. If the type must be +/// extended, it is zero extended. +SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert(SrcTy->isInteger() && Ty->isInteger() && + "Cannot truncate or zero extend with non-integer arguments!"); + if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits()) + return V; // No conversion + if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()) + return getTruncateExpr(V, Ty); + return getZeroExtendExpr(V, Ty); +} + // get - Get a canonical add expression, or something simpler if possible. SCEVHandle ScalarEvolution::getAddExpr(std::vector &Ops) { assert(!Ops.empty() && "Cannot get empty add!"); @@ -1132,11 +1142,10 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector &Operands, const Loop *L) { if (Operands.size() == 1) return Operands[0]; - if (SCEVConstant *StepC = dyn_cast(Operands.back())) - if (StepC->getValue()->isZero()) { - Operands.pop_back(); - return getAddRecExpr(Operands, L); // { X,+,0 } --> X - } + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L); // { X,+,0 } --> X + } SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, std::vector(Operands.begin(), @@ -1695,118 +1704,125 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { if (!isa(V->getType())) return SE.getUnknown(V); - if (Instruction *I = dyn_cast(V)) { - switch (I->getOpcode()) { - case Instruction::Add: - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Mul: - return SE.getMulExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::UDiv: - return SE.getUDivExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Sub: - return SE.getMinusSCEV(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - SCEVHandle LHS = getSCEV(I->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return SE.getAddExpr(LHS, getSCEV(I->getOperand(1))); - } - break; - case Instruction::Xor: - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { - if (CI->getValue().isSignBit()) - return SE.getAddExpr(getSCEV(I->getOperand(0)), - getSCEV(I->getOperand(1))); - else if (CI->isAllOnesValue()) - return SE.getNotSCEV(getSCEV(I->getOperand(0))); - } - break; - - case Instruction::Shl: - // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast(I->getOperand(1))) { - uint32_t BitWidth = cast(V->getType())->getBitWidth(); - Constant *X = ConstantInt::get( - APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); - return SE.getMulExpr(getSCEV(I->getOperand(0)), getSCEV(X)); - } - break; - - case Instruction::Trunc: - return SE.getTruncateExpr(getSCEV(I->getOperand(0)), I->getType()); - - case Instruction::ZExt: - return SE.getZeroExtendExpr(getSCEV(I->getOperand(0)), I->getType()); - - case Instruction::SExt: - return SE.getSignExtendExpr(getSCEV(I->getOperand(0)), I->getType()); - - case Instruction::BitCast: - // BitCasts are no-op casts so we just eliminate the cast. - if (I->getType()->isInteger() && - I->getOperand(0)->getType()->isInteger()) - return getSCEV(I->getOperand(0)); - break; - - case Instruction::PHI: - return createNodeForPHI(cast(I)); - - case Instruction::Select: - // This could be a smax or umax that was lowered earlier. - // Try to recover it. - if (ICmpInst *ICI = dyn_cast(I->getOperand(0))) { - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); - switch (ICI->getPredicate()) { - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) - return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); - else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) - // -smax(-x, -y) == smin(x, y). - return SE.getNegativeSCEV(SE.getSMaxExpr( - SE.getNegativeSCEV(getSCEV(LHS)), - SE.getNegativeSCEV(getSCEV(RHS)))); - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - std::swap(LHS, RHS); - // fall through - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - if (LHS == I->getOperand(1) && RHS == I->getOperand(2)) - return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); - else if (LHS == I->getOperand(2) && RHS == I->getOperand(1)) - // ~umax(~x, ~y) == umin(x, y) - return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), - SE.getNotSCEV(getSCEV(RHS)))); - break; - default: - break; - } - } + unsigned Opcode = Instruction::UserOp1; + if (Instruction *I = dyn_cast(V)) + Opcode = I->getOpcode(); + else if (ConstantExpr *CE = dyn_cast(V)) + Opcode = CE->getOpcode(); + else + return SE.getUnknown(V); - default: // We cannot analyze this expression. - break; + User *U = cast(V); + switch (Opcode) { + case Instruction::Add: + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Mul: + return SE.getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::UDiv: + return SE.getUDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Sub: + return SE.getMinusSCEV(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + SCEVHandle LHS = getSCEV(U->getOperand(0)); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) + return SE.getAddExpr(LHS, getSCEV(U->getOperand(1))); + } + break; + case Instruction::Xor: + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (ConstantInt *CI = dyn_cast(U->getOperand(1))) { + if (CI->getValue().isSignBit()) + return SE.getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + else if (CI->isAllOnesValue()) + return SE.getNotSCEV(getSCEV(U->getOperand(0))); } + break; + + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast(U->getOperand(1))) { + uint32_t BitWidth = cast(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; + + case Instruction::Trunc: + return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::ZExt: + return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::SExt: + return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (U->getType()->isInteger() && + U->getOperand(0)->getType()->isInteger()) + return getSCEV(U->getOperand(0)); + break; + + case Instruction::PHI: + return createNodeForPHI(cast(U)); + + case Instruction::Select: + // This could be a smax or umax that was lowered earlier. + // Try to recover it. + if (ICmpInst *ICI = dyn_cast(U->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // -smax(-x, -y) == smin(x, y). + return SE.getNegativeSCEV(SE.getSMaxExpr( + SE.getNegativeSCEV(getSCEV(LHS)), + SE.getNegativeSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)), + SE.getNotSCEV(getSCEV(RHS)))); + break; + default: + break; + } + } + + default: // We cannot analyze this expression. + break; } return SE.getUnknown(V); @@ -1980,8 +1996,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) { break; } case ICmpInst::ICMP_UGT: { - SCEVHandle TC = HowManyLessThans(SE.getNegativeSCEV(LHS), - SE.getNegativeSCEV(RHS), L, false); + SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS), + SE.getNotSCEV(RHS), L, false); if (!isa(TC)) return TC; break; } @@ -2045,7 +2061,7 @@ GetAddressedElementFromGlobal(GlobalVariable *GV, } /// ComputeLoadConstantCompareIterationCount - Given an exit condition of -/// 'icmp op load X, cst', try to se if we can compute the trip count. +/// 'icmp op load X, cst', try to see if we can compute the trip count. SCEVHandle ScalarEvolutionsImpl:: ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS, const Loop *L, @@ -2434,8 +2450,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) { // loop iterates. Compute this now. SCEVHandle IterationCount = getIterationCount(AddRec->getLoop()); if (IterationCount == UnknownValue) return UnknownValue; - IterationCount = getTruncateOrZeroExtend(IterationCount, - AddRec->getType(), SE); + IterationCount = SE.getTruncateOrZeroExtend(IterationCount, + AddRec->getType()); // If the value is affine, simplify the expression evaluation to just // Start + Step*IterationCount. @@ -2550,9 +2566,9 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { if (SCEVConstant *StartC = dyn_cast(Start)) { ConstantInt *StartCC = StartC->getValue(); Constant *StartNegC = ConstantExpr::getNeg(StartCC); - Constant *Rem = ConstantExpr::getSRem(StartNegC, StepC->getValue()); + Constant *Rem = ConstantExpr::getURem(StartNegC, StepC->getValue()); if (Rem->isNullValue()) { - Constant *Result =ConstantExpr::getSDiv(StartNegC,StepC->getValue()); + Constant *Result = ConstantExpr::getUDiv(StartNegC,StepC->getValue()); return SE.getUnknown(Result); } } @@ -2579,9 +2595,8 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) { // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE); - if (SCEVConstant *EvalVal = dyn_cast(Val)) - if (EvalVal->getValue()->isZero()) - return R1; // We found a quadratic root! + if (Val->isZero()) + return R1; // We found a quadratic root! } } }