From aa03649af255fbbb049f393a2cf7d533da86d951 Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Sat, 14 Feb 2009 02:31:09 +0000 Subject: [PATCH] Extend the IndVarSimplify support for promoting induction variables: - Test for signed and unsigned wrapping conditions, instead of just testing for non-negative induction ranges. - Handle loops with GT comparisons, in addition to LT comparisons. - Support more cases of induction variables that don't start at 0. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@64532 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Scalar/IndVarSimplify.cpp | 178 +++++++++++++----- .../promote-iv-to-eliminate-casts.ll | 38 ++++ 2 files changed, 172 insertions(+), 44 deletions(-) diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index b4b6ea35577..92a501fda8e 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -458,33 +458,98 @@ static const Type *getEffectiveIndvarType(const PHINode *Phi) { return Ty; } -/// isOrigIVAlwaysNonNegative - Analyze the original induction variable -/// in the loop to determine whether it would ever have a negative -/// value. +/// TestOrigIVForWrap - Analyze the original induction variable +/// in the loop to determine whether it would ever undergo signed +/// or unsigned overflow. /// /// TODO: This duplicates a fair amount of ScalarEvolution logic. -/// Perhaps this can be merged with ScalarEvolution::getIterationCount. +/// Perhaps this can be merged with ScalarEvolution::getIterationCount +/// and/or ScalarEvolution::get{Sign,Zero}ExtendExpr. /// -static bool isOrigIVAlwaysNonNegative(const Loop *L, - const Instruction *OrigCond) { +static void TestOrigIVForWrap(const Loop *L, + const BranchInst *BI, + const Instruction *OrigCond, + bool &NoSignedWrap, + bool &NoUnsignedWrap) { // Verify that the loop is sane and find the exit condition. const ICmpInst *Cmp = dyn_cast(OrigCond); - if (!Cmp) return false; + if (!Cmp) return; + + const Value *CmpLHS = Cmp->getOperand(0); + const Value *CmpRHS = Cmp->getOperand(1); + const BasicBlock *TrueBB = BI->getSuccessor(0); + const BasicBlock *FalseBB = BI->getSuccessor(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); + + // Canonicalize a constant to the RHS. + if (isa(CmpLHS)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + // Canonicalize SLE to SLT. + if (Pred == ICmpInst::ICMP_SLE) + if (const ConstantInt *CI = dyn_cast(CmpRHS)) + if (!CI->getValue().isMaxSignedValue()) { + CmpRHS = ConstantInt::get(CI->getValue() + 1); + Pred = ICmpInst::ICMP_SLT; + } + // Canonicalize SGT to SGE. + if (Pred == ICmpInst::ICMP_SGT) + if (const ConstantInt *CI = dyn_cast(CmpRHS)) + if (!CI->getValue().isMaxSignedValue()) { + CmpRHS = ConstantInt::get(CI->getValue() + 1); + Pred = ICmpInst::ICMP_SGE; + } + // Canonicalize SGE to SLT. + if (Pred == ICmpInst::ICMP_SGE) { + std::swap(TrueBB, FalseBB); + Pred = ICmpInst::ICMP_SLT; + } + // Canonicalize ULE to ULT. + if (Pred == ICmpInst::ICMP_ULE) + if (const ConstantInt *CI = dyn_cast(CmpRHS)) + if (!CI->getValue().isMaxValue()) { + CmpRHS = ConstantInt::get(CI->getValue() + 1); + Pred = ICmpInst::ICMP_ULT; + } + // Canonicalize UGT to UGE. + if (Pred == ICmpInst::ICMP_UGT) + if (const ConstantInt *CI = dyn_cast(CmpRHS)) + if (!CI->getValue().isMaxValue()) { + CmpRHS = ConstantInt::get(CI->getValue() + 1); + Pred = ICmpInst::ICMP_UGE; + } + // Canonicalize UGE to ULT. + if (Pred == ICmpInst::ICMP_UGE) { + std::swap(TrueBB, FalseBB); + Pred = ICmpInst::ICMP_ULT; + } + // For now, analyze only LT loops for signed overflow. + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_ULT) + return; - // For now, analyze only SLT loops for signed overflow. - if (Cmp->getPredicate() != ICmpInst::ICMP_SLT) return false; + bool isSigned = Pred == ICmpInst::ICMP_SLT; - // Get the increment instruction. Look past SExtInsts if we will + // Get the increment instruction. Look past casts if we will // be able to prove that the original induction variable doesn't - // undergo signed overflow. - const Value *OrigIncrVal = Cmp->getOperand(0); - const Value *IncrVal = OrigIncrVal; - if (SExtInst *SI = dyn_cast(Cmp->getOperand(0))) { - if (!isa(Cmp->getOperand(1)) || - !cast(Cmp->getOperand(1))->getValue() - .isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits())) - return false; - IncrVal = SI->getOperand(0); + // undergo signed or unsigned overflow, respectively. + const Value *IncrVal = CmpLHS; + if (isSigned) { + if (const SExtInst *SI = dyn_cast(CmpLHS)) { + if (!isa(CmpRHS) || + !cast(CmpRHS)->getValue() + .isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits())) + return; + IncrVal = SI->getOperand(0); + } + } else { + if (const ZExtInst *ZI = dyn_cast(CmpLHS)) { + if (!isa(CmpRHS) || + !cast(CmpRHS)->getValue() + .isIntN(IncrVal->getType()->getPrimitiveSizeInBits())) + return; + IncrVal = ZI->getOperand(0); + } } // For now, only analyze induction variables that have simple increments. @@ -493,32 +558,36 @@ static bool isOrigIVAlwaysNonNegative(const Loop *L, IncrOp->getOpcode() != Instruction::Add || !isa(IncrOp->getOperand(1)) || !cast(IncrOp->getOperand(1))->equalsInt(1)) - return false; + return; // Make sure the PHI looks like a normal IV. const PHINode *PN = dyn_cast(IncrOp->getOperand(0)); if (!PN || PN->getNumIncomingValues() != 2) - return false; + return; unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); unsigned BackEdge = !IncomingEdge; if (!L->contains(PN->getIncomingBlock(BackEdge)) || PN->getIncomingValue(BackEdge) != IncrOp) - return false; + return; + if (!L->contains(TrueBB)) + return; // For now, only analyze loops with a constant start value, so that - // we can easily determine if the start value is non-negative and - // not a maximum value which would wrap on the first iteration. + // we can easily determine if the start value is not a maximum value + // which would wrap on the first iteration. const Value *InitialVal = PN->getIncomingValue(IncomingEdge); - if (!isa(InitialVal) || - cast(InitialVal)->getValue().isNegative() || - cast(InitialVal)->getValue().isMaxSignedValue()) - return false; + if (!isa(InitialVal)) + return; - // The original induction variable will start at some non-negative - // non-max value, it counts up by one, and the loop iterates only - // while it remans less than (signed) some value in the same type. - // As such, it will always be non-negative. - return true; + // The original induction variable will start at some non-max value, + // it counts up by one, and the loop iterates only while it remans + // less than some value in the same type. As such, it will never wrap. + if (isSigned && + !cast(InitialVal)->getValue().isMaxSignedValue()) + NoSignedWrap = true; + else if (!isSigned && + !cast(InitialVal)->getValue().isMaxValue()) + NoUnsignedWrap = true; } bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { @@ -596,13 +665,15 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { // If we have a trip count expression, rewrite the loop's exit condition // using it. We can currently only handle loops with a single exit. - bool OrigIVAlwaysNonNegative = false; + bool NoSignedWrap = false; + bool NoUnsignedWrap = false; if (!isa(IterationCount) && ExitingBlock) // Can't rewrite non-branch yet. if (BranchInst *BI = dyn_cast(ExitingBlock->getTerminator())) { if (Instruction *OrigCond = dyn_cast(BI->getCondition())) { - // Determine if the OrigIV will ever have a non-zero sign bit. - OrigIVAlwaysNonNegative = isOrigIVAlwaysNonNegative(L, OrigCond); + // Determine if the OrigIV will ever undergo overflow. + TestOrigIVForWrap(L, BI, OrigCond, + NoSignedWrap, NoUnsignedWrap); // We'll be replacing the original condition, so it'll be dead. DeadInsts.insert(OrigCond); @@ -642,19 +713,38 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { /// If the new canonical induction variable is wider than the original, /// and the original has uses that are casts to wider types, see if the /// truncate and extend can be omitted. - if (isa(NewVal)) + if (PN->getType() != LargestType) for (Value::use_iterator UI = PN->use_begin(), UE = PN->use_end(); - UI != UE; ++UI) - if (isa(UI) || - (isa(UI) && OrigIVAlwaysNonNegative)) { - Value *TruncIndVar = IndVar; - if (TruncIndVar->getType() != UI->getType()) - TruncIndVar = new TruncInst(IndVar, UI->getType(), "truncindvar", - InsertPt); + UI != UE; ++UI) { + if (isa(UI) && NoSignedWrap) { + SCEVHandle ExtendedStart = + SE->getSignExtendExpr(cast(IndVars.back().second)->getStart(), LargestType); + SCEVHandle ExtendedStep = + SE->getSignExtendExpr(cast(IndVars.back().second)->getStepRecurrence(*SE), LargestType); + SCEVHandle ExtendedAddRec = + SE->getAddRecExpr(ExtendedStart, ExtendedStep, L); + if (LargestType != UI->getType()) + ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType()); + Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt); + UI->replaceAllUsesWith(TruncIndVar); + if (Instruction *DeadUse = dyn_cast(*UI)) + DeadInsts.insert(DeadUse); + } + if (isa(UI) && NoUnsignedWrap) { + SCEVHandle ExtendedStart = + SE->getZeroExtendExpr(cast(IndVars.back().second)->getStart(), LargestType); + SCEVHandle ExtendedStep = + SE->getZeroExtendExpr(cast(IndVars.back().second)->getStepRecurrence(*SE), LargestType); + SCEVHandle ExtendedAddRec = + SE->getAddRecExpr(ExtendedStart, ExtendedStep, L); + if (LargestType != UI->getType()) + ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType()); + Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt); UI->replaceAllUsesWith(TruncIndVar); if (Instruction *DeadUse = dyn_cast(*UI)) DeadInsts.insert(DeadUse); } + } // Replace the old PHI Node with the inserted computation. PN->replaceAllUsesWith(NewVal); diff --git a/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll b/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll index 703fce4e292..08b08f200ec 100644 --- a/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll +++ b/test/Transforms/IndVarsSimplify/promote-iv-to-eliminate-casts.ll @@ -60,3 +60,41 @@ bb1.return_crit_edge: ; preds = %bb1 return: ; preds = %bb1.return_crit_edge, %entry ret void } + +; Test cases from PR1301: + +define void @kinds__srangezero([21 x i32]* nocapture %a) nounwind { +bb.thread: + br label %bb + +bb: ; preds = %bb, %bb.thread + %i.0.reg2mem.0 = phi i8 [ -10, %bb.thread ], [ %tmp7, %bb ] ; [#uses=2] + %tmp12 = sext i8 %i.0.reg2mem.0 to i32 ; [#uses=1] + %tmp4 = add i32 %tmp12, 10 ; [#uses=1] + %tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4 ; [#uses=1] + store i32 0, i32* %tmp5 + %tmp7 = add i8 %i.0.reg2mem.0, 1 ; [#uses=2] + %0 = icmp sgt i8 %tmp7, 10 ; [#uses=1] + br i1 %0, label %return, label %bb + +return: ; preds = %bb + ret void +} + +define void @kinds__urangezero([21 x i32]* nocapture %a) nounwind { +bb.thread: + br label %bb + +bb: ; preds = %bb, %bb.thread + %i.0.reg2mem.0 = phi i8 [ 10, %bb.thread ], [ %tmp7, %bb ] ; [#uses=2] + %tmp12 = sext i8 %i.0.reg2mem.0 to i32 ; [#uses=1] + %tmp4 = add i32 %tmp12, -10 ; [#uses=1] + %tmp5 = getelementptr [21 x i32]* %a, i32 0, i32 %tmp4 ; [#uses=1] + store i32 0, i32* %tmp5 + %tmp7 = add i8 %i.0.reg2mem.0, 1 ; [#uses=2] + %0 = icmp sgt i8 %tmp7, 30 ; [#uses=1] + br i1 %0, label %return, label %bb + +return: ; preds = %bb + ret void +} -- 2.34.1