X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FInstructionSimplify.cpp;h=9d78f8bf40441fafd023f08eebb9b126931459b7;hb=491a13691d3b30b8288dfc6e01ad6a58f69a4ce6;hp=fb51fa5315b219495b9f2ef25d130d074931cacc;hpb=1cd05bb605e3c3eee9197d3f10b628c60d0cc07a;p=oota-llvm.git diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index fb51fa5315b..9d78f8bf404 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -18,17 +18,20 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "instsimplify" +#include "llvm/Operator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/ConstantRange.h" #include "llvm/Support/PatternMatch.h" #include "llvm/Support/ValueHandle.h" #include "llvm/Target/TargetData.h" using namespace llvm; using namespace llvm::PatternMatch; -#define RecursionLimit 3 +enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); @@ -71,8 +74,9 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { /// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". /// Returns the simplified value, or null if no simplification was performed. static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, - unsigned OpcodeToExpand, const TargetData *TD, + unsigned OpcToExpand, const TargetData *TD, const DominatorTree *DT, unsigned MaxRecurse) { + Instruction::BinaryOps OpcodeToExpand = (Instruction::BinaryOps)OpcToExpand; // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return 0; @@ -133,8 +137,9 @@ static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// OpCodeToExtract is Mul then this tries to turn "(A*B)+(A*C)" into "A*(B+C)". /// Returns the simplified value, or null if no simplification was performed. static Value *FactorizeBinOp(unsigned Opcode, Value *LHS, Value *RHS, - unsigned OpcodeToExtract, const TargetData *TD, + unsigned OpcToExtract, const TargetData *TD, const DominatorTree *DT, unsigned MaxRecurse) { + Instruction::BinaryOps OpcodeToExtract = (Instruction::BinaryOps)OpcToExtract; // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return 0; @@ -201,10 +206,11 @@ static Value *FactorizeBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// SimplifyAssociativeBinOp - Generic simplifications for associative binary /// operations. Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(unsigned Opcode, Value *LHS, Value *RHS, +static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, const TargetData *TD, const DominatorTree *DT, unsigned MaxRecurse) { + Instruction::BinaryOps Opcode = (Instruction::BinaryOps)Opc; assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. @@ -391,17 +397,39 @@ static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, assert(isa(LHS) && "Not comparing with a select instruction!"); SelectInst *SI = cast(LHS); - // Now that we have "cmp select(cond, TV, FV), RHS", analyse it. + // Now that we have "cmp select(Cond, TV, FV), RHS", analyse it. // Does "cmp TV, RHS" simplify? if (Value *TCmp = SimplifyCmpInst(Pred, SI->getTrueValue(), RHS, TD, DT, - MaxRecurse)) + MaxRecurse)) { // It does! Does "cmp FV, RHS" simplify? if (Value *FCmp = SimplifyCmpInst(Pred, SI->getFalseValue(), RHS, TD, DT, - MaxRecurse)) + MaxRecurse)) { // It does! If they simplified to the same value, then use it as the // result of the original comparison. if (TCmp == FCmp) return TCmp; + Value *Cond = SI->getCondition(); + // If the false value simplified to false, then the result of the compare + // is equal to "Cond && TCmp". This also catches the case when the false + // value simplified to false and the true value to true, returning "Cond". + if (match(FCmp, m_Zero())) + if (Value *V = SimplifyAndInst(Cond, TCmp, TD, DT, MaxRecurse)) + return V; + // If the true value simplified to true, then the result of the compare + // is equal to "Cond || FCmp". + if (match(TCmp, m_One())) + if (Value *V = SimplifyOrInst(Cond, FCmp, TD, DT, MaxRecurse)) + return V; + // Finally, if the false value simplified to true and the true value to + // false, then the result of the compare is equal to "!Cond". + if (match(FCmp, m_One()) && match(TCmp, m_Zero())) + if (Value *V = + SimplifyXorInst(Cond, Constant::getAllOnesValue(Cond->getType()), + TD, DT, MaxRecurse)) + return V; + } + } + return 0; } @@ -506,7 +534,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } // X + undef -> undef - if (isa(Op1)) + if (match(Op1, m_Undef())) return Op1; // X + 0 -> X @@ -572,7 +600,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // X - undef -> undef // undef - X -> undef - if (isa(Op0) || isa(Op1)) + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) return UndefValue::get(Op0->getType()); // X - 0 -> X @@ -583,23 +611,85 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); - // (X + Y) - Y -> X - // (Y + X) - Y -> X + // (X*2) - X -> X + // (X<<1) - X -> X Value *X = 0; - if (match(Op0, m_Add(m_Value(X), m_Specific(Op1))) || - match(Op0, m_Add(m_Specific(Op1), m_Value(X)))) - return X; + if (match(Op0, m_Mul(m_Specific(Op1), m_ConstantInt<2>())) || + match(Op0, m_Shl(m_Specific(Op1), m_One()))) + return Op1; - /// i1 sub -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) - if (Value *V = SimplifyXorInst(Op0, Op1, TD, DT, MaxRecurse-1)) - return V; + // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. + // For example, (X + Y) - Y -> X; (Y + X) - Y -> X + Value *Y = 0, *Z = Op1; + if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z + // See if "V === Y - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, TD, DT, MaxRecurse-1)) + // It does! Now see if "X + V" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, X, V, TD, DT, + MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + // See if "V === X - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, TD, DT, MaxRecurse-1)) + // It does! Now see if "Y + V" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, TD, DT, + MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + } + + // X - (Y + Z) -> (X - Y) - Z or (X - Z) - Y if everything simplifies. + // For example, X - (X + 1) -> -1 + X = Op0; + if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z) + // See if "V === X - Y" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, TD, DT, MaxRecurse-1)) + // It does! Now see if "V - Z" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, TD, DT, + MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + // See if "V === X - Z" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, TD, DT, MaxRecurse-1)) + // It does! Now see if "V - Y" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, TD, DT, + MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } + } + + // Z - (X - Y) -> (Z - X) + Y if everything simplifies. + // For example, X - (X - Y) -> Y. + Z = Op0; + if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y) + // See if "V === Z - X" simplifies. + if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, TD, DT, MaxRecurse-1)) + // It does! Now see if "V + Y" simplifies. + if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, TD, DT, + MaxRecurse-1)) { + // It does, we successfully reassociated! + ++NumReassoc; + return W; + } // Mul distributes over Sub. Try some generic simplifications based on this. if (Value *V = FactorizeBinOp(Instruction::Sub, Op0, Op1, Instruction::Mul, TD, DT, MaxRecurse)) return V; + // i1 sub -> xor. + if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (Value *V = SimplifyXorInst(Op0, Op1, TD, DT, MaxRecurse-1)) + return V; + // Threading Sub over selects and phi nodes is pointless, so don't bother. // Threading over the select in "A - select(cond, B, C)" means evaluating // "A-B" and "A-C" and seeing if they are equal; but they are equal if and @@ -633,7 +723,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const TargetData *TD, } // X * undef -> 0 - if (isa(Op1)) + if (match(Op1, m_Undef())) return Constant::getNullValue(Op0->getType()); // X * 0 -> 0 @@ -644,7 +734,16 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const TargetData *TD, if (match(Op1, m_One())) return Op0; - /// i1 mul -> and. + // (X / Y) * Y -> X if the division is exact. + Value *X = 0, *Y = 0; + if ((match(Op0, m_IDiv(m_Value(X), m_Value(Y))) && Y == Op1) || // (X / Y) * Y + (match(Op1, m_IDiv(m_Value(X), m_Value(Y))) && Y == Op0)) { // Y * (X / Y) + BinaryOperator *Div = cast(Y == Op1 ? Op0 : Op1); + if (Div->isExact()) + return X; + } + + // i1 mul -> and. if (MaxRecurse && Op0->getType()->isIntegerTy(1)) if (Value *V = SimplifyAndInst(Op0, Op1, TD, DT, MaxRecurse-1)) return V; @@ -681,6 +780,356 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const TargetData *TD, return ::SimplifyMulInst(Op0, Op1, TD, DT, RecursionLimit); } +/// SimplifyDiv - Given operands for an SDiv or UDiv, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Constant *C0 = dyn_cast(Op0)) { + if (Constant *C1 = dyn_cast(Op1)) { + Constant *Ops[] = { C0, C1 }; + return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD); + } + } + + bool isSigned = Opcode == Instruction::SDiv; + + // X / undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + // undef / X -> 0 + if (match(Op0, m_Undef())) + return Constant::getNullValue(Op0->getType()); + + // 0 / X -> 0, we don't need to preserve faults! + if (match(Op0, m_Zero())) + return Op0; + + // X / 1 -> X + if (match(Op1, m_One())) + return Op0; + + if (Op0->getType()->isIntegerTy(1)) + // It can't be division by zero, hence it must be division by one. + return Op0; + + // X / X -> 1 + if (Op0 == Op1) + return ConstantInt::get(Op0->getType(), 1); + + // (X * Y) / Y -> X if the multiplication does not overflow. + Value *X = 0, *Y = 0; + if (match(Op0, m_Mul(m_Value(X), m_Value(Y))) && (X == Op1 || Y == Op1)) { + if (Y != Op1) std::swap(X, Y); // Ensure expression is (X * Y) / Y, Y = Op1 + BinaryOperator *Mul = cast(Op0); + // If the Mul knows it does not overflow, then we are good to go. + if ((isSigned && Mul->hasNoSignedWrap()) || + (!isSigned && Mul->hasNoUnsignedWrap())) + return X; + // If X has the form X = A / Y then X * Y cannot overflow. + if (BinaryOperator *Div = dyn_cast(X)) + if (Div->getOpcode() == Opcode && Div->getOperand(1) == Y) + return X; + } + + // (X rem Y) / Y -> 0 + if ((isSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || + (!isSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + return Constant::getNullValue(Op0->getType()); + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +/// SimplifySDivInst - Given operands for an SDiv, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifySDivInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT, unsigned MaxRecurse) { + if (Value *V = SimplifyDiv(Instruction::SDiv, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifySDivInst(Op0, Op1, TD, DT, RecursionLimit); +} + +/// SimplifyUDivInst - Given operands for a UDiv, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT, unsigned MaxRecurse) { + if (Value *V = SimplifyDiv(Instruction::UDiv, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifyUDivInst(Op0, Op1, TD, DT, RecursionLimit); +} + +static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const TargetData *, + const DominatorTree *, unsigned) { + // undef / X -> undef (the undef could be a snan). + if (match(Op0, m_Undef())) + return Op0; + + // X / undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + return 0; +} + +Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifyFDivInst(Op0, Op1, TD, DT, RecursionLimit); +} + +/// SimplifyRem - Given operands for an SRem or URem, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Constant *C0 = dyn_cast(Op0)) { + if (Constant *C1 = dyn_cast(Op1)) { + Constant *Ops[] = { C0, C1 }; + return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD); + } + } + + // X % undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + // undef % X -> 0 + if (match(Op0, m_Undef())) + return Constant::getNullValue(Op0->getType()); + + // 0 % X -> 0, we don't need to preserve faults! + if (match(Op0, m_Zero())) + return Op0; + + // X % 0 -> undef, we don't need to preserve faults! + if (match(Op1, m_Zero())) + return UndefValue::get(Op0->getType()); + + // X % 1 -> 0 + if (match(Op1, m_One())) + return Constant::getNullValue(Op0->getType()); + + if (Op0->getType()->isIntegerTy(1)) + // It can't be remainder by zero, hence it must be remainder by one. + return Constant::getNullValue(Op0->getType()); + + // X % X -> 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +/// SimplifySRemInst - Given operands for an SRem, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifySRemInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT, unsigned MaxRecurse) { + if (Value *V = SimplifyRem(Instruction::SRem, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifySRemInst(Op0, Op1, TD, DT, RecursionLimit); +} + +/// SimplifyURemInst - Given operands for a URem, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyURemInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT, unsigned MaxRecurse) { + if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifyURemInst(Op0, Op1, TD, DT, RecursionLimit); +} + +static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const TargetData *, + const DominatorTree *, unsigned) { + // undef % X -> undef (the undef could be a snan). + if (match(Op0, m_Undef())) + return Op0; + + // X % undef -> undef + if (match(Op1, m_Undef())) + return Op1; + + return 0; +} + +Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, const TargetData *TD, + const DominatorTree *DT) { + return ::SimplifyFRemInst(Op0, Op1, TD, DT, RecursionLimit); +} + +/// SimplifyShift - Given operands for an Shl, LShr or AShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Constant *C0 = dyn_cast(Op0)) { + if (Constant *C1 = dyn_cast(Op1)) { + Constant *Ops[] = { C0, C1 }; + return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD); + } + } + + // 0 shift by X -> 0 + if (match(Op0, m_Zero())) + return Op0; + + // X shift by 0 -> X + if (match(Op1, m_Zero())) + return Op0; + + // X shift by undef -> undef because it may shift by the bitwidth. + if (match(Op1, m_Undef())) + return Op1; + + // Shifting by the bitwidth or more is undefined. + if (ConstantInt *CI = dyn_cast(Op1)) + if (CI->getValue().getLimitedValue() >= + Op0->getType()->getScalarSizeInBits()) + return UndefValue::get(Op0->getType()); + + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa(Op0) || isa(Op1)) + if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + return 0; +} + +/// SimplifyShlInst - Given operands for an Shl, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // undef << X -> 0 + if (match(Op0, m_Undef())) + return Constant::getNullValue(Op0->getType()); + + // (X >> A) << A -> X + Value *X; + if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1))) && + cast(Op0)->isExact()) + return X; + return 0; +} + +Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, + const TargetData *TD, const DominatorTree *DT) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, TD, DT, RecursionLimit); +} + +/// SimplifyLShrInst - Given operands for an LShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Instruction::LShr, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // undef >>l X -> 0 + if (match(Op0, m_Undef())) + return Constant::getNullValue(Op0->getType()); + + // (X << A) >> A -> X + Value *X; + if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && + cast(Op0)->hasNoUnsignedWrap()) + return X; + + return 0; +} + +Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, + const TargetData *TD, const DominatorTree *DT) { + return ::SimplifyLShrInst(Op0, Op1, isExact, TD, DT, RecursionLimit); +} + +/// SimplifyAShrInst - Given operands for an AShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, + const TargetData *TD, const DominatorTree *DT, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Instruction::AShr, Op0, Op1, TD, DT, MaxRecurse)) + return V; + + // all ones >>a X -> all ones + if (match(Op0, m_AllOnes())) + return Op0; + + // undef >>a X -> all ones + if (match(Op0, m_Undef())) + return Constant::getAllOnesValue(Op0->getType()); + + // (X << A) >> A -> X + Value *X; + if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && + cast(Op0)->hasNoSignedWrap()) + return X; + + return 0; +} + +Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, + const TargetData *TD, const DominatorTree *DT) { + return ::SimplifyAShrInst(Op0, Op1, isExact, TD, DT, RecursionLimit); +} + /// SimplifyAndInst - Given operands for an And, see if we can /// fold the result. If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const TargetData *TD, @@ -697,7 +1146,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const TargetData *TD, } // X & undef -> 0 - if (isa(Op1)) + if (match(Op1, m_Undef())) return Constant::getNullValue(Op0->getType()); // X & X = X @@ -713,12 +1162,12 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const TargetData *TD, return Op0; // A & ~A = ~A & A = 0 - Value *A = 0, *B = 0; - if ((match(Op0, m_Not(m_Value(A))) && A == Op1) || - (match(Op1, m_Not(m_Value(A))) && A == Op0)) + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) return Constant::getNullValue(Op0->getType()); // (A | ?) & A = A + Value *A = 0, *B = 0; if (match(Op0, m_Or(m_Value(A), m_Value(B))) && (A == Op1 || B == Op1)) return Op1; @@ -786,7 +1235,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const TargetData *TD, } // X | undef -> -1 - if (isa(Op1)) + if (match(Op1, m_Undef())) return Constant::getAllOnesValue(Op0->getType()); // X | X = X @@ -802,12 +1251,12 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const TargetData *TD, return Op1; // A | ~A = ~A | A = -1 - Value *A = 0, *B = 0; - if ((match(Op0, m_Not(m_Value(A))) && A == Op1) || - (match(Op1, m_Not(m_Value(A))) && A == Op0)) + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); // (A & ?) | A = A + Value *A = 0, *B = 0; if (match(Op0, m_And(m_Value(A), m_Value(B))) && (A == Op1 || B == Op1)) return Op1; @@ -817,6 +1266,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const TargetData *TD, (A == Op0 || B == Op0)) return Op0; + // ~(A & ?) | A = -1 + if (match(Op0, m_Not(m_And(m_Value(A), m_Value(B)))) && + (A == Op1 || B == Op1)) + return Constant::getAllOnesValue(Op1->getType()); + + // A | ~(A & ?) = -1 + if (match(Op1, m_Not(m_And(m_Value(A), m_Value(B)))) && + (A == Op0 || B == Op0)) + return Constant::getAllOnesValue(Op0->getType()); + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, TD, DT, MaxRecurse)) @@ -870,7 +1329,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const TargetData *TD, } // A ^ undef -> undef - if (isa(Op1)) + if (match(Op1, m_Undef())) return Op1; // A ^ 0 = A @@ -882,9 +1341,8 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const TargetData *TD, return Constant::getNullValue(Op0->getType()); // A ^ ~A = ~A ^ A = -1 - Value *A = 0; - if ((match(Op0, m_Not(m_Value(A))) && A == Op1) || - (match(Op1, m_Not(m_Value(A))) && A == Op0)) + if (match(Op0, m_Not(m_Specific(Op1))) || + match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); // Try some generic simplifications for associative operations. @@ -918,6 +1376,26 @@ static const Type *GetCompareTy(Value *Op) { return CmpInst::makeCmpResultType(Op->getType()); } +/// ExtractEquivalentCondition - Rummage around inside V looking for something +/// equivalent to the comparison "LHS Pred RHS". Return such a value if found, +/// otherwise return null. Helper function for analyzing max/min idioms. +static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, + Value *LHS, Value *RHS) { + SelectInst *SI = dyn_cast(V); + if (!SI) + return 0; + CmpInst *Cmp = dyn_cast(SI->getCondition()); + if (!Cmp) + return 0; + Value *CmpLHS = Cmp->getOperand(0), *CmpRHS = Cmp->getOperand(1); + if (Pred == Cmp->getPredicate() && LHS == CmpLHS && RHS == CmpRHS) + return Cmp; + if (Pred == CmpInst::getSwappedPredicate(Cmp->getPredicate()) && + LHS == CmpRHS && RHS == CmpLHS) + return Cmp; + return 0; +} + /// SimplifyICmpInst - Given operands for an ICmpInst, see if we can /// fold the result. If not, this returns null. static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -935,8 +1413,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Pred = CmpInst::getSwappedPredicate(Pred); } - // ITy - This is the return type of the compare we're considering. - const Type *ITy = GetCompareTy(LHS); + const Type *ITy = GetCompareTy(LHS); // The return type. + const Type *OpTy = LHS->getType(); // The operand type. // icmp X, X -> true/false // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false @@ -944,38 +1422,658 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (LHS == RHS || isa(RHS)) return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); - // icmp , - Global/Stack value - // addresses never equal each other! We already know that Op0 != Op1. - if ((isa(LHS) || isa(LHS) || - isa(LHS)) && - (isa(RHS) || isa(RHS) || - isa(RHS))) + // Special case logic when the operands have i1 type. + if (OpTy->isIntegerTy(1) || (OpTy->isVectorTy() && + cast(OpTy)->getElementType()->isIntegerTy(1))) { + switch (Pred) { + default: break; + case ICmpInst::ICMP_EQ: + // X == 1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_NE: + // X != 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGT: + // X >u 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGE: + // X >=u 1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_SLT: + // X X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_SLE: + // X <=s -1 -> X + if (match(RHS, m_One())) + return LHS; + break; + } + } + + // icmp , - Different stack variables have + // different addresses, and what's more the address of a stack variable is + // never null or equal to the address of a global. Note that generalizing + // to the case where LHS is a global variable address or null is pointless, + // since if both LHS and RHS are constants then we already constant folded + // the compare, and if only one of them is then we moved it to RHS already. + if (isa(LHS) && (isa(RHS) || isa(RHS) || + isa(RHS))) + // We already know that LHS != RHS. return ConstantInt::get(ITy, CmpInst::isFalseWhenEqual(Pred)); - // See if we are doing a comparison with a constant. - if (ConstantInt *CI = dyn_cast(RHS)) { - // If we have an icmp le or icmp ge instruction, turn it into the - // appropriate icmp lt or icmp gt instruction. This allows us to rely on - // them being folded in the code below. + // If we are comparing with zero then try hard since this is a common case. + if (match(RHS, m_Zero())) { + bool LHSKnownNonNegative, LHSKnownNegative; switch (Pred) { - default: break; + default: + assert(false && "Unknown ICmp predicate!"); + case ICmpInst::ICMP_ULT: + // getNullValue also works for vectors, unlike getFalse. + return Constant::getNullValue(ITy); + case ICmpInst::ICMP_UGE: + // getAllOnesValue also works for vectors, unlike getTrue. + return ConstantInt::getAllOnesValue(ITy); + case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: - if (CI->isMaxValue(false)) // A <=u MAX -> TRUE - return ConstantInt::getTrue(CI->getContext()); + if (isKnownNonZero(LHS, TD)) + return Constant::getNullValue(ITy); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + if (isKnownNonZero(LHS, TD)) + return ConstantInt::getAllOnesValue(ITy); + break; + case ICmpInst::ICMP_SLT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getAllOnesValue(ITy); + if (LHSKnownNonNegative) + return Constant::getNullValue(ITy); break; case ICmpInst::ICMP_SLE: - if (CI->isMaxValue(true)) // A <=s MAX -> TRUE - return ConstantInt::getTrue(CI->getContext()); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return ConstantInt::getAllOnesValue(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, TD)) + return Constant::getNullValue(ITy); + break; + case ICmpInst::ICMP_SGE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return Constant::getNullValue(ITy); + if (LHSKnownNonNegative) + return ConstantInt::getAllOnesValue(ITy); break; + case ICmpInst::ICMP_SGT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD); + if (LHSKnownNegative) + return Constant::getNullValue(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, TD)) + return ConstantInt::getAllOnesValue(ITy); + break; + } + } + + // See if we are doing a comparison with a constant integer. + if (ConstantInt *CI = dyn_cast(RHS)) { + // Rule out tautological comparisons (eg., ult 0 or uge 0). + ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue()); + if (RHS_CR.isEmptySet()) + return ConstantInt::getFalse(CI->getContext()); + if (RHS_CR.isFullSet()) + return ConstantInt::getTrue(CI->getContext()); + + // Many binary operators with constant RHS have easy to compute constant + // range. Use them to check whether the comparison is a tautology. + uint32_t Width = CI->getBitWidth(); + APInt Lower = APInt(Width, 0); + APInt Upper = APInt(Width, 0); + ConstantInt *CI2; + if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) { + // 'urem x, CI2' produces [0, CI2). + Upper = CI2->getValue(); + } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) { + // 'srem x, CI2' produces (-|CI2|, |CI2|). + Upper = CI2->getValue().abs(); + Lower = (-Upper) + 1; + } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) { + // 'udiv x, CI2' produces [0, UINT_MAX / CI2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (!CI2->isZero()) + Upper = NegOne.udiv(CI2->getValue()) + 1; + } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { + // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2]. + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + APInt Val = CI2->getValue().abs(); + if (!Val.isMinValue()) { + Lower = IntMin.sdiv(Val); + Upper = IntMax.sdiv(Val) + 1; + } + } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { + // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (CI2->getValue().ult(Width)) + Upper = NegOne.lshr(CI2->getValue()) + 1; + } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) { + // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2]. + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (CI2->getValue().ult(Width)) { + Lower = IntMin.ashr(CI2->getValue()); + Upper = IntMax.ashr(CI2->getValue()) + 1; + } + } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) { + // 'or x, CI2' produces [CI2, UINT_MAX]. + Lower = CI2->getValue(); + } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) { + // 'and x, CI2' produces [0, CI2]. + Upper = CI2->getValue() + 1; + } + if (Lower != Upper) { + ConstantRange LHS_CR = ConstantRange(Lower, Upper); + if (RHS_CR.contains(LHS_CR)) + return ConstantInt::getTrue(RHS->getContext()); + if (RHS_CR.inverse().contains(LHS_CR)) + return ConstantInt::getFalse(RHS->getContext()); + } + } + + // Compare of cast, for example (zext X) != 0 -> X != 0 + if (isa(LHS) && (isa(RHS) || isa(RHS))) { + Instruction *LI = cast(LHS); + Value *SrcOp = LI->getOperand(0); + const Type *SrcTy = SrcOp->getType(); + const Type *DstTy = LI->getType(); + + // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input + // if the integer type is the same size as the pointer type. + if (MaxRecurse && TD && isa(LI) && + TD->getPointerSizeInBits() == DstTy->getPrimitiveSizeInBits()) { + if (Constant *RHSC = dyn_cast(RHS)) { + // Transfer the cast to the constant. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, + ConstantExpr::getIntToPtr(RHSC, SrcTy), + TD, DT, MaxRecurse-1)) + return V; + } else if (PtrToIntInst *RI = dyn_cast(RHS)) { + if (RI->getOperand(0)->getType() == SrcTy) + // Compare without the cast. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + TD, DT, MaxRecurse-1)) + return V; + } + } + + if (isa(LHS)) { + // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the + // same type. + if (ZExtInst *RI = dyn_cast(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that signed predicates become unsigned. + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, RI->getOperand(0), TD, DT, + MaxRecurse-1)) + return V; + } + // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two zero-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, Trunc, TD, DT, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit + // there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: + assert(false && "Unknown ICmp predicate!"); + // LHS getContext()); + + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return ConstantInt::getTrue(CI->getContext()); + + // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS + // is non-negative then LHS getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + } + } + } + } + + if (isa(LHS)) { + // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the + // same type. + if (SExtInst *RI = dyn_cast(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that the predicate does not change. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + TD, DT, MaxRecurse-1)) + return V; + } + // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two sign-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, TD, DT, + MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are all equal, while RHS has varying + // bits there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: + assert(false && "Unknown ICmp predicate!"); + case ICmpInst::ICMP_EQ: + return ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_NE: + return ConstantInt::getTrue(CI->getContext()); + + // If RHS is non-negative then LHS s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + + // If LHS is non-negative then LHS u RHS. + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // Comparison is true iff the LHS =s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, + Constant::getNullValue(SrcTy), + TD, DT, MaxRecurse-1)) + return V; + break; + } + } + } + } + } + + // Special logic for binary operators. + BinaryOperator *LBO = dyn_cast(LHS); + BinaryOperator *RBO = dyn_cast(RHS); + if (MaxRecurse && (LBO || RBO)) { + // Analyze the case when either LHS or RHS is an add instruction. + Value *A = 0, *B = 0, *C = 0, *D = 0; + // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null). + bool NoLHSWrapProblem = false, NoRHSWrapProblem = false; + if (LBO && LBO->getOpcode() == Instruction::Add) { + A = LBO->getOperand(0); B = LBO->getOperand(1); + NoLHSWrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); + } + if (RBO && RBO->getOpcode() == Instruction::Add) { + C = RBO->getOperand(0); D = RBO->getOperand(1); + NoRHSWrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); + } + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == RHS || B == RHS) && NoLHSWrapProblem) + if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, + Constant::getNullValue(RHS->getType()), + TD, DT, MaxRecurse-1)) + return V; + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == LHS || D == LHS) && NoRHSWrapProblem) + if (Value *V = SimplifyICmpInst(Pred, + Constant::getNullValue(LHS->getType()), + C == LHS ? D : C, TD, DT, MaxRecurse-1)) + return V; + + // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && + NoLHSWrapProblem && NoRHSWrapProblem) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y = (A == C || A == D) ? B : A; + Value *Z = (C == A || C == B) ? D : C; + if (Value *V = SimplifyICmpInst(Pred, Y, Z, TD, DT, MaxRecurse-1)) + return V; + } + } + + if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { + bool KnownNonNegative, KnownNegative; + switch (Pred) { + default: + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, TD); + if (!KnownNonNegative) + break; + // fall-through + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (CI->isMinValue(false)) // A >=u MIN -> TRUE - return ConstantInt::getTrue(CI->getContext()); + // getNullValue also works for vectors, unlike getFalse. + return Constant::getNullValue(ITy); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, TD); + if (!KnownNonNegative) + break; + // fall-through + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + // getAllOnesValue also works for vectors, unlike getTrue. + return Constant::getAllOnesValue(ITy); + } + } + if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { + bool KnownNonNegative, KnownNegative; + switch (Pred) { + default: break; + case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (CI->isMinValue(true)) // A >=s MIN -> TRUE - return ConstantInt::getTrue(CI->getContext()); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, TD); + if (!KnownNonNegative) + break; + // fall-through + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // getAllOnesValue also works for vectors, unlike getTrue. + return Constant::getAllOnesValue(ITy); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, TD); + if (!KnownNonNegative) + break; + // fall-through + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + // getNullValue also works for vectors, unlike getFalse. + return Constant::getNullValue(ITy); + } + } + + if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && + LBO->getOperand(1) == RBO->getOperand(1)) { + switch (LBO->getOpcode()) { + default: break; + case Instruction::UDiv: + case Instruction::LShr: + if (ICmpInst::isSigned(Pred)) + break; + // fall-through + case Instruction::SDiv: + case Instruction::AShr: + if (!LBO->isExact() || !RBO->isExact()) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), TD, DT, MaxRecurse-1)) + return V; + break; + case Instruction::Shl: { + bool NUW = LBO->hasNoUnsignedWrap() && LBO->hasNoUnsignedWrap(); + bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && ICmpInst::isSigned(Pred)) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), TD, DT, MaxRecurse-1)) + return V; + break; + } + } + } + + // Simplify comparisons involving max/min. + Value *A, *B; + CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; + CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". + + // Signed variants on "max(a,b)>=a -> true". + if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { + if (A != RHS) std::swap(A, B); // smax(A, B) pred A. + EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". + // We analyze this as smax(A, B) pred A. + P = Pred; + } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) std::swap(A, B); // A pred smax(A, B). + EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". + // We analyze this as smax(A, B) swapped-pred A. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && + (A == RHS || B == RHS)) { + if (A != RHS) std::swap(A, B); // smin(A, B) pred A. + EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". + // We analyze this as smax(-A, -B) swapped-pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) std::swap(A, B); // A pred smin(A, B). + EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". + // We analyze this as smax(-A, -B) pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = Pred; + } + if (P != CmpInst::BAD_ICMP_PREDICATE) { + // Cases correspond to "max(A, B) p A". + switch (P) { + default: + break; + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_SLE: + // Equivalent to "A EqP B". This may be the same as the condition tested + // in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + return V; + // Otherwise, see if "A EqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(EqP, A, B, TD, DT, MaxRecurse-1)) + return V; + break; + case CmpInst::ICMP_NE: + case CmpInst::ICMP_SGT: { + CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); + // Equivalent to "A InvEqP B". This may be the same as the condition + // tested in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + return V; + // Otherwise, see if "A InvEqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, TD, DT, MaxRecurse-1)) + return V; + break; + } + case CmpInst::ICMP_SGE: + // Always true. + return Constant::getAllOnesValue(ITy); + case CmpInst::ICMP_SLT: + // Always false. + return Constant::getNullValue(ITy); + } + } + + // Unsigned variants on "max(a,b)>=a -> true". + P = CmpInst::BAD_ICMP_PREDICATE; + if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { + if (A != RHS) std::swap(A, B); // umax(A, B) pred A. + EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". + // We analyze this as umax(A, B) pred A. + P = Pred; + } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) std::swap(A, B); // A pred umax(A, B). + EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". + // We analyze this as umax(A, B) swapped-pred A. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && + (A == RHS || B == RHS)) { + if (A != RHS) std::swap(A, B); // umin(A, B) pred A. + EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". + // We analyze this as umax(-A, -B) swapped-pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = CmpInst::getSwappedPredicate(Pred); + } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) && + (A == LHS || B == LHS)) { + if (A != LHS) std::swap(A, B); // A pred umin(A, B). + EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". + // We analyze this as umax(-A, -B) pred -A. + // Note that we do not need to actually form -A or -B thanks to EqP. + P = Pred; + } + if (P != CmpInst::BAD_ICMP_PREDICATE) { + // Cases correspond to "max(A, B) p A". + switch (P) { + default: + break; + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_ULE: + // Equivalent to "A EqP B". This may be the same as the condition tested + // in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) + return V; + // Otherwise, see if "A EqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(EqP, A, B, TD, DT, MaxRecurse-1)) + return V; + break; + case CmpInst::ICMP_NE: + case CmpInst::ICMP_UGT: { + CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); + // Equivalent to "A InvEqP B". This may be the same as the condition + // tested in the max/min; if so, we can just return that. + if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) + return V; + if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) + return V; + // Otherwise, see if "A InvEqP B" simplifies. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, TD, DT, MaxRecurse-1)) + return V; break; } + case CmpInst::ICMP_UGE: + // Always true. + return Constant::getAllOnesValue(ITy); + case CmpInst::ICMP_ULT: + // Always false. + return Constant::getNullValue(ITy); + } + } + + // Variants on "max(x,y) >= min(x,z)". + Value *C, *D; + if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && + match(RHS, m_SMin(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // max(x, ?) pred min(x, ?). + if (Pred == CmpInst::ICMP_SGE) + // Always true. + return Constant::getAllOnesValue(ITy); + if (Pred == CmpInst::ICMP_SLT) + // Always false. + return Constant::getNullValue(ITy); + } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && + match(RHS, m_SMax(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // min(x, ?) pred max(x, ?). + if (Pred == CmpInst::ICMP_SLE) + // Always true. + return Constant::getAllOnesValue(ITy); + if (Pred == CmpInst::ICMP_SGT) + // Always false. + return Constant::getNullValue(ITy); + } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && + match(RHS, m_UMin(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // max(x, ?) pred min(x, ?). + if (Pred == CmpInst::ICMP_UGE) + // Always true. + return Constant::getAllOnesValue(ITy); + if (Pred == CmpInst::ICMP_ULT) + // Always false. + return Constant::getNullValue(ITy); + } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && + match(RHS, m_UMax(m_Value(C), m_Value(D))) && + (A == C || A == D || B == C || B == D)) { + // min(x, ?) pred max(x, ?). + if (Pred == CmpInst::ICMP_ULE) + // Always true. + return Constant::getAllOnesValue(ITy); + if (Pred == CmpInst::ICMP_UGT) + // Always false. + return Constant::getNullValue(ITy); } // If the comparison is with the result of a select instruction, check whether @@ -1203,15 +2301,28 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const TargetData *TD, const DominatorTree *DT, unsigned MaxRecurse) { switch (Opcode) { - case Instruction::Add: return SimplifyAddInst(LHS, RHS, /* isNSW */ false, - /* isNUW */ false, TD, DT, - MaxRecurse); - case Instruction::Sub: return SimplifySubInst(LHS, RHS, /* isNSW */ false, - /* isNUW */ false, TD, DT, - MaxRecurse); - case Instruction::Mul: return SimplifyMulInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::Add: + return SimplifyAddInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, + TD, DT, MaxRecurse); + case Instruction::Sub: + return SimplifySubInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, + TD, DT, MaxRecurse); + case Instruction::Mul: return SimplifyMulInst (LHS, RHS, TD, DT, MaxRecurse); + case Instruction::SDiv: return SimplifySDivInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::UDiv: return SimplifyUDivInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::FDiv: return SimplifyFDivInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::SRem: return SimplifySRemInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::URem: return SimplifyURemInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::FRem: return SimplifyFRemInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::Shl: + return SimplifyShlInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, + TD, DT, MaxRecurse); + case Instruction::LShr: + return SimplifyLShrInst(LHS, RHS, /*isExact*/false, TD, DT, MaxRecurse); + case Instruction::AShr: + return SimplifyAShrInst(LHS, RHS, /*isExact*/false, TD, DT, MaxRecurse); case Instruction::And: return SimplifyAndInst(LHS, RHS, TD, DT, MaxRecurse); - case Instruction::Or: return SimplifyOrInst(LHS, RHS, TD, DT, MaxRecurse); + case Instruction::Or: return SimplifyOrInst (LHS, RHS, TD, DT, MaxRecurse); case Instruction::Xor: return SimplifyXorInst(LHS, RHS, TD, DT, MaxRecurse); default: if (Constant *CLHS = dyn_cast(LHS)) @@ -1288,6 +2399,40 @@ Value *llvm::SimplifyInstruction(Instruction *I, const TargetData *TD, case Instruction::Mul: Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), TD, DT); break; + case Instruction::SDiv: + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::UDiv: + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::FDiv: + Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::SRem: + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::URem: + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::FRem: + Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), TD, DT); + break; + case Instruction::Shl: + Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), + cast(I)->hasNoSignedWrap(), + cast(I)->hasNoUnsignedWrap(), + TD, DT); + break; + case Instruction::LShr: + Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), + cast(I)->isExact(), + TD, DT); + break; + case Instruction::AShr: + Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), + cast(I)->isExact(), + TD, DT); + break; case Instruction::And: Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), TD, DT); break;