From 83bb0055fdac3c6234c4178cd429e6a917d06c4e Mon Sep 17 00:00:00 2001 From: Nick Lewycky Date: Thu, 22 Nov 2007 07:59:40 +0000 Subject: [PATCH] Instead of calculating constant factors, calculate the number of trailing bits. Patch from Wojciech Matyjewicz. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@44268 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Analysis/ScalarEvolution.cpp | 101 ++++++++++++++----------------- 1 file changed, 47 insertions(+), 54 deletions(-) diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index fed57f9d917..cc6cde2ba21 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1410,62 +1410,60 @@ SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) { return SE.getUnknown(PN); } -/// GetConstantFactor - Determine the largest constant factor that S has. For -/// example, turn {4,+,8} -> 4. (S umod result) should always equal zero. -static APInt GetConstantFactor(SCEVHandle S) { - if (SCEVConstant *C = dyn_cast(S)) { - const APInt& V = C->getValue()->getValue(); - if (!V.isMinValue()) - return V; - else // Zero is a multiple of everything. - return APInt::getHighBitsSet(C->getBitWidth(), 1); - } +/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is +/// guaranteed to end in (at every loop iteration). It is, at the same time, +/// the minimum number of times S is divisible by 2. For example, given {4,+,8} +/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. +static uint32_t GetMinTrailingZeros(SCEVHandle S) { + if (SCEVConstant *C = dyn_cast(S)) + // APInt::countTrailingZeros() returns the number of trailing zeros in its + // internal representation, which length may be greater than the represented + // value bitwidth. This is why we use a min operation here. + return std::min(C->getValue()->getValue().countTrailingZeros(), + C->getBitWidth()); if (SCEVTruncateExpr *T = dyn_cast(S)) - return GetConstantFactor(T->getOperand()).trunc( - cast(T->getType())->getBitWidth()); - if (SCEVZeroExtendExpr *E = dyn_cast(S)) - return GetConstantFactor(E->getOperand()).zext( - cast(E->getType())->getBitWidth()); - if (SCEVSignExtendExpr *E = dyn_cast(S)) - return GetConstantFactor(E->getOperand()).sext( - cast(E->getType())->getBitWidth()); - + return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth()); + + if (SCEVZeroExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes; + } + + if (SCEVSignExtendExpr *E = dyn_cast(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); + return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes; + } + if (SCEVAddExpr *A = dyn_cast(S)) { - // The result is the min of all operands. - APInt Res(GetConstantFactor(A->getOperand(0))); - for (unsigned i = 1, e = A->getNumOperands(); - i != e && Res.ugt(APInt(Res.getBitWidth(),1)); ++i) { - APInt Tmp(GetConstantFactor(A->getOperand(i))); - Res = APIntOps::umin(Res, Tmp); - } - return Res; + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; } if (SCEVMulExpr *M = dyn_cast(S)) { - // The result is the product of all the operands. - APInt Res(GetConstantFactor(M->getOperand(0))); - for (unsigned i = 1, e = M->getNumOperands(); i != e; ++i) { - APInt Tmp(GetConstantFactor(M->getOperand(i))); - Res *= Tmp; - } - return Res; + // The result is the sum of all operands results. + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); + uint32_t BitWidth = M->getBitWidth(); + for (unsigned i = 1, e = M->getNumOperands(); + SumOpRes != BitWidth && i != e; ++i) + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), + BitWidth); + return SumOpRes; } - + if (SCEVAddRecExpr *A = dyn_cast(S)) { - // For now, we just handle linear expressions. - if (A->getNumOperands() == 2) { - // We want the GCD between the start and the stride value. - APInt Start(GetConstantFactor(A->getOperand(0))); - if (Start == 1) - return Start; - APInt Stride(GetConstantFactor(A->getOperand(1))); - return APIntOps::GreatestCommonDivisor(Start, Stride); - } + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); + return MinOpRes; } - - // SCEVSDivExpr, SCEVUnknown. - return APInt(S->getBitWidth(), 1); + + // SCEVSDivExpr, SCEVUnknown + return 0; } /// createSCEV - We know that there is no SCEV for the specified value. @@ -1493,17 +1491,12 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { // // In order for this transformation to be safe, the LHS must be of the // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { SCEVHandle LHS = getSCEV(I->getOperand(0)); - APInt CommonFact(GetConstantFactor(LHS)); - assert(!CommonFact.isMinValue() && - "Common factor should at least be 1!"); const APInt &CIVal = CI->getValue(); - if (CommonFact.countTrailingZeros() >= + if (GetMinTrailingZeros(LHS) >= (CIVal.getBitWidth() - CIVal.countLeadingZeros())) - return SE.getAddExpr(LHS, - getSCEV(I->getOperand(1))); + return SE.getAddExpr(LHS, getSCEV(I->getOperand(1))); } break; case Instruction::Xor: -- 2.34.1