Don't duplicate the work done by a gep into a "bitcast" if the gep has
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index 71a286ef69c47de6fbb230065d12b2c7998be8c0..5733c20828c67c4e3c83607792c3c33345d63ce4 100644 (file)
@@ -133,9 +133,8 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
     }
 
     // Fold this by inserting a select from the input values.
-    SelectInst *NewSI = SelectInst::Create(SI.getCondition(), TI->getOperand(0),
-                                          FI->getOperand(0), SI.getName()+".v");
-    InsertNewInstBefore(NewSI, SI);
+    Value *NewSI = Builder->CreateSelect(SI.getCondition(), TI->getOperand(0),
+                                         FI->getOperand(0), SI.getName()+".v");
     return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI,
                             TI->getType());
   }
@@ -174,9 +173,8 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
   }
 
   // If we reach here, they do have operations in common.
-  SelectInst *NewSI = SelectInst::Create(SI.getCondition(), OtherOpT,
-                                         OtherOpF, SI.getName()+".v");
-  InsertNewInstBefore(NewSI, SI);
+  Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT,
+                                       OtherOpF, SI.getName()+".v");
 
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TI)) {
     if (MatchIsOpZero)
@@ -214,7 +212,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
         unsigned OpToFold = 0;
         if ((SFO & 1) && FalseVal == TVI->getOperand(0)) {
           OpToFold = 1;
-        } else  if ((SFO & 2) && FalseVal == TVI->getOperand(1)) {
+        } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) {
           OpToFold = 2;
         }
 
@@ -224,12 +222,18 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
           // Avoid creating select between 2 constants unless it's selecting
           // between 0, 1 and -1.
           if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
-            Instruction *NewSel = SelectInst::Create(SI.getCondition(), OOp, C);
-            InsertNewInstBefore(NewSel, SI);
+            Value *NewSel = Builder->CreateSelect(SI.getCondition(), OOp, C);
             NewSel->takeName(TVI);
-            if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TVI))
-              return BinaryOperator::Create(BO->getOpcode(), FalseVal, NewSel);
-            llvm_unreachable("Unknown instruction!!");
+            BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI);
+            BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(),
+                                                        FalseVal, NewSel);
+            if (isa<PossiblyExactOperator>(BO))
+              BO->setIsExact(TVI_BO->isExact());
+            if (isa<OverflowingBinaryOperator>(BO)) {
+              BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap());
+              BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap());
+            }
+            return BO;
           }
         }
       }
@@ -243,7 +247,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
         unsigned OpToFold = 0;
         if ((SFO & 1) && TrueVal == FVI->getOperand(0)) {
           OpToFold = 1;
-        } else  if ((SFO & 2) && TrueVal == FVI->getOperand(1)) {
+        } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) {
           OpToFold = 2;
         }
 
@@ -253,12 +257,18 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
           // Avoid creating select between 2 constants unless it's selecting
           // between 0, 1 and -1.
           if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
-            Instruction *NewSel = SelectInst::Create(SI.getCondition(), C, OOp);
-            InsertNewInstBefore(NewSel, SI);
+            Value *NewSel = Builder->CreateSelect(SI.getCondition(), C, OOp);
             NewSel->takeName(FVI);
-            if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FVI))
-              return BinaryOperator::Create(BO->getOpcode(), TrueVal, NewSel);
-            llvm_unreachable("Unknown instruction!!");
+            BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI);
+            BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(),
+                                                        TrueVal, NewSel);
+            if (isa<PossiblyExactOperator>(BO))
+              BO->setIsExact(FVI_BO->isExact());
+            if (isa<OverflowingBinaryOperator>(BO)) {
+              BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap());
+              BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap());
+            }
+            return BO;
           }
         }
       }
@@ -268,6 +278,59 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
   return 0;
 }
 
+/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
+/// replaced with RepOp.
+static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const TargetData *TD) {
+  // Trivial replacement.
+  if (V == Op)
+    return RepOp;
+
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return 0;
+
+  // If this is a binary operator, try to simplify it with the replaced op.
+  if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
+    if (B->getOperand(0) == Op)
+      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD);
+    if (B->getOperand(1) == Op)
+      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD);
+  }
+
+  // Same for CmpInsts.
+  if (CmpInst *C = dyn_cast<CmpInst>(I)) {
+    if (C->getOperand(0) == Op)
+      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD);
+    if (C->getOperand(1) == Op)
+      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD);
+  }
+
+  // TODO: We could hand off more cases to instsimplify here.
+
+  // If all operands are constant after substituting Op for RepOp then we can
+  // constant fold the instruction.
+  if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
+    // Build a list of all constant operands.
+    SmallVector<Constant*, 8> ConstOps;
+    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+      if (I->getOperand(i) == Op)
+        ConstOps.push_back(CRepOp);
+      else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
+        ConstOps.push_back(COp);
+      else
+        break;
+    }
+
+    // All operands were constants, fold it.
+    if (ConstOps.size() == I->getNumOperands())
+      return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
+                                      ConstOps.data(), ConstOps.size(), TD);
+  }
+
+  return 0;
+}
+
 /// visitSelectInstWithICmp - Visit a SelectInst that has an
 /// ICmpInst as its first operand.
 ///
@@ -281,8 +344,8 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   Value *FalseVal = SI.getFalseValue();
 
   // Check cases where the comparison is with a constant that
-  // can be adjusted to fit the min/max idiom. We may edit ICI in
-  // place here, so make sure the select is the only user.
+  // can be adjusted to fit the min/max idiom. We may move or edit ICI
+  // here, so make sure the select is the only user.
   if (ICI->hasOneUse())
     if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) {
       // X < MIN ? T : F  -->  F
@@ -364,6 +427,11 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
         ICI->setOperand(1, CmpRHS);
         SI.setOperand(1, TrueVal);
         SI.setOperand(2, FalseVal);
+
+        // Move ICI instruction right before the select instruction. Otherwise
+        // the sext/zext value may be defined after the ICI instruction uses it.
+        ICI->moveBefore(&SI);
+
         Changed = true;
         break;
       }
@@ -401,24 +469,33 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
-    // Transform (X == Y) ? X : Y  -> Y
-    if (Pred == ICmpInst::ICMP_EQ)
+  // If we have an equality comparison then we know the value in one of the
+  // arms of the select. See if substituting this value into the arm and
+  // simplifying the result yields the same value as the other arm.
+  if (Pred == ICmpInst::ICMP_EQ) {
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD) == TrueVal)
       return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? X : Y  -> X
-    if (Pred == ICmpInst::ICMP_NE)
+  } else if (Pred == ICmpInst::ICMP_NE) {
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD) == FalseVal)
       return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
+  }
 
-  } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) {
-    // Transform (X == Y) ? Y : X  -> X
-    if (Pred == ICmpInst::ICMP_EQ)
-      return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? Y : X  -> Y
-    if (Pred == ICmpInst::ICMP_NE)
-      return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
+  // NOTE: if we wanted to, this is where to detect integer MIN/MAX
+
+  if (isa<Constant>(CmpRHS)) {
+    if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
+      // Transform (X == C) ? X : Y -> (X == C) ? C : Y
+      SI.setOperand(1, CmpRHS);
+      Changed = true;
+    } else if (CmpLHS == FalseVal && Pred == ICmpInst::ICMP_NE) {
+      // Transform (X != C) ? Y : X -> (X != C) ? Y : C
+      SI.setOperand(2, CmpRHS);
+      Changed = true;
+    }
   }
+
   return Changed ? &SI : 0;
 }
 
@@ -498,9 +575,8 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal,
   if (!IC || !IC->isEquality())
     return 0;
 
-  if (ConstantInt *C = dyn_cast<ConstantInt>(IC->getOperand(1)))
-    if (!C->isZero())
-      return 0;
+  if (!match(IC->getOperand(1), m_Zero()))
+    return 0;
 
   ConstantInt *AndRHS;
   Value *LHS = IC->getOperand(0);
@@ -573,9 +649,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
         return BinaryOperator::CreateOr(CondVal, FalseVal);
       }
       // Change: A = select B, false, C --> A = and !B, C
-      Value *NotCond =
-        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                           "not."+CondVal->getName()), SI);
+      Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
       return BinaryOperator::CreateAnd(NotCond, FalseVal);
     } else if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
       if (C->getZExtValue() == false) {
@@ -583,9 +657,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
         return BinaryOperator::CreateAnd(CondVal, TrueVal);
       }
       // Change: A = select B, C, true --> A = or !B, C
-      Value *NotCond =
-        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                           "not."+CondVal->getName()), SI);
+      Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
       return BinaryOperator::CreateOr(NotCond, TrueVal);
     }
 
@@ -724,28 +796,21 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
             // So at this point we know we have (Y -> OtherAddOp):
             //        select C, (add X, Y), (sub X, Z)
             Value *NegVal;  // Compute -Z
-            if (Constant *C = dyn_cast<Constant>(SubOp->getOperand(1))) {
-              NegVal = ConstantExpr::getNeg(C);
-            } else if (SI.getType()->isFloatingPointTy()) {
-              NegVal = InsertNewInstBefore(
-                    BinaryOperator::CreateFNeg(SubOp->getOperand(1),
-                                              "tmp"), SI);
+            if (SI.getType()->isFPOrFPVectorTy()) {
+              NegVal = Builder->CreateFNeg(SubOp->getOperand(1));
             } else {
-              NegVal = InsertNewInstBefore(
-                    BinaryOperator::CreateNeg(SubOp->getOperand(1),
-                                              "tmp"), SI);
+              NegVal = Builder->CreateNeg(SubOp->getOperand(1));
             }
 
             Value *NewTrueOp = OtherAddOp;
             Value *NewFalseOp = NegVal;
             if (AddOp != TI)
               std::swap(NewTrueOp, NewFalseOp);
-            Instruction *NewSel =
-              SelectInst::Create(CondVal, NewTrueOp,
-                                 NewFalseOp, SI.getName() + ".p");
+            Value *NewSel = 
+              Builder->CreateSelect(CondVal, NewTrueOp,
+                                    NewFalseOp, SI.getName() + ".p");
 
-            NewSel = InsertNewInstBefore(NewSel, SI);
-            if (SI.getType()->isFloatingPointTy())
+            if (SI.getType()->isFPOrFPVectorTy())
               return BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel);
             else
               return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel);
@@ -787,6 +852,19 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       if (Instruction *NV = FoldOpIntoPhi(SI))
         return NV;
 
+  if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
+    if (TrueSI->getCondition() == CondVal) {
+      SI.setOperand(1, TrueSI->getTrueValue());
+      return &SI;
+    }
+  }
+  if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
+    if (FalseSI->getCondition() == CondVal) {
+      SI.setOperand(2, FalseSI->getFalseValue());
+      return &SI;
+    }
+  }
+
   if (BinaryOperator::isNot(CondVal)) {
     SI.setOperand(0, BinaryOperator::getNotArgument(CondVal));
     SI.setOperand(1, FalseVal);