const SCEV *
ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
- return getConstant(
- ConstantInt::get(cast<IntegerType>(Ty), V, isSigned));
+ const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
+ return getConstant(ConstantInt::get(ITy, V, isSigned));
}
const Type *SCEVConstant::getType() const { return V->getType(); }
}
LargeOps.push_back(T->getOperand());
} else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
- // This could be either sign or zero extension, but sign extension
- // is much more likely to be foldable here.
- LargeOps.push_back(getSignExtendExpr(C, SrcType));
+ LargeOps.push_back(getAnyExtendExpr(C, SrcType));
} else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
SmallVector<const SCEV *, 8> LargeMulOps;
for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
LargeMulOps.push_back(T->getOperand());
} else if (const SCEVConstant *C =
dyn_cast<SCEVConstant>(M->getOperand(j))) {
- // This could be either sign or zero extension, but sign extension
- // is much more likely to be foldable here.
- LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
+ LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
} else {
Ok = false;
break;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->equalsInt(1))
return LHS; // X udiv 1 --> x
- if (RHSC->getValue()->isZero())
- return getIntegerSCEV(0, LHS->getType()); // value is undefined
-
- // Determine if the division can be folded into the operands of
- // its operands.
- // TODO: Generalize this to non-constants by using known-bits information.
- const Type *Ty = LHS->getType();
- unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
- unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
- // For non-power-of-two values, effectively round the value up to the
- // nearest power of two.
- if (!RHSC->getValue()->getValue().isPowerOf2())
- ++MaxShiftAmt;
- const IntegerType *ExtTy =
- IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
- // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
- if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
- if (const SCEVConstant *Step =
- dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
- if (!Step->getValue()->getValue()
- .urem(RHSC->getValue()->getValue()) &&
- getZeroExtendExpr(AR, ExtTy) ==
- getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
- getZeroExtendExpr(Step, ExtTy),
- AR->getLoop())) {
- SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
- Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
- return getAddRecExpr(Operands, AR->getLoop());
- }
- // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
- if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
- SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
- Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
- if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
- // Find an operand that's safely divisible.
- for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
- const SCEV *Op = M->getOperand(i);
- const SCEV *Div = getUDivExpr(Op, RHSC);
- if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
- Operands = SmallVector<const SCEV *, 4>(M->op_begin(), M->op_end());
- Operands[i] = Div;
- return getMulExpr(Operands);
+ // If the denominator is zero, the result of the udiv is undefined. Don't
+ // try to analyze it, because the resolution chosen here may differ from
+ // the resolution chosen in other parts of the compiler.
+ if (!RHSC->getValue()->isZero()) {
+ // Determine if the division can be folded into the operands of
+ // its operands.
+ // TODO: Generalize this to non-constants by using known-bits information.
+ const Type *Ty = LHS->getType();
+ unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
+ unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
+ // For non-power-of-two values, effectively round the value up to the
+ // nearest power of two.
+ if (!RHSC->getValue()->getValue().isPowerOf2())
+ ++MaxShiftAmt;
+ const IntegerType *ExtTy =
+ IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
+ // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
+ if (const SCEVConstant *Step =
+ dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
+ if (!Step->getValue()->getValue()
+ .urem(RHSC->getValue()->getValue()) &&
+ getZeroExtendExpr(AR, ExtTy) ==
+ getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
+ getZeroExtendExpr(Step, ExtTy),
+ AR->getLoop())) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
+ Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
+ return getAddRecExpr(Operands, AR->getLoop());
}
+ // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
+ if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
+ Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
+ if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
+ // Find an operand that's safely divisible.
+ for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
+ const SCEV *Op = M->getOperand(i);
+ const SCEV *Div = getUDivExpr(Op, RHSC);
+ if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
+ Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
+ M->op_end());
+ Operands[i] = Div;
+ return getMulExpr(Operands);
+ }
+ }
+ }
+ // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
+ if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
+ SmallVector<const SCEV *, 4> Operands;
+ for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
+ Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
+ if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
+ Operands.clear();
+ for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
+ const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
+ if (isa<SCEVUDivExpr>(Op) ||
+ getMulExpr(Op, RHS) != A->getOperand(i))
+ break;
+ Operands.push_back(Op);
+ }
+ if (Operands.size() == A->getNumOperands())
+ return getAddExpr(Operands);
}
- }
- // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
- if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
- SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
- Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
- if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
- Operands.clear();
- for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
- const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
- if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
- break;
- Operands.push_back(Op);
- }
- if (Operands.size() == A->getNumOperands())
- return getAddExpr(Operands);
}
- }
- // Fold if both operands are constant.
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
- Constant *LHSCV = LHSC->getValue();
- Constant *RHSCV = RHSC->getValue();
- return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
- RHSCV)));
+ // Fold if both operands are constant.
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
+ Constant *LHSCV = LHSC->getValue();
+ Constant *RHSCV = RHSC->getValue();
+ return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
+ RHSCV)));
+ }
}
}
// Turn shift left of a constant amount into a multiply.
if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().uge(BitWidth))
+ break;
+
Constant *X = ConstantInt::get(getContext(),
- APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
+ APInt(BitWidth, 1).shl(SA->getZExtValue()));
return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
}
break;
// Turn logical shift right of a constant into a unsigned divide.
if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().uge(BitWidth))
+ break;
+
Constant *X = ConstantInt::get(getContext(),
- APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
+ APInt(BitWidth, 1).shl(SA->getZExtValue()));
return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
}
break;
case Instruction::AShr:
// For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
- if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
+ if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
if (L->getOpcode() == Instruction::Shl &&
L->getOperand(1) == U->getOperand(1)) {
- unsigned BitWidth = getTypeSizeInBits(U->getType());
+ uint64_t BitWidth = getTypeSizeInBits(U->getType());
+
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (CI->getValue().uge(BitWidth))
+ break;
+
uint64_t Amt = BitWidth - CI->getZExtValue();
if (Amt == BitWidth)
return getSCEV(L->getOperand(0)); // shift by zero --> noop
- if (Amt > BitWidth)
- return getIntegerSCEV(0, U->getType()); // value is undefined
return
getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
- IntegerType::get(getContext(), Amt)),
- U->getType());
+ IntegerType::get(getContext(),
+ Amt)),
+ U->getType());
}
break;
// fall through
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
- if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
- return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
- else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
- return getSMinExpr(getSCEV(LHS), getSCEV(RHS));
+ // a >s b ? a+x : b+x -> smax(a, b)+x
+ // a >s b ? b+x : a+x -> smin(a, b)+x
+ if (LHS->getType() == U->getType()) {
+ const SCEV *LS = getSCEV(LHS);
+ const SCEV *RS = getSCEV(RHS);
+ const SCEV *LA = getSCEV(U->getOperand(1));
+ const SCEV *RA = getSCEV(U->getOperand(2));
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMinExpr(LS, RS), LDiff);
+ }
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
// fall through
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
- if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
- return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
- else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
- return getUMinExpr(getSCEV(LHS), getSCEV(RHS));
+ // a >u b ? a+x : b+x -> umax(a, b)+x
+ // a >u b ? b+x : a+x -> umin(a, b)+x
+ if (LHS->getType() == U->getType()) {
+ const SCEV *LS = getSCEV(LHS);
+ const SCEV *RS = getSCEV(RHS);
+ const SCEV *LA = getSCEV(U->getOperand(1));
+ const SCEV *RA = getSCEV(U->getOperand(2));
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMinExpr(LS, RS), LDiff);
+ }
break;
case ICmpInst::ICMP_NE:
- // n != 0 ? n : 1 -> umax(n, 1)
- if (LHS == U->getOperand(1) &&
- isa<ConstantInt>(U->getOperand(2)) &&
- cast<ConstantInt>(U->getOperand(2))->isOne() &&
+ // n != 0 ? n+x : 1+x -> umax(n, 1)+x
+ if (LHS->getType() == U->getType() &&
isa<ConstantInt>(RHS) &&
- cast<ConstantInt>(RHS)->isZero())
- return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2)));
+ cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getConstant(LHS->getType(), 1);
+ const SCEV *LS = getSCEV(LHS);
+ const SCEV *LA = getSCEV(U->getOperand(1));
+ const SCEV *RA = getSCEV(U->getOperand(2));
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, One);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, One), LDiff);
+ }
break;
case ICmpInst::ICMP_EQ:
- // n == 0 ? 1 : n -> umax(n, 1)
- if (LHS == U->getOperand(2) &&
- isa<ConstantInt>(U->getOperand(1)) &&
- cast<ConstantInt>(U->getOperand(1))->isOne() &&
+ // n == 0 ? 1+x : n+x -> umax(n, 1)+x
+ if (LHS->getType() == U->getType() &&
isa<ConstantInt>(RHS) &&
- cast<ConstantInt>(RHS)->isZero())
- return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1)));
+ cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getConstant(LHS->getType(), 1);
+ const SCEV *LS = getSCEV(LHS);
+ const SCEV *LA = getSCEV(U->getOperand(1));
+ const SCEV *RA = getSCEV(U->getOperand(2));
+ const SCEV *LDiff = getMinusSCEV(LA, One);
+ const SCEV *RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, One), LDiff);
+ }
break;
default:
break;
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
}
+ // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
+ // adding or subtracting 1 from one of the operands.
+ switch (Cond) {
+ case ICmpInst::ICMP_SLE:
+ if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
+ /*HasNUW=*/false, /*HasNSW=*/true);
+ Cond = ICmpInst::ICMP_SLT;
+ } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
+ /*HasNUW=*/false, /*HasNSW=*/true);
+ Cond = ICmpInst::ICMP_SLT;
+ }
+ break;
+ case ICmpInst::ICMP_SGE:
+ if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
+ /*HasNUW=*/false, /*HasNSW=*/true);
+ Cond = ICmpInst::ICMP_SGT;
+ } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
+ /*HasNUW=*/false, /*HasNSW=*/true);
+ Cond = ICmpInst::ICMP_SGT;
+ }
+ break;
+ case ICmpInst::ICMP_ULE:
+ if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
+ /*HasNUW=*/true, /*HasNSW=*/false);
+ Cond = ICmpInst::ICMP_ULT;
+ } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
+ /*HasNUW=*/true, /*HasNSW=*/false);
+ Cond = ICmpInst::ICMP_ULT;
+ }
+ break;
+ case ICmpInst::ICMP_UGE:
+ if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
+ /*HasNUW=*/true, /*HasNSW=*/false);
+ Cond = ICmpInst::ICMP_UGT;
+ } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
+ LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
+ /*HasNUW=*/true, /*HasNSW=*/false);
+ Cond = ICmpInst::ICMP_UGT;
+ }
+ break;
+ default:
+ break;
+ }
+
switch (Cond) {
case ICmpInst::ICMP_NE: { // while (X != Y)
// Convert to: while (X-Y != 0)
return false;
}
+/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
+/// predicate Pred. Return true iff any changes were made.
+///
+bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
+ const SCEV *&LHS, const SCEV *&RHS) {
+ bool Changed = false;
+
+ // Canonicalize a constant to the right side.
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
+ // Check for both operands constant.
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+ if (ConstantExpr::getICmp(Pred,
+ LHSC->getValue(),
+ RHSC->getValue())->isNullValue())
+ goto trivially_false;
+ else
+ goto trivially_true;
+ }
+ // Otherwise swap the operands to put the constant on the right.
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ Changed = true;
+ }
+
+ // If we're comparing an addrec with a value which is loop-invariant in the
+ // addrec's loop, put the addrec on the left.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
+ if (LHS->isLoopInvariant(AR->getLoop())) {
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ Changed = true;
+ }
+
+ // If there's a constant operand, canonicalize comparisons with boundary
+ // cases, and canonicalize *-or-equal comparisons to regular comparisons.
+ if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
+ const APInt &RA = RC->getValue()->getValue();
+ switch (Pred) {
+ default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE:
+ break;
+ case ICmpInst::ICMP_UGE:
+ if ((RA - 1).isMinValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ Changed = true;
+ break;
+ }
+ if (RA.isMinValue()) goto trivially_true;
+
+ Pred = ICmpInst::ICMP_UGT;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_ULE:
+ if ((RA + 1).isMaxValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMinValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxValue()) goto trivially_true;
+
+ Pred = ICmpInst::ICMP_ULT;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_SGE:
+ if ((RA - 1).isMinSignedValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxSignedValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ Changed = true;
+ break;
+ }
+ if (RA.isMinSignedValue()) goto trivially_true;
+
+ Pred = ICmpInst::ICMP_SGT;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_SLE:
+ if ((RA + 1).isMaxSignedValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMinSignedValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxSignedValue()) goto trivially_true;
+
+ Pred = ICmpInst::ICMP_SLT;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ case ICmpInst::ICMP_UGT:
+ if (RA.isMinValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ Changed = true;
+ break;
+ }
+ if ((RA + 1).isMaxValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxValue()) goto trivially_false;
+ break;
+ case ICmpInst::ICMP_ULT:
+ if (RA.isMaxValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ Changed = true;
+ break;
+ }
+ if ((RA - 1).isMinValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMinValue()) goto trivially_false;
+ break;
+ case ICmpInst::ICMP_SGT:
+ if (RA.isMinSignedValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ Changed = true;
+ break;
+ }
+ if ((RA + 1).isMaxSignedValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ RHS = getConstant(RA + 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMaxSignedValue()) goto trivially_false;
+ break;
+ case ICmpInst::ICMP_SLT:
+ if (RA.isMaxSignedValue()) {
+ Pred = ICmpInst::ICMP_NE;
+ Changed = true;
+ break;
+ }
+ if ((RA - 1).isMinSignedValue()) {
+ Pred = ICmpInst::ICMP_EQ;
+ RHS = getConstant(RA - 1);
+ Changed = true;
+ break;
+ }
+ if (RA.isMinSignedValue()) goto trivially_false;
+ break;
+ }
+ }
+
+ // Check for obvious equality.
+ if (HasSameValue(LHS, RHS)) {
+ if (ICmpInst::isTrueWhenEqual(Pred))
+ goto trivially_true;
+ if (ICmpInst::isFalseWhenEqual(Pred))
+ goto trivially_false;
+ }
+
+ // TODO: More simplifications are possible here.
+
+ return Changed;
+
+trivially_true:
+ // Return 0 == 0.
+ LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
+ Pred = ICmpInst::ICMP_EQ;
+ return true;
+
+trivially_false:
+ // Return 0 != 0.
+ LHS = RHS = getConstant(Type::getInt1Ty(getContext()), 0);
+ Pred = ICmpInst::ICMP_NE;
+ return true;
+}
+
bool ScalarEvolution::isKnownNegative(const SCEV *S) {
return getSignedRange(S).getSignedMax().isNegative();
}
bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
+ // Canonicalize the inputs first.
+ (void)SimplifyICmpOperands(Pred, LHS, RHS);
+
// If LHS or RHS is an addrec, check to see if the condition is true in
// every iteration of the loop.
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
// Canonicalize the query to match the way instcombine will have
// canonicalized the comparison.
- // First, put a constant operand on the right.
- if (isa<SCEVConstant>(LHS)) {
- std::swap(LHS, RHS);
- Pred = ICmpInst::getSwappedPredicate(Pred);
- }
- // Then, canonicalize comparisons with boundary cases.
- if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
- const APInt &RA = RC->getValue()->getValue();
- switch (Pred) {
- default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
- case ICmpInst::ICMP_EQ:
- case ICmpInst::ICMP_NE:
- break;
- case ICmpInst::ICMP_UGE:
- if ((RA - 1).isMinValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA - 1);
- break;
- }
- if (RA.isMaxValue()) {
- Pred = ICmpInst::ICMP_EQ;
- break;
- }
- if (RA.isMinValue()) return true;
- break;
- case ICmpInst::ICMP_ULE:
- if ((RA + 1).isMaxValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA + 1);
- break;
- }
- if (RA.isMinValue()) {
- Pred = ICmpInst::ICMP_EQ;
- break;
- }
- if (RA.isMaxValue()) return true;
- break;
- case ICmpInst::ICMP_SGE:
- if ((RA - 1).isMinSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA - 1);
- break;
- }
- if (RA.isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- break;
- }
- if (RA.isMinSignedValue()) return true;
- break;
- case ICmpInst::ICMP_SLE:
- if ((RA + 1).isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- RHS = getConstant(RA + 1);
- break;
- }
- if (RA.isMinSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- break;
- }
- if (RA.isMaxSignedValue()) return true;
- break;
- case ICmpInst::ICMP_UGT:
- if (RA.isMinValue()) {
- Pred = ICmpInst::ICMP_NE;
- break;
- }
- if ((RA + 1).isMaxValue()) {
- Pred = ICmpInst::ICMP_EQ;
- RHS = getConstant(RA + 1);
- break;
- }
- if (RA.isMaxValue()) return false;
- break;
- case ICmpInst::ICMP_ULT:
- if (RA.isMaxValue()) {
- Pred = ICmpInst::ICMP_NE;
- break;
- }
- if ((RA - 1).isMinValue()) {
- Pred = ICmpInst::ICMP_EQ;
- RHS = getConstant(RA - 1);
- break;
- }
- if (RA.isMinValue()) return false;
- break;
- case ICmpInst::ICMP_SGT:
- if (RA.isMinSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- break;
- }
- if ((RA + 1).isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- RHS = getConstant(RA + 1);
- break;
- }
- if (RA.isMaxSignedValue()) return false;
- break;
- case ICmpInst::ICMP_SLT:
- if (RA.isMaxSignedValue()) {
- Pred = ICmpInst::ICMP_NE;
- break;
- }
- if ((RA - 1).isMinSignedValue()) {
- Pred = ICmpInst::ICMP_EQ;
- RHS = getConstant(RA - 1);
- break;
- }
- if (RA.isMinSignedValue()) return false;
- break;
- }
- }
+ if (SimplifyICmpOperands(Pred, LHS, RHS))
+ if (LHS == RHS)
+ return Pred == ICmpInst::ICMP_EQ;
+ if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
+ if (FoundLHS == FoundRHS)
+ return Pred == ICmpInst::ICMP_NE;
// Check to see if we can make the LHS or RHS match.
if (LHS == FoundRHS || RHS == FoundLHS) {