//===----------------------------------------------------------------------===//
#include "InstCombine.h"
-#include "llvm/Support/PatternMatch.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Support/PatternMatch.h"
using namespace llvm;
using namespace PatternMatch;
// If this is a non-volatile load or a cast from the same type,
// merge.
if (TI->isCast()) {
- if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType())
+ Type *FIOpndTy = FI->getOperand(0)->getType();
+ if (TI->getOperand(0)->getType() != FIOpndTy)
+ return 0;
+ // The select condition may be a vector. We may only change the operand
+ // type if the vector width remains the same (and matches the condition).
+ Type *CondTy = SI.getCondition()->getType();
+ if (CondTy->isVectorTy() && (!FIOpndTy->isVectorTy() ||
+ CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()))
return 0;
} else {
return 0; // unknown unary op.
return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp);
}
llvm_unreachable("Shouldn't get here");
- return 0;
}
static bool isSelect01(Constant *C1, Constant *C2) {
/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
/// replaced with RepOp.
static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
- const TargetData *TD) {
+ const DataLayout *TD,
+ const TargetLibraryInfo *TLI) {
// Trivial replacement.
if (V == Op)
return RepOp;
// If this is a binary operator, try to simplify it with the replaced op.
if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
if (B->getOperand(0) == Op)
- return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD);
+ return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI);
if (B->getOperand(1) == Op)
- return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD);
+ return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI);
}
// Same for CmpInsts.
if (CmpInst *C = dyn_cast<CmpInst>(I)) {
if (C->getOperand(0) == Op)
- return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD);
+ return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD,
+ TLI);
if (C->getOperand(1) == Op)
- return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD);
+ return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD,
+ TLI);
}
// TODO: We could hand off more cases to instsimplify here.
// All operands were constants, fold it.
if (ConstOps.size() == I->getNumOperands()) {
+ if (CmpInst *C = dyn_cast<CmpInst>(I))
+ return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+ ConstOps[1], TD, TLI);
+
if (LoadInst *LI = dyn_cast<LoadInst>(I))
if (!LI->isVolatile())
return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
- ConstOps, TD);
+ ConstOps, TD, TLI);
}
}
return 0;
}
+/// foldSelectICmpAndOr - We want to turn:
+/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
+/// into:
+/// (or (shl (and X, C1), C3), y)
+/// iff:
+/// C1 and C2 are both powers of 2
+/// where:
+/// C3 = Log(C2) - Log(C1)
+///
+/// This transform handles cases where:
+/// 1. The icmp predicate is inverted
+/// 2. The select operands are reversed
+/// 3. The magnitude of C2 and C1 are flipped
+static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
+ Value *FalseVal,
+ InstCombiner::BuilderTy *Builder) {
+ const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition());
+ if (!IC || !IC->isEquality())
+ return 0;
+
+ Value *CmpLHS = IC->getOperand(0);
+ Value *CmpRHS = IC->getOperand(1);
+
+ if (!match(CmpRHS, m_Zero()))
+ return 0;
+
+ Value *X;
+ const APInt *C1;
+ if (!match(CmpLHS, m_And(m_Value(X), m_Power2(C1))))
+ return 0;
+
+ const APInt *C2;
+ bool OrOnTrueVal = false;
+ bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)));
+ if (!OrOnFalseVal)
+ OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)));
+
+ if (!OrOnFalseVal && !OrOnTrueVal)
+ return 0;
+
+ Value *V = CmpLHS;
+ Value *Y = OrOnFalseVal ? TrueVal : FalseVal;
+
+ unsigned C1Log = C1->logBase2();
+ unsigned C2Log = C2->logBase2();
+ if (C2Log > C1Log) {
+ V = Builder->CreateZExtOrTrunc(V, Y->getType());
+ V = Builder->CreateShl(V, C2Log - C1Log);
+ } else if (C1Log > C2Log) {
+ V = Builder->CreateLShr(V, C1Log - C2Log);
+ V = Builder->CreateZExtOrTrunc(V, Y->getType());
+ } else
+ V = Builder->CreateZExtOrTrunc(V, Y->getType());
+
+ ICmpInst::Predicate Pred = IC->getPredicate();
+ if ((Pred == ICmpInst::ICMP_NE && OrOnFalseVal) ||
+ (Pred == ICmpInst::ICMP_EQ && OrOnTrueVal))
+ V = Builder->CreateXor(V, *C2);
+
+ return Builder->CreateOr(V, Y);
+}
+
/// visitSelectInstWithICmp - Visit a SelectInst that has an
/// ICmpInst as its first operand.
///
// arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
if (Pred == ICmpInst::ICMP_EQ) {
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD) == TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD) == TrueVal)
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
+ return ReplaceInstUsesWith(SI, FalseVal);
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
return ReplaceInstUsesWith(SI, FalseVal);
} else if (Pred == ICmpInst::ICMP_NE) {
- if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD) == FalseVal ||
- SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD) == FalseVal)
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
+ return ReplaceInstUsesWith(SI, TrueVal);
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
return ReplaceInstUsesWith(SI, TrueVal);
}
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
- if (isa<Constant>(CmpRHS)) {
+ if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
// Transform (X == C) ? X : Y -> (X == C) ? C : Y
SI.setOperand(1, CmpRHS);
}
}
+ if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder))
+ return ReplaceInstUsesWith(SI, V);
+
return Changed ? &SI : 0;
}
ConstantInt *FalseVal,
InstCombiner::BuilderTy *Builder) {
const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition());
- if (!IC || !IC->isEquality())
+ if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy())
return 0;
if (!match(IC->getOperand(1), m_Zero()))
ConstantInt *AndRHS;
Value *LHS = IC->getOperand(0);
- if (LHS->getType() != SI.getType() ||
- !match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS))))
+ if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS))))
return 0;
// If both select arms are non-zero see if we have a select of the form
unsigned ValZeros = ValC->getValue().logBase2();
unsigned AndZeros = AndRHS->getValue().logBase2();
- Value *V = LHS;
+ // If types don't match we can still convert the select by introducing a zext
+ // or a trunc of the 'and'. The trunc case requires that all of the truncated
+ // bits are zero, we can figure that out by looking at the 'and' mask.
+ if (AndZeros >= ValC->getBitWidth())
+ return 0;
+
+ Value *V = Builder->CreateZExtOrTrunc(LHS, SI.getType());
if (ValZeros > AndZeros)
V = Builder->CreateShl(V, ValZeros - AndZeros);
else if (ValZeros < AndZeros)
// Change: A = select B, false, C --> A = and !B, C
Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
return BinaryOperator::CreateAnd(NotCond, FalseVal);
- } else if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
+ }
+ if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
if (C->getZExtValue() == false) {
// Change: A = select B, C, false --> A = and B, C
return BinaryOperator::CreateAnd(CondVal, TrueVal);
// select a, a, b -> a|b
if (CondVal == TrueVal)
return BinaryOperator::CreateOr(CondVal, FalseVal);
- else if (CondVal == FalseVal)
+ if (CondVal == FalseVal)
return BinaryOperator::CreateAnd(CondVal, TrueVal);
+
+ // select a, ~a, b -> (~a)&b
+ // select a, b, ~a -> (~a)|b
+ if (match(TrueVal, m_Not(m_Specific(CondVal))))
+ return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+ if (match(FalseVal, m_Not(m_Specific(CondVal))))
+ return BinaryOperator::CreateOr(TrueVal, FalseVal);
}
// Selecting between two integer constants?
Value *NewFalseOp = NegVal;
if (AddOp != TI)
std::swap(NewTrueOp, NewFalseOp);
- Value *NewSel =
+ Value *NewSel =
Builder->CreateSelect(CondVal, NewTrueOp,
NewFalseOp, SI.getName() + ".p");
Value *LHS, *RHS, *LHS2, *RHS2;
if (SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS)) {
if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2))
- if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2,
+ if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2,
SI, SPF, RHS))
return R;
if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2))
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
if (TrueSI->getCondition() == CondVal) {
+ if (SI.getTrueValue() == TrueSI->getTrueValue())
+ return 0;
SI.setOperand(1, TrueSI->getTrueValue());
return &SI;
}
}
if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
if (FalseSI->getCondition() == CondVal) {
+ if (SI.getFalseValue() == FalseSI->getFalseValue())
+ return 0;
SI.setOperand(2, FalseSI->getFalseValue());
return &SI;
}
return &SI;
}
+ if (VectorType* VecTy = dyn_cast<VectorType>(SI.getType())) {
+ unsigned VWidth = VecTy->getNumElements();
+ APInt UndefElts(VWidth, 0);
+ APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+ if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
+ if (V != &SI)
+ return ReplaceInstUsesWith(SI, V);
+ return &SI;
+ }
+
+ if (isa<ConstantAggregateZero>(CondVal)) {
+ return ReplaceInstUsesWith(SI, FalseVal);
+ }
+ }
+
return 0;
}