Merging r257875:
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAndOrXor.cpp
index 95c50d32c8207b12f5ec0abfd78301ce0e4da6a4..76cefd97cd8f1a53479aa94ee9ab3e50185c3a18 100644 (file)
@@ -17,6 +17,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Transforms/Utils/CmpInstAnalysis.h"
+#include "llvm/Transforms/Utils/Local.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -1565,190 +1566,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
   return Changed ? &I : nullptr;
 }
 
-
-/// Analyze the specified subexpression and see if it is capable of providing
-/// pieces of a bswap or bitreverse. The subexpression provides a potential
-/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
-/// the output of the expression came from a corresponding bit in some other
-/// value. This function is recursive, and the end result is a mapping of
-/// (value, bitnumber) to bitnumber. It is the caller's responsibility to
-/// validate that all `value`s are identical and that the bitnumber to bitnumber
-/// mapping is correct for a bswap or bitreverse.
-///
-/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
-/// that the expression deposits the low byte of %X into the high byte of the
-/// result and that all other bits are zero. This expression is accepted,
-/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7].
-///
-/// This function returns true if the match was unsuccessful and false if so.
-/// On entry to the function the "OverallLeftShift" is a signed integer value
-/// indicating the number of bits that the subexpression is later shifted.  For
-/// example, if the expression is later right shifted by 16 bits, the
-/// OverallLeftShift value would be -16 on entry.  This is used to specify which
-/// bits of BitValues are actually being set.
-///
-/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding
-/// bit is masked to zero by a user.  For example, in (X & 255), X will be
-/// processed with a bytemask of 255. BitMask is always in the local
-/// (OverallLeftShift) coordinate space.
-///
-static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask,
-                            SmallVectorImpl<Value *> &BitValues,
-                            SmallVectorImpl<int> &BitProvenance) {
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
-    // If this is an or instruction, it may be an inner node of the bswap.
-    if (I->getOpcode() == Instruction::Or)
-      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
-                             BitValues, BitProvenance) ||
-             CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask,
-                             BitValues, BitProvenance);
-
-    // If this is a logical shift by a constant, recurse with OverallLeftShift
-    // and BitMask adjusted.
-    if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
-      unsigned ShAmt =
-          cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
-      // Ensure the shift amount is defined.
-      if (ShAmt > BitValues.size())
-        return true;
-
-      unsigned BitShift = ShAmt;
-      if (I->getOpcode() == Instruction::Shl) {
-        // X << C -> collect(X, +C)
-        OverallLeftShift += BitShift;
-        BitMask = BitMask.lshr(BitShift);
-      } else {
-        // X >>u C -> collect(X, -C)
-        OverallLeftShift -= BitShift;
-        BitMask = BitMask.shl(BitShift);
-      }
-
-      if (OverallLeftShift >= (int)BitValues.size())
-        return true;
-      if (OverallLeftShift <= -(int)BitValues.size())
-        return true;
-
-      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
-                             BitValues, BitProvenance);
-    }
-
-    // If this is a logical 'and' with a mask that clears bits, clear the
-    // corresponding bits in BitMask.
-    if (I->getOpcode() == Instruction::And &&
-        isa<ConstantInt>(I->getOperand(1))) {
-      unsigned NumBits = BitValues.size();
-      APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
-      const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
-
-      for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) {
-        // If this bit is masked out by a later operation, we don't care what
-        // the and mask is.
-        if (BitMask[i] == 0)
-          continue;
-
-        // If the AndMask is zero for this bit, clear the bit.
-        APInt MaskB = AndMask & Bit;
-        if (MaskB == 0) {
-          BitMask.clearBit(i);
-          continue;
-        }
-
-        // Otherwise, this bit is kept.
-      }
-
-      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
-                             BitValues, BitProvenance);
-    }
-  }
-
-  // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
-  // the input value to the bswap/bitreverse. To be part of a bswap or
-  // bitreverse we must be demanding a contiguous range of bits from it.
-  unsigned InputBitLen = BitMask.countPopulation();
-  unsigned InputBitNo = BitMask.countTrailingZeros();
-  if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo !=
-      InputBitLen)
-    // Not a contiguous set range of bits!
-    return true;
-
-  // We know we're moving a contiguous range of bits from the input to the
-  // output. Record which bits in the output came from which bits in the input.
-  unsigned DestBitNo = InputBitNo + OverallLeftShift;
-  for (unsigned I = 0; I < InputBitLen; ++I)
-    BitProvenance[DestBitNo + I] = InputBitNo + I;
-
-  // If the destination bit value is already defined, the values are or'd
-  // together, which isn't a bswap/bitreverse (unless it's an or of the same
-  // bits).
-  if (BitValues[DestBitNo] && BitValues[DestBitNo] != V)
-    return true;
-  for (unsigned I = 0; I < InputBitLen; ++I)
-    BitValues[DestBitNo + I] = V;
-
-  return false;
-}
-
-static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
-                                          unsigned BitWidth) {
-  if (From % 8 != To % 8)
-    return false;
-  // Convert from bit indices to byte indices and check for a byte reversal.
-  From >>= 3;
-  To >>= 3;
-  BitWidth >>= 3;
-  return From == BitWidth - To - 1;
-}
-
-static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
-                                               unsigned BitWidth) {
-  return From == BitWidth - To - 1;
-}
-
 /// Given an OR instruction, check to see if this is a bswap or bitreverse
 /// idiom. If so, insert the new intrinsic and return it.
 Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) {
-  IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
-  if (!ITy)
-    return nullptr;   // Can't do vectors.
-  unsigned BW = ITy->getBitWidth();
-  
-  /// We keep track of which bit (BitProvenance) inside which value (BitValues)
-  /// defines each bit in the result.
-  SmallVector<Value *, 8> BitValues(BW, nullptr);
-  SmallVector<int, 8> BitProvenance(BW, -1);
-  
-  // Try to find all the pieces corresponding to the bswap.
-  APInt BitMask = APInt::getAllOnesValue(BitValues.size());
-  if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance))
-    return nullptr;
-
-  // Check to see if all of the bits come from the same value.
-  Value *V = BitValues[0];
-  if (!V) return nullptr;  // Didn't find a bit?  Must be zero.
-
-  if (!std::all_of(BitValues.begin(), BitValues.end(),
-                   [&](const Value *X) { return X == V; }))
-    return nullptr;
-
-  // Now, is the bit permutation correct for a bswap or a bitreverse? We can
-  // only byteswap values with an even number of bytes.
-  bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;;
-  for (unsigned i = 0, e = BitValues.size(); i != e; ++i) {
-    OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
-    OKForBitReverse &=
-        bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
-  }
-
-  Intrinsic::ID Intrin;
-  if (OKForBSwap)
-    Intrin = Intrinsic::bswap;
-  else if (OKForBitReverse)
-    Intrin = Intrinsic::bitreverse;
-  else
+  SmallVector<Instruction*, 4> Insts;
+  if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts))
     return nullptr;
+  Instruction *LastInst = Insts.pop_back_val();
+  LastInst->removeFromParent();
 
-  Function *F = Intrinsic::getDeclaration(I.getModule(), Intrin, ITy);
-  return CallInst::Create(F, V);
+  for (auto *Inst : Insts)
+    Worklist.Add(Inst);
+  return LastInst;
 }
 
 /// We have an expression of the form (A&C)|(B&D).  Check if A is (cond?-1:0)