X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineSelect.cpp;h=69380fc41d3a1b83a95f69c12a7d996dea6a7019;hb=dd27f99713e60722ebe8c66216b5917b237b3303;hp=83a2f2c563f33249a7fd979853379a339a22c2ba;hpb=dcf39d2586da38fe4410d8406a0abe4e3c7860c1;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 83a2f2c563f..69380fc41d3 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/PatternMatch.h" @@ -314,8 +314,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const DataLayout *TD, const TargetLibraryInfo *TLI, - DominatorTree *DT, - AssumptionTracker *AT) { + DominatorTree *DT, AssumptionCache *AC) { // Trivial replacement. if (V == Op) return RepOp; @@ -336,10 +335,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (CmpInst *C = dyn_cast(I)) { if (C->getOperand(0) == Op) return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI, DT, AT); + TLI, DT, AC); if (C->getOperand(1) == Op) return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI, DT, AT); + TLI, DT, AC); } // TODO: We could hand off more cases to instsimplify here. @@ -438,6 +437,66 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, return Builder->CreateOr(V, Y); } +/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single +/// call to cttz/ctlz with flag 'is_zero_undef' cleared. +/// +/// For example, we can fold the following code sequence: +/// \code +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true) +/// %1 = icmp ne i32 %x, 0 +/// %2 = select i1 %1, i32 %0, i32 32 +/// \code +/// +/// into: +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) +static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, + InstCombiner::BuilderTy *Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + + // Check if the condition value compares a value for equality against zero. + if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) + return nullptr; + + Value *Count = FalseVal; + Value *ValueOnZero = TrueVal; + if (Pred == ICmpInst::ICMP_NE) + std::swap(Count, ValueOnZero); + + // Skip zero extend/truncate. + Value *V = nullptr; + if (match(Count, m_ZExt(m_Value(V))) || + match(Count, m_Trunc(m_Value(V)))) + Count = V; + + // Check if the value propagated on zero is a constant number equal to the + // sizeof in bits of 'Count'. + unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); + if (!match(ValueOnZero, m_SpecificInt(SizeOfInBits))) + return nullptr; + + // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the + // input to the cttz/ctlz is used as LHS for the compare instruction. + if (match(Count, m_Intrinsic(m_Specific(CmpLHS))) || + match(Count, m_Intrinsic(m_Specific(CmpLHS)))) { + IntrinsicInst *II = cast(Count); + IRBuilder<> Builder(II); + if (cast(II->getArgOperand(1))->isOne()) { + // Explicitly clear the 'undef_on_zero' flag. + IntrinsicInst *NewI = cast(II->clone()); + Type *Ty = NewI->getArgOperand(1)->getType(); + NewI->setArgOperand(1, Constant::getNullValue(Ty)); + Builder.Insert(NewI); + Count = NewI; + } + + return Builder.CreateZExtOrTrunc(Count, ValueOnZero->getType()); + } + + return nullptr; +} + /// visitSelectInstWithICmp - Visit a SelectInst that has an /// ICmpInst as its first operand. /// @@ -580,26 +639,26 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, - DT, AT) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, - DT, AT) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + TrueVal) return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, - DT, AT) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, - DT, AT) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + FalseVal) return ReplaceInstUsesWith(SI, FalseVal); } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, - DT, AT) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, - DT, AT) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + FalseVal) return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, - DT, AT) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, - DT, AT) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + TrueVal) return ReplaceInstUsesWith(SI, TrueVal); } @@ -617,26 +676,44 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - { + if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) { + APInt MinSignedValue = APInt::getSignBit(BitWidth); Value *X; const APInt *Y, *C; - if (match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) && + bool TrueWhenUnset; + bool IsBitTest = false; + if (ICmpInst::isEquality(Pred) && + match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) && match(CmpRHS, m_Zero())) { + IsBitTest = true; + TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = false; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = true; + } + if (IsBitTest) { Value *V = nullptr; // (X & Y) == 0 ? X : X ^ Y --> X & ~Y - if (Pred == ICmpInst::ICMP_EQ && TrueVal == X && + if (TrueWhenUnset && TrueVal == X && match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) V = Builder->CreateAnd(X, ~(*Y)); // (X & Y) != 0 ? X ^ Y : X --> X & ~Y - else if (Pred == ICmpInst::ICMP_NE && FalseVal == X && + else if (!TrueWhenUnset && FalseVal == X && match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) V = Builder->CreateAnd(X, ~(*Y)); // (X & Y) == 0 ? X ^ Y : X --> X | Y - else if (Pred == ICmpInst::ICMP_EQ && FalseVal == X && + else if (TrueWhenUnset && FalseVal == X && match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) V = Builder->CreateOr(X, *Y); // (X & Y) != 0 ? X : X ^ Y --> X | Y - else if (Pred == ICmpInst::ICMP_NE && TrueVal == X && + else if (!TrueWhenUnset && TrueVal == X && match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) V = Builder->CreateOr(X, *Y); @@ -648,6 +725,9 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) return ReplaceInstUsesWith(SI, V); + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) + return ReplaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -836,8 +916,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, - DT, AT)) + if (Value *V = + SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) return ReplaceInstUsesWith(SI, V); if (SI.getType()->isIntegerTy(1)) {