Merging r257875:
[oota-llvm.git] / lib / Transforms / Utils / Local.cpp
index 1cb65fccb9d97da93da4426c81e4bc296144139c..abc9b65f7a394b367f7f9394c9e5228d3b2c8fb1 100644 (file)
@@ -1592,3 +1592,205 @@ bool llvm::callsGCLeafFunction(ImmutableCallSite CS) {
 
   return false;
 }
+
+/// A potential constituent of a bitreverse or bswap expression. See
+/// collectBitParts for a fuller explanation.
+struct BitPart {
+  BitPart(Value *P, unsigned BW) : Provider(P) {
+    Provenance.resize(BW);
+  }
+
+  /// The Value that this is a bitreverse/bswap of.
+  Value *Provider;
+  /// The "provenance" of each bit. Provenance[A] = B means that bit A
+  /// in Provider becomes bit B in the result of this expression.
+  SmallVector<int8_t, 32> Provenance; // int8_t means max size is i128.
+
+  enum { Unset = -1 };
+};
+
+/// Analyze the specified subexpression and see if it is capable of providing
+/// pieces of a bswap or bitreverse. The subexpression provides a potential
+/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
+/// the output of the expression came from a corresponding bit in some other
+/// value. This function is recursive, and the end result is a mapping of
+/// bitnumber to bitnumber. It is the caller's responsibility to validate that
+/// the bitnumber to bitnumber mapping is correct for a bswap or bitreverse.
+///
+/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
+/// that the expression deposits the low byte of %X into the high byte of the
+/// result and that all other bits are zero. This expression is accepted and a
+/// BitPart is returned with Provider set to %X and Provenance[24-31] set to
+/// [0-7].
+///
+/// To avoid revisiting values, the BitPart results are memoized into the
+/// provided map. To avoid unnecessary copying of BitParts, BitParts are
+/// constructed in-place in the \c BPS map. Because of this \c BPS needs to
+/// store BitParts objects, not pointers. As we need the concept of a nullptr
+/// BitParts (Value has been analyzed and the analysis failed), we an Optional
+/// type instead to provide the same functionality.
+///
+/// Because we pass around references into \c BPS, we must use a container that
+/// does not invalidate internal references (std::map instead of DenseMap).
+///
+static const Optional<BitPart> &
+collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
+                std::map<Value *, Optional<BitPart>> &BPS) {
+  auto I = BPS.find(V);
+  if (I != BPS.end())
+    return I->second;
+
+  auto &Result = BPS[V] = None;
+  auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+
+  if (Instruction *I = dyn_cast<Instruction>(V)) {
+    // If this is an or instruction, it may be an inner node of the bswap.
+    if (I->getOpcode() == Instruction::Or) {
+      auto &A = collectBitParts(I->getOperand(0), MatchBSwaps,
+                                MatchBitReversals, BPS);
+      auto &B = collectBitParts(I->getOperand(1), MatchBSwaps,
+                                MatchBitReversals, BPS);
+      if (!A || !B)
+        return Result;
+
+      // Try and merge the two together.
+      if (!A->Provider || A->Provider != B->Provider)
+        return Result;
+
+      Result = BitPart(A->Provider, BitWidth);
+      for (unsigned i = 0; i < A->Provenance.size(); ++i) {
+        if (A->Provenance[i] != BitPart::Unset &&
+            B->Provenance[i] != BitPart::Unset &&
+            A->Provenance[i] != B->Provenance[i])
+          return Result = None;
+
+        if (A->Provenance[i] == BitPart::Unset)
+          Result->Provenance[i] = B->Provenance[i];
+        else
+          Result->Provenance[i] = A->Provenance[i];
+      }
+
+      return Result;
+    }
+
+    // If this is a logical shift by a constant, recurse then shift the result.
+    if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
+      unsigned BitShift =
+          cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
+      // Ensure the shift amount is defined.
+      if (BitShift > BitWidth)
+        return Result;
+
+      auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
+                                  MatchBitReversals, BPS);
+      if (!Res)
+        return Result;
+      Result = Res;
+
+      // Perform the "shift" on BitProvenance.
+      auto &P = Result->Provenance;
+      if (I->getOpcode() == Instruction::Shl) {
+        P.erase(std::prev(P.end(), BitShift), P.end());
+        P.insert(P.begin(), BitShift, BitPart::Unset);
+      } else {
+        P.erase(P.begin(), std::next(P.begin(), BitShift));
+        P.insert(P.end(), BitShift, BitPart::Unset);
+      }
+
+      return Result;
+    }
+
+    // If this is a logical 'and' with a mask that clears bits, recurse then
+    // unset the appropriate bits.
+    if (I->getOpcode() == Instruction::And &&
+        isa<ConstantInt>(I->getOperand(1))) {
+      APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
+      const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
+
+      // Check that the mask allows a multiple of 8 bits for a bswap, for an
+      // early exit.
+      unsigned NumMaskedBits = AndMask.countPopulation();
+      if (!MatchBitReversals && NumMaskedBits % 8 != 0)
+        return Result;
+      
+      auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
+                                  MatchBitReversals, BPS);
+      if (!Res)
+        return Result;
+      Result = Res;
+
+      for (unsigned i = 0; i < BitWidth; ++i, Bit <<= 1)
+        // If the AndMask is zero for this bit, clear the bit.
+        if ((AndMask & Bit) == 0)
+          Result->Provenance[i] = BitPart::Unset;
+
+      return Result;
+    }
+  }
+
+  // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
+  // the input value to the bswap/bitreverse.
+  Result = BitPart(V, BitWidth);
+  for (unsigned i = 0; i < BitWidth; ++i)
+    Result->Provenance[i] = i;
+  return Result;
+}
+
+static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
+                                          unsigned BitWidth) {
+  if (From % 8 != To % 8)
+    return false;
+  // Convert from bit indices to byte indices and check for a byte reversal.
+  From >>= 3;
+  To >>= 3;
+  BitWidth >>= 3;
+  return From == BitWidth - To - 1;
+}
+
+static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
+                                               unsigned BitWidth) {
+  return From == BitWidth - To - 1;
+}
+
+/// Given an OR instruction, check to see if this is a bitreverse
+/// idiom. If so, insert the new intrinsic and return true.
+bool llvm::recognizeBitReverseOrBSwapIdiom(
+    Instruction *I, bool MatchBSwaps, bool MatchBitReversals,
+    SmallVectorImpl<Instruction *> &InsertedInsts) {
+  if (Operator::getOpcode(I) != Instruction::Or)
+    return false;
+  if (!MatchBSwaps && !MatchBitReversals)
+    return false;
+  IntegerType *ITy = dyn_cast<IntegerType>(I->getType());
+  if (!ITy || ITy->getBitWidth() > 128)
+    return false;   // Can't do vectors or integers > 128 bits.
+  unsigned BW = ITy->getBitWidth();
+
+  // Try to find all the pieces corresponding to the bswap.
+  std::map<Value *, Optional<BitPart>> BPS;
+  auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS);
+  if (!Res)
+    return false;
+  auto &BitProvenance = Res->Provenance;
+
+  // Now, is the bit permutation correct for a bswap or a bitreverse? We can
+  // only byteswap values with an even number of bytes.
+  bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;
+  for (unsigned i = 0; i < BW; ++i) {
+    OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
+    OKForBitReverse &=
+        bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
+  }
+
+  Intrinsic::ID Intrin;
+  if (OKForBSwap && MatchBSwaps)
+    Intrin = Intrinsic::bswap;
+  else if (OKForBitReverse && MatchBitReversals)
+    Intrin = Intrinsic::bitreverse;
+  else
+    return false;
+
+  Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, ITy);
+  InsertedInsts.push_back(CallInst::Create(F, Res->Provider, "rev", I));
+  return true;
+}