From 1975d03183966698650042e7a2bbd7198d276cfb Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Thu, 30 Oct 2008 20:40:10 +0000 Subject: [PATCH] Canonicalize sext(i1) to i1?-1:0, and update various instcombine optimizations accordingly. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@58457 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Support/PatternMatch.h | 46 ++++++ .../Scalar/InstructionCombining.cpp | 131 ++++++++++++------ test/Transforms/InstCombine/logical-select.ll | 26 +++- 3 files changed, 162 insertions(+), 41 deletions(-) diff --git a/include/llvm/Support/PatternMatch.h b/include/llvm/Support/PatternMatch.h index a3951e2dd39..2408103cb94 100644 --- a/include/llvm/Support/PatternMatch.h +++ b/include/llvm/Support/PatternMatch.h @@ -51,6 +51,22 @@ inline leaf_ty m_Value() { return leaf_ty(); } /// m_ConstantInt() - Match an arbitrary ConstantInt and ignore it. inline leaf_ty m_ConstantInt() { return leaf_ty(); } +struct constantint_ty { + int64_t Val; + explicit constantint_ty(int64_t val) : Val(val) {} + + template + bool match(ITy *V) { + return isa(V) && cast(V)->getSExtValue() == Val; + } +}; + +/// m_ConstantInt(int64_t) - Match a ConstantInt with a specific value +/// and ignore it. +inline constantint_ty m_ConstantInt(int64_t Val) { + return constantint_ty(Val); +} + struct zero_ty { template bool match(ITy *V) { @@ -321,6 +337,36 @@ m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) { FCmpInst, FCmpInst::Predicate>(Pred, L, R); } +//===----------------------------------------------------------------------===// +// Matchers for SelectInst classes +// + +template +struct SelectClass_match { + Cond_t C; + LHS_t L; + RHS_t R; + + SelectClass_match(const Cond_t &Cond, const LHS_t &LHS, + const RHS_t &RHS) + : C(Cond), L(LHS), R(RHS) {} + + template + bool match(OpTy *V) { + if (SelectInst *I = dyn_cast(V)) + return C.match(I->getOperand(0)) && + L.match(I->getOperand(1)) && + R.match(I->getOperand(2)); + return false; + } +}; + +template +inline SelectClass_match +m_Select(const Cond &C, const LHS &L, const RHS &R) { + return SelectClass_match(C, L, R); +} + //===----------------------------------------------------------------------===// // Matchers for CastInst classes // diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 4ec36ad1514..70b5aefa23c 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -2012,6 +2012,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { KnownZero, KnownOne)) return &I; } + + // zext(i1) - 1 -> select i1, 0, -1 + if (ZExtInst *ZI = dyn_cast(LHS)) + if (CI->isAllOnesValue() && + ZI->getOperand(0)->getType() == Type::Int1Ty) + return SelectInst::Create(ZI->getOperand(0), + Constant::getNullValue(I.getType()), + ConstantInt::getAllOnesValue(I.getType())); } if (isa(LHS)) @@ -4338,24 +4346,55 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (A & sext(C0)) | (B & ~sext(C0) -> C0 ? A : B - if (isa(C) && - cast(C)->getOperand(0)->getType() == Type::Int1Ty) { + // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants + if (match(A, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { + if (match(D, m_Not(m_Value(A)))) + return SelectInst::Create(cast(A)->getOperand(0), C, B); + if (match(B, m_Not(m_Value(A)))) + return SelectInst::Create(cast(A)->getOperand(0), C, D); + } + if (match(B, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { + if (match(C, m_Not(m_Value(B)))) + return SelectInst::Create(cast(B)->getOperand(0), A, D); + if (match(A, m_Not(m_Value(B)))) + return SelectInst::Create(cast(B)->getOperand(0), C, D); + } + if (match(C, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { if (match(D, m_Not(m_Value(C)))) return SelectInst::Create(cast(C)->getOperand(0), A, B); - // And commutes, try both ways. if (match(B, m_Not(m_Value(C)))) return SelectInst::Create(cast(C)->getOperand(0), A, D); } - // Or commutes, try both ways. - if (isa(D) && - cast(D)->getOperand(0)->getType() == Type::Int1Ty) { + if (match(D, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { if (match(C, m_Not(m_Value(D)))) return SelectInst::Create(cast(D)->getOperand(0), A, B); - // And commutes, try both ways. if (match(A, m_Not(m_Value(D)))) return SelectInst::Create(cast(D)->getOperand(0), C, B); } + if (match(A, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(D, m_Not(m_Value(A)))) + return SelectInst::Create(cast(A)->getOperand(0), B, C); + if (match(B, m_Not(m_Value(A)))) + return SelectInst::Create(cast(A)->getOperand(0), D, C); + } + if (match(B, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(C, m_Not(m_Value(B)))) + return SelectInst::Create(cast(B)->getOperand(0), D, A); + if (match(A, m_Not(m_Value(B)))) + return SelectInst::Create(cast(B)->getOperand(0), D, C); + } + if (match(C, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(D, m_Not(m_Value(C)))) + return SelectInst::Create(cast(C)->getOperand(0), B, A); + if (match(B, m_Not(m_Value(C)))) + return SelectInst::Create(cast(C)->getOperand(0), D, A); + } + if (match(D, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(C, m_Not(m_Value(D)))) + return SelectInst::Create(cast(D)->getOperand(0), B, A); + if (match(A, m_Not(m_Value(D)))) + return SelectInst::Create(cast(D)->getOperand(0), B, C); + } } // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. @@ -7965,37 +8004,11 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { Value *Src = CI.getOperand(0); - // sext (x ashr x, 31 -> all ones if signed - // sext (x >s -1) -> ashr x, 31 -> all ones if not signed - if (ICmpInst *ICI = dyn_cast(Src)) { - // If we are just checking for a icmp eq of a single bit and zext'ing it - // to an integer, then shift the bit to the appropriate place and then - // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); - - // sext (x x>>s31 true if signbit set. - // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){ - Value *In = ICI->getOperand(0); - Value *Sh = ConstantInt::get(In->getType(), - In->getType()->getPrimitiveSizeInBits()-1); - In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh, - In->getName()+".lobit"), - CI); - if (In->getType() != CI.getType()) - In = CastInst::CreateIntegerCast(In, CI.getType(), - true/*SExt*/, "tmp", &CI); - - if (ICI->getPredicate() == ICmpInst::ICMP_SGT) - In = InsertNewInstBefore(BinaryOperator::CreateNot(In, - In->getName()+".not"), CI); - - return ReplaceInstUsesWith(CI, In); - } - } - } + // Canonicalize sign-extend from i1 to a select. + if (Src->getType() == Type::Int1Ty) + return SelectInst::Create(Src, + ConstantInt::getAllOnesValue(CI.getType()), + Constant::getNullValue(CI.getType())); // See if the value being truncated is already sign extended. If so, just // eliminate the trunc/sext pair. @@ -8468,7 +8481,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // can be adjusted to fit the min/max idiom. We may edit ICI in // place here, so make sure the select is the only user. if (ICI->hasOneUse()) - if (ConstantInt *CI = dyn_cast(CmpRHS)) + if (ConstantInt *CI = dyn_cast(CmpRHS)) { switch (Pred) { default: break; case ICmpInst::ICMP_ULT: @@ -8513,6 +8526,44 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } + // (x ashr x, 31 -> all ones if signed + // (x >s -1) ? -1 : 0 -> ashr x, 31 -> all ones if not signed + CmpInst::Predicate Pred = ICI->getPredicate(); + if (match(TrueVal, m_ConstantInt(0)) && + match(FalseVal, m_ConstantInt(-1))) + Pred = CmpInst::getInversePredicate(Pred); + else if (!match(TrueVal, m_ConstantInt(-1)) || + !match(FalseVal, m_ConstantInt(0))) + Pred = CmpInst::BAD_ICMP_PREDICATE; + if (Pred != CmpInst::BAD_ICMP_PREDICATE) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + const APInt &Op1CV = CI->getValue(); + + // sext (x x>>s31 true if signbit set. + // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. + if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) || + (Pred == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())) { + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh, + In->getName()+".lobit"), + *ICI); + if (In->getType() != CI->getType()) + In = CastInst::CreateIntegerCast(In, CI->getType(), + true/*SExt*/, "tmp", ICI); + + if (Pred == ICmpInst::ICMP_SGT) + In = InsertNewInstBefore(BinaryOperator::CreateNot(In, + In->getName()+".not"), *ICI); + + return ReplaceInstUsesWith(SI, In); + } + } + } + if (CmpLHS == TrueVal && CmpRHS == FalseVal) { // Transform (X == Y) ? X : Y -> Y if (Pred == ICmpInst::ICMP_EQ) diff --git a/test/Transforms/InstCombine/logical-select.ll b/test/Transforms/InstCombine/logical-select.ll index 6369badee6c..39702d390ae 100644 --- a/test/Transforms/InstCombine/logical-select.ll +++ b/test/Transforms/InstCombine/logical-select.ll @@ -1,4 +1,7 @@ -; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep select | count 2 +; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t +; RUN grep select %t | count 4 +; RUN not grep and %t +; RUN not grep or %t define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { %e = icmp slt i32 %a, %b @@ -18,3 +21,24 @@ define i32 @bar(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { %j = or i32 %i, %g ret i32 %j } +define i32 @goo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { +entry: + %0 = icmp slt i32 %a, %b + %iftmp.0.0 = select i1 %0, i32 -1, i32 0 + %1 = and i32 %iftmp.0.0, %c + %not = xor i32 %iftmp.0.0, -1 + %2 = and i32 %not, %d + %3 = or i32 %1, %2 + ret i32 %3 +} + +define i32 @par(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { +entry: + %0 = icmp slt i32 %a, %b + %iftmp.1.0 = select i1 %0, i32 -1, i32 0 + %1 = and i32 %iftmp.1.0, %c + %not = xor i32 %iftmp.1.0, -1 + %2 = and i32 %not, %d + %3 = or i32 %1, %2 + ret i32 %3 +} -- 2.34.1