function names start with a lower case letter ; NFC
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index d2fbcdd39915c8d0cba67be02aa9e66945becba7..51219bcb0b7ba6fdf90b223be9273647b6ace7da 100644 (file)
@@ -38,7 +38,8 @@ getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) {
   }
 }
 
-static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) {
+static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF,
+                                                   bool Ordered=false) {
   switch (SPF) {
   default:
     llvm_unreachable("unhandled!");
@@ -51,17 +52,22 @@ static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) {
     return ICmpInst::ICMP_SGT;
   case SPF_UMAX:
     return ICmpInst::ICMP_UGT;
+  case SPF_FMINNUM:
+    return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
+  case SPF_FMAXNUM:
+    return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
   }
 }
 
 static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder,
                                           SelectPatternFlavor SPF, Value *A,
                                           Value *B) {
-  CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
+  CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF);
+  assert(CmpInst::isIntPredicate(Pred));
   return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B);
 }
 
-/// GetSelectFoldableOperands - We want to turn code that looks like this:
+/// We want to turn code that looks like this:
 ///   %C = or %A, %B
 ///   %D = select %cond, %C, %A
 /// into:
@@ -90,8 +96,8 @@ static unsigned GetSelectFoldableOperands(Instruction *I) {
   }
 }
 
-/// GetSelectFoldableConstant - For the same transformation as the previous
-/// function, return the identity constant that goes into the select.
+/// For the same transformation as the previous function, return the identity
+/// constant that goes into the select.
 static Constant *GetSelectFoldableConstant(Instruction *I) {
   switch (I->getOpcode()) {
   default: llvm_unreachable("This cannot happen!");
@@ -110,7 +116,7 @@ static Constant *GetSelectFoldableConstant(Instruction *I) {
   }
 }
 
-/// FoldSelectOpOp - Here we have (select c, TI, FI), and we know that TI and FI
+/// Here we have (select c, TI, FI), and we know that TI and FI
 /// have the same opcode and only one use each.  Try to simplify this.
 Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
                                           Instruction *FI) {
@@ -197,8 +203,8 @@ static bool isSelect01(Constant *C1, Constant *C2) {
          C2I->isOne() || C2I->isAllOnesValue();
 }
 
-/// FoldSelectIntoOp - Try fold the select into one of the operands to
-/// facilitate further optimization.
+/// Try to fold the select into one of the operands to allow further
+/// optimization.
 Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
                                             Value *FalseVal) {
   // See the comment above GetSelectFoldableOperands for a description of the
@@ -276,73 +282,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
   return nullptr;
 }
 
-/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
-/// replaced with RepOp.
-static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const TargetLibraryInfo *TLI,
-                                     const DataLayout &DL, DominatorTree *DT,
-                                     AssumptionCache *AC) {
-  // Trivial replacement.
-  if (V == Op)
-    return RepOp;
-
-  Instruction *I = dyn_cast<Instruction>(V);
-  if (!I)
-    return nullptr;
-
-  // If this is a binary operator, try to simplify it with the replaced op.
-  if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
-    if (B->getOperand(0) == Op)
-      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), DL, TLI);
-    if (B->getOperand(1) == Op)
-      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, DL, TLI);
-  }
-
-  // Same for CmpInsts.
-  if (CmpInst *C = dyn_cast<CmpInst>(I)) {
-    if (C->getOperand(0) == Op)
-      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), DL,
-                             TLI, DT, AC);
-    if (C->getOperand(1) == Op)
-      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, DL,
-                             TLI, DT, AC);
-  }
-
-  // TODO: We could hand off more cases to instsimplify here.
-
-  // If all operands are constant after substituting Op for RepOp then we can
-  // constant fold the instruction.
-  if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
-    // Build a list of all constant operands.
-    SmallVector<Constant*, 8> ConstOps;
-    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
-      if (I->getOperand(i) == Op)
-        ConstOps.push_back(CRepOp);
-      else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
-        ConstOps.push_back(COp);
-      else
-        break;
-    }
-
-    // All operands were constants, fold it.
-    if (ConstOps.size() == I->getNumOperands()) {
-      if (CmpInst *C = dyn_cast<CmpInst>(I))
-        return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
-                                               ConstOps[1], DL, TLI);
-
-      if (LoadInst *LI = dyn_cast<LoadInst>(I))
-        if (!LI->isVolatile())
-          return ConstantFoldLoadFromConstPtr(ConstOps[0], DL);
-
-      return ConstantFoldInstOperands(I->getOpcode(), I->getType(), ConstOps,
-                                      DL, TLI);
-    }
-  }
-
-  return nullptr;
-}
-
-/// foldSelectICmpAndOr - We want to turn:
+/// We want to turn:
 ///   (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
 /// into:
 ///   (or (shl (and X, C1), C3), y)
@@ -460,9 +400,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
   return nullptr;
 }
 
-/// visitSelectInstWithICmp - Visit a SelectInst that has an
-/// ICmpInst as its first operand.
-///
+/// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
                                                    ICmpInst *ICI) {
   bool Changed = false;
@@ -477,14 +415,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   // here, so make sure the select is the only user.
   if (ICI->hasOneUse())
     if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) {
-      // X < MIN ? T : F  -->  F
-      if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT)
-          && CI->isMinValue(Pred == ICmpInst::ICMP_SLT))
-        return ReplaceInstUsesWith(SI, FalseVal);
-      // X > MAX ? T : F  -->  F
-      else if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT)
-               && CI->isMaxValue(Pred == ICmpInst::ICMP_SGT))
-        return ReplaceInstUsesWith(SI, FalseVal);
       switch (Pred) {
       default: break;
       case ICmpInst::ICMP_ULT:
@@ -598,33 +528,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  // If we have an equality comparison then we know the value in one of the
-  // arms of the select. See if substituting this value into the arm and
-  // simplifying the result yields the same value as the other arm.
-  if (Pred == ICmpInst::ICMP_EQ) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
-            TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
-            TrueVal)
-      return ReplaceInstUsesWith(SI, FalseVal);
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
-            FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
-            FalseVal)
-      return ReplaceInstUsesWith(SI, FalseVal);
-  } else if (Pred == ICmpInst::ICMP_NE) {
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
-            FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
-            FalseVal)
-      return ReplaceInstUsesWith(SI, TrueVal);
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
-            TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
-            TrueVal)
-      return ReplaceInstUsesWith(SI, TrueVal);
-  }
-
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
 
   if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
@@ -639,7 +542,8 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) {
+  {
+    unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType());
     APInt MinSignedValue = APInt::getSignBit(BitWidth);
     Value *X;
     const APInt *Y, *C;
@@ -695,10 +599,9 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
 }
 
 
-/// CanSelectOperandBeMappingIntoPredBlock - SI is a select whose condition is a
-/// PHI node (but the two may be in different blocks).  See if the true/false
-/// values (V) are live in all of the predecessor blocks of the PHI.  For
-/// example, cases like this cannot be mapped:
+/// SI is a select whose condition is a PHI node (but the two may be in
+/// different blocks). See if the true/false values (V) are live in all of the
+/// predecessor blocks of the PHI. For example, cases like this can't be mapped:
 ///
 ///   X = phi [ C1, BB1], [C2, BB2]
 ///   Y = add
@@ -732,7 +635,7 @@ static bool CanSelectOperandBeMappingIntoPredBlock(const Value *V,
   return false;
 }
 
-/// FoldSPFofSPF - We have an SPF (e.g. a min or max) of an SPF of the form:
+/// We have an SPF (e.g. a min or max) of an SPF of the form:
 ///   SPF2(SPF1(A, B), C)
 Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
                                         SelectPatternFlavor SPF1,
@@ -845,10 +748,10 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
   return nullptr;
 }
 
-/// foldSelectICmpAnd - If one of the constants is zero (we know they can't
-/// both be) and we have an icmp instruction with zero, and we have an 'and'
-/// with the non-constant value and a power of two we can turn the select
-/// into a shift on the result of the 'and'.
+/// If one of the constants is zero (we know they can't both be) and we have an
+/// icmp instruction with zero, and we have an 'and' with the non-constant value
+/// and a power of two we can turn the select into a shift on the result of the
+/// 'and'.
 static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal,
                                 ConstantInt *FalseVal,
                                 InstCombiner::BuilderTy *Builder) {
@@ -1026,6 +929,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X
       if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
         FCmpInst::Predicate InvPred = FCI->getInversePredicate();
+        IRBuilder<>::FastMathFlagGuard FMFG(*Builder);
+        Builder->setFastMathFlags(FCI->getFastMathFlags());
         Value *NewCond = Builder->CreateFCmp(InvPred, TrueVal, FalseVal,
                                              FCI->getName() + ".inv");
 
@@ -1067,6 +972,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y
       if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) {
         FCmpInst::Predicate InvPred = FCI->getInversePredicate();
+        IRBuilder<>::FastMathFlagGuard FMFG(*Builder);
+        Builder->setFastMathFlags(FCI->getFastMathFlags());
         Value *NewCond = Builder->CreateFCmp(InvPred, FalseVal, TrueVal,
                                              FCI->getName() + ".inv");
 
@@ -1154,35 +1061,50 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       }
 
   // See if we can fold the select into one of our operands.
-  if (SI.getType()->isIntOrIntVectorTy()) {
+  if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) {
     if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
       return FoldI;
 
     Value *LHS, *RHS, *LHS2, *RHS2;
     Instruction::CastOps CastOp;
-    SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS, &CastOp);
+    SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp);
+    auto SPF = SPR.Flavor;
 
-    if (SPF) {
+    if (SelectPatternResult::isMinOrMax(SPF)) {
       // Canonicalize so that type casts are outside select patterns.
       if (LHS->getType()->getPrimitiveSizeInBits() !=
           SI.getType()->getPrimitiveSizeInBits()) {
-        CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
-        Value *Cmp = Builder->CreateICmp(Pred, LHS, RHS);
+        CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered);
+
+        Value *Cmp;
+        if (CmpInst::isIntPredicate(Pred)) {
+          Cmp = Builder->CreateICmp(Pred, LHS, RHS);
+        } else {
+          IRBuilder<>::FastMathFlagGuard FMFG(*Builder);
+          auto FMF = cast<FPMathOperator>(SI.getCondition())->getFastMathFlags();
+          Builder->setFastMathFlags(FMF);
+          Cmp = Builder->CreateFCmp(Pred, LHS, RHS);
+        }
+
         Value *NewSI = Builder->CreateCast(CastOp,
                                            Builder->CreateSelect(Cmp, LHS, RHS),
                                            SI.getType());
         return ReplaceInstUsesWith(SI, NewSI);
       }
+    }
 
+    if (SPF) {
       // MAX(MAX(a, b), a) -> MAX(a, b)
       // MIN(MIN(a, b), a) -> MIN(a, b)
       // MAX(MIN(a, b), a) -> a
       // MIN(MAX(a, b), a) -> a
-      if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2))
+      // ABS(ABS(a)) -> ABS(a)
+      // NABS(NABS(a)) -> NABS(a)
+      if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor)
         if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2,
                                           SI, SPF, RHS))
           return R;
-      if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2))
+      if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor)
         if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2,
                                           SI, SPF, LHS))
           return R;