Follow-up fix to r165928: handle memset rewriting for widened integers,
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAndOrXor.cpp
index 4b1e7feef075b2fea34431ae121d9f4e3d8b911c..7d0af0d80226f06f3cc4fd1edabac3e44d79cd98 100644 (file)
@@ -87,16 +87,15 @@ static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) {
   default:
     // Not expecting FCMP_FALSE and FCMP_TRUE;
     llvm_unreachable("Unexpected FCmp predicate!");
-    return 0;
   }
 }
 
-/// getICmpValue - This is the complement of getICmpCode, which turns an
+/// getNewICmpValue - This is the complement of getICmpCode, which turns an
 /// opcode and two operands into either a constant true or false, or a brand 
 /// new ICmp instruction. The sign is passed in to determine which kind
 /// of predicate to use in the new icmp instruction.
-Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS,
-                    InstCombiner::BuilderTy *Builder) {
+static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS,
+                              InstCombiner::BuilderTy *Builder) {
   ICmpInst::Predicate NewPred;
   if (Value *NewConstant = getICmpValue(Sign, Code, LHS, RHS, NewPred))
     return NewConstant;
@@ -111,7 +110,7 @@ static Value *getFCmpValue(bool isordered, unsigned code,
                            InstCombiner::BuilderTy *Builder) {
   CmpInst::Predicate Pred;
   switch (code) {
-  default: assert(0 && "Illegal FCmp code!");
+  default: llvm_unreachable("Illegal FCmp code!");
   case 0: Pred = isordered ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; break;
   case 1: Pred = isordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; break;
   case 2: Pred = isordered ? FCmpInst::FCMP_OEQ : FCmpInst::FCMP_UEQ; break;
@@ -496,6 +495,38 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
   return result;
 }
 
+/// decomposeBitTestICmp - Decompose an icmp into the form ((X & Y) pred Z)
+/// if possible. The returned predicate is either == or !=. Returns false if
+/// decomposition fails.
+static bool decomposeBitTestICmp(const ICmpInst *I, ICmpInst::Predicate &Pred,
+                                 Value *&X, Value *&Y, Value *&Z) {
+  // X < 0 is equivalent to (X & SignBit) != 0.
+  if (I->getPredicate() == ICmpInst::ICMP_SLT)
+    if (ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)))
+      if (C->isZero()) {
+        X = I->getOperand(0);
+        Y = ConstantInt::get(I->getContext(),
+                             APInt::getSignBit(C->getBitWidth()));
+        Pred = ICmpInst::ICMP_NE;
+        Z = C;
+        return true;
+      }
+
+  // X > -1 is equivalent to (X & SignBit) == 0.
+  if (I->getPredicate() == ICmpInst::ICMP_SGT)
+    if (ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)))
+      if (C->isAllOnesValue()) {
+        X = I->getOperand(0);
+        Y = ConstantInt::get(I->getContext(),
+                             APInt::getSignBit(C->getBitWidth()));
+        Pred = ICmpInst::ICMP_EQ;
+        Z = ConstantInt::getNullValue(C->getType());
+        return true;
+      }
+
+  return false;
+}
+
 /// foldLogOpOfMaskedICmpsHelper:
 /// handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
 /// return the set of pattern classes (from MaskedICmpType)
@@ -503,10 +534,9 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
 static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, 
                                              Value*& B, Value*& C,
                                              Value*& D, Value*& E,
-                                             ICmpInst *LHS, ICmpInst *RHS) {
-  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
-  if (LHSCC != ICmpInst::ICMP_EQ && LHSCC != ICmpInst::ICMP_NE) return 0;
-  if (RHSCC != ICmpInst::ICMP_EQ && RHSCC != ICmpInst::ICMP_NE) return 0;
+                                             ICmpInst *LHS, ICmpInst *RHS,
+                                             ICmpInst::Predicate &LHSCC,
+                                             ICmpInst::Predicate &RHSCC) {
   if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0;
   // vectors are not (yet?) supported
   if (LHS->getOperand(0)->getType()->isVectorTy()) return 0;
@@ -520,40 +550,60 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
   Value *L1 = LHS->getOperand(0);
   Value *L2 = LHS->getOperand(1);
   Value *L11,*L12,*L21,*L22;
-  if (match(L1, m_And(m_Value(L11), m_Value(L12)))) {
-    if (!match(L2, m_And(m_Value(L21), m_Value(L22))))
+  // Check whether the icmp can be decomposed into a bit test.
+  if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) {
+    L21 = L22 = L1 = 0;
+  } else {
+    // Look for ANDs in the LHS icmp.
+    if (match(L1, m_And(m_Value(L11), m_Value(L12)))) {
+      if (!match(L2, m_And(m_Value(L21), m_Value(L22))))
+        L21 = L22 = 0;
+    } else {
+      if (!match(L2, m_And(m_Value(L11), m_Value(L12))))
+        return 0;
+      std::swap(L1, L2);
       L21 = L22 = 0;
+    }
   }
-  else {
-    if (!match(L2, m_And(m_Value(L11), m_Value(L12))))
-      return 0;
-    std::swap(L1, L2);
-    L21 = L22 = 0;
-  }
+
+  // Bail if LHS was a icmp that can't be decomposed into an equality.
+  if (!ICmpInst::isEquality(LHSCC))
+    return 0;
 
   Value *R1 = RHS->getOperand(0);
   Value *R2 = RHS->getOperand(1);
   Value *R11,*R12;
   bool ok = false;
-  if (match(R1, m_And(m_Value(R11), m_Value(R12)))) {
-    if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) {
-      A = R11; D = R12; E = R2; ok = true;
+  if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) {
+    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+      A = R11; D = R12;
+    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+      A = R12; D = R11;
+    } else {
+      return 0;
     }
-    else 
-    if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) {
+    E = R2; R1 = 0; ok = true;
+  } else if (match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+      A = R11; D = R12; E = R2; ok = true;
+    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
       A = R12; D = R11; E = R2; ok = true;
     }
   }
+
+  // Bail if RHS was a icmp that can't be decomposed into an equality.
+  if (!ICmpInst::isEquality(RHSCC))
+    return 0;
+
+  // Look for ANDs in on the right side of the RHS icmp.
   if (!ok && match(R2, m_And(m_Value(R11), m_Value(R12)))) {
-    if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) {
-       A = R11; D = R12; E = R1; ok = true;
-    }
-    else 
-    if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) {
+    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+      A = R11; D = R12; E = R1; ok = true;
+    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
       A = R12; D = R11; E = R1; ok = true;
-    }
-    else
+    } else {
       return 0;
+    }
   }
   if (!ok)
     return 0;
@@ -582,8 +632,12 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS,
                                      ICmpInst::Predicate NEWCC,
                                      llvm::InstCombiner::BuilderTy* Builder) {
   Value *A = 0, *B = 0, *C = 0, *D = 0, *E = 0;
-  unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS);
+  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+  unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS,
+                                               LHSCC, RHSCC);
   if (mask == 0) return 0;
+  assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) &&
+         "foldLogOpOfMaskedICmpsHelper must return an equality predicate.");
 
   if (NEWCC == ICmpInst::ICMP_NE)
     mask >>= 1; // treat "Not"-states as normal states
@@ -631,11 +685,11 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS,
 
     ConstantInt *CCst = dyn_cast<ConstantInt>(C);
     if (CCst == 0) return 0;
-    if (LHS->getPredicate() != NEWCC)
+    if (LHSCC != NEWCC)
       CCst = dyn_cast<ConstantInt>( ConstantExpr::getXor(BCst, CCst) );
     ConstantInt *ECst = dyn_cast<ConstantInt>(E);
     if (ECst == 0) return 0;
-    if (RHS->getPredicate() != NEWCC)
+    if (RHSCC != NEWCC)
       ECst = dyn_cast<ConstantInt>( ConstantExpr::getXor(DCst, ECst) );
     ConstantInt* MCst = dyn_cast<ConstantInt>(
       ConstantExpr::getAnd(ConstantExpr::getAnd(BCst, DCst),
@@ -694,24 +748,12 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
       Value *NewOr = Builder->CreateOr(Val, Val2);
       return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
     }
-
-    // (icmp slt A, 0) & (icmp slt B, 0) --> (icmp slt (A&B), 0)
-    if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) {
-      Value *NewAnd = Builder->CreateAnd(Val, Val2);
-      return Builder->CreateICmp(LHSCC, NewAnd, LHSCst);
-    }
-
-    // (icmp sgt A, -1) & (icmp sgt B, -1) --> (icmp sgt (A|B), -1)
-    if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) {
-      Value *NewOr = Builder->CreateOr(Val, Val2);
-      return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
-    }
   }
 
   // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2
   // where CMAX is the all ones value for the truncated type,
   // iff the lower bits of C2 and CA are zero.
-  if (LHSCC == RHSCC && ICmpInst::isEquality(LHSCC) &&
+  if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC &&
       LHS->hasOneUse() && RHS->hasOneUse()) {
     Value *V;
     ConstantInt *AndCst, *SmallCst = 0, *BigCst = 0;
@@ -743,7 +785,7 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
       }
     }
   }
-  
+
   // From here on, we only handle:
   //    (icmp1 A, C1) & (icmp2 A, C2) --> something simpler.
   if (Val != Val2) return 0;
@@ -944,19 +986,23 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) {
     bool Op1Ordered;
     unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered);
     unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered);
+    // uno && ord -> false
+    if (Op0Pred == 0 && Op1Pred == 0 && Op0Ordered != Op1Ordered)
+        return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0);
     if (Op1Pred == 0) {
       std::swap(LHS, RHS);
       std::swap(Op0Pred, Op1Pred);
       std::swap(Op0Ordered, Op1Ordered);
     }
     if (Op0Pred == 0) {
-      // uno && ueq -> uno && (uno || eq) -> ueq
+      // uno && ueq -> uno && (uno || eq) -> uno
       // ord && olt -> ord && (ord && lt) -> olt
-      if (Op0Ordered == Op1Ordered)
+      if (!Op0Ordered && (Op0Ordered == Op1Ordered))
+        return LHS;
+      if (Op0Ordered && (Op0Ordered == Op1Ordered))
         return RHS;
       
       // uno && oeq -> uno && (ord && eq) -> false
-      // uno && ord -> false
       if (!Op0Ordered)
         return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0);
       // ord && ueq -> ord && (uno || eq) -> oeq
@@ -1320,13 +1366,8 @@ static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask,
   // part of the value (e.g. byte 3) then it must be shifted right.  If from the
   // low part, it must be shifted left.
   unsigned DestByteNo = InputByteNo + OverallLeftShift;
-  if (InputByteNo < ByteValues.size()/2) {
-    if (ByteValues.size()-1-DestByteNo != InputByteNo)
-      return true;
-  } else {
-    if (ByteValues.size()-1-DestByteNo != InputByteNo)
-      return true;
-  }
+  if (ByteValues.size()-1-DestByteNo != InputByteNo)
+    return true;
   
   // If the destination byte value is already defined, the values are or'd
   // together, which isn't a bswap (unless it's an or of the same bits).
@@ -1428,18 +1469,6 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
       Value *NewOr = Builder->CreateOr(Val, Val2);
       return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
     }
-
-    // (icmp slt A, 0) | (icmp slt B, 0) --> (icmp slt (A|B), 0)
-    if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) {
-      Value *NewOr = Builder->CreateOr(Val, Val2);
-      return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
-    }
-
-    // (icmp sgt A, -1) | (icmp sgt B, -1) --> (icmp sgt (A&B), -1)
-    if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) {
-      Value *NewAnd = Builder->CreateAnd(Val, Val2);
-      return Builder->CreateICmp(LHSCC, NewAnd, LHSCst);
-    }
   }
 
   // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1)
@@ -1524,7 +1553,6 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
     case ICmpInst::ICMP_SLT:         // (X != 13 | X s< 15) -> true
       return ConstantInt::getTrue(LHS->getContext());
     }
-    break;
   case ICmpInst::ICMP_ULT:
     switch (RHSCC) {
     default: llvm_unreachable("Unknown integer condition code!");
@@ -1900,15 +1928,23 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
       }
 
   // Canonicalize xor to the RHS.
-  if (match(Op0, m_Xor(m_Value(), m_Value())))
+  bool SwappedForXor = false;
+  if (match(Op0, m_Xor(m_Value(), m_Value()))) {
     std::swap(Op0, Op1);
+    SwappedForXor = true;
+  }
 
   // A | ( A ^ B) -> A |  B
   // A | (~A ^ B) -> A | ~B
+  // (A & B) | (A ^ B)
   if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) {
     if (Op0 == A || Op0 == B)
       return BinaryOperator::CreateOr(A, B);
 
+    if (match(Op0, m_And(m_Specific(A), m_Specific(B))) ||
+        match(Op0, m_And(m_Specific(B), m_Specific(A))))
+      return BinaryOperator::CreateOr(A, B);
+
     if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) {
       Value *Not = Builder->CreateNot(B, B->getName()+".not");
       return BinaryOperator::CreateOr(Not, Op0);
@@ -1932,6 +1968,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
         return BinaryOperator::CreateOr(Not, Op0);
       }
 
+  if (SwappedForXor)
+    std::swap(Op0, Op1);
+
   if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
     if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
       if (Value *Res = FoldOrOfICmps(LHS, RHS))
@@ -2182,7 +2221,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
   if (Op0I && Op1I && Op0I->isShift() && 
       Op0I->getOpcode() == Op1I->getOpcode() && 
       Op0I->getOperand(1) == Op1I->getOperand(1) &&
-      (Op1I->hasOneUse() || Op1I->hasOneUse())) {
+      (Op0I->hasOneUse() || Op1I->hasOneUse())) {
     Value *NewOp =
       Builder->CreateXor(Op0I->getOperand(0), Op1I->getOperand(0),
                          Op0I->getName());