From: James Molloy Date: Fri, 11 Dec 2015 10:04:51 +0000 (+0000) Subject: [InstCombine] Make MatchBSwap also match bit reversals X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=e15387964886dcb3d919648218da958f9707d988;p=oota-llvm.git [InstCombine] Make MatchBSwap also match bit reversals MatchBSwap has most of the functionality to match bit reversals already. If we switch it from looking at bytes to individual bits and remove a few early exits, we can extend the main recursive function to match any sequence of ORs, ANDs and shifts that assemble a value from different parts of another, base value. Once we have this bit->bit mapping, we can very simply detect if it is appropriate for a bswap or bitreverse. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@255334 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2bf6faa47b9..7f01d58c2ff 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1566,157 +1566,190 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return Changed ? &I : nullptr; } + /// Analyze the specified subexpression and see if it is capable of providing -/// pieces of a bswap. The subexpression provides pieces of a bswap if it is -/// proven that each of the non-zero bytes in the output of the expression came -/// from the corresponding "byte swapped" byte in some other value. -/// For example, if the current subexpression is "(shl i32 %X, 24)" then -/// we know that the expression deposits the low byte of %X into the high byte -/// of the bswap result and that all other bytes are zero. This expression is -/// accepted, the high byte of ByteValues is set to X to indicate a correct -/// match. +/// pieces of a bswap or bitreverse. The subexpression provides a potential +/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in +/// the output of the expression came from a corresponding bit in some other +/// value. This function is recursive, and the end result is a mapping of +/// (value, bitnumber) to bitnumber. It is the caller's responsibility to +/// validate that all `value`s are identical and that the bitnumber to bitnumber +/// mapping is correct for a bswap or bitreverse. +/// +/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know +/// that the expression deposits the low byte of %X into the high byte of the +/// result and that all other bits are zero. This expression is accepted, +/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7]. /// /// This function returns true if the match was unsuccessful and false if so. /// On entry to the function the "OverallLeftShift" is a signed integer value -/// indicating the number of bytes that the subexpression is later shifted. For +/// indicating the number of bits that the subexpression is later shifted. For /// example, if the expression is later right shifted by 16 bits, the -/// OverallLeftShift value would be -2 on entry. This is used to specify which -/// byte of ByteValues is actually being set. +/// OverallLeftShift value would be -16 on entry. This is used to specify which +/// bits of BitValues are actually being set. /// -/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding -/// byte is masked to zero by a user. For example, in (X & 255), X will be -/// processed with a bytemask of 1. Because bytemask is 32-bits, this limits -/// this function to working on up to 32-byte (256 bit) values. ByteMask is -/// always in the local (OverallLeftShift) coordinate space. +/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding +/// bit is masked to zero by a user. For example, in (X & 255), X will be +/// processed with a bytemask of 255. BitMask is always in the local +/// (OverallLeftShift) coordinate space. /// -static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask, - SmallVectorImpl &ByteValues) { +static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask, + SmallVectorImpl &BitValues, + SmallVectorImpl &BitProvenance) { if (Instruction *I = dyn_cast(V)) { // If this is an or instruction, it may be an inner node of the bswap. - if (I->getOpcode() == Instruction::Or) { - return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, - ByteValues) || - CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask, - ByteValues); - } - - // If this is a logical shift by a constant multiple of 8, recurse with - // OverallLeftShift and ByteMask adjusted. + if (I->getOpcode() == Instruction::Or) + return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, + BitValues, BitProvenance) || + CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask, + BitValues, BitProvenance); + + // If this is a logical shift by a constant, recurse with OverallLeftShift + // and BitMask adjusted. if (I->isLogicalShift() && isa(I->getOperand(1))) { unsigned ShAmt = - cast(I->getOperand(1))->getLimitedValue(~0U); - // Ensure the shift amount is defined and of a byte value. - if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size())) + cast(I->getOperand(1))->getLimitedValue(~0U); + // Ensure the shift amount is defined. + if (ShAmt > BitValues.size()) return true; - unsigned ByteShift = ShAmt >> 3; + unsigned BitShift = ShAmt; if (I->getOpcode() == Instruction::Shl) { - // X << 2 -> collect(X, +2) - OverallLeftShift += ByteShift; - ByteMask >>= ByteShift; + // X << C -> collect(X, +C) + OverallLeftShift += BitShift; + BitMask = BitMask.lshr(BitShift); } else { - // X >>u 2 -> collect(X, -2) - OverallLeftShift -= ByteShift; - ByteMask <<= ByteShift; - ByteMask &= (~0U >> (32-ByteValues.size())); + // X >>u C -> collect(X, -C) + OverallLeftShift -= BitShift; + BitMask = BitMask.shl(BitShift); } - if (OverallLeftShift >= (int)ByteValues.size()) return true; - if (OverallLeftShift <= -(int)ByteValues.size()) return true; + if (OverallLeftShift >= (int)BitValues.size()) + return true; + if (OverallLeftShift <= -(int)BitValues.size()) + return true; - return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, - ByteValues); + return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, + BitValues, BitProvenance); } - // If this is a logical 'and' with a mask that clears bytes, clear the - // corresponding bytes in ByteMask. + // If this is a logical 'and' with a mask that clears bits, clear the + // corresponding bits in BitMask. if (I->getOpcode() == Instruction::And && isa(I->getOperand(1))) { - // Scan every byte of the and mask, seeing if the byte is either 0 or 255. - unsigned NumBytes = ByteValues.size(); - APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255); + unsigned NumBits = BitValues.size(); + APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1); const APInt &AndMask = cast(I->getOperand(1))->getValue(); - for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) { - // If this byte is masked out by a later operation, we don't care what + for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) { + // If this bit is masked out by a later operation, we don't care what // the and mask is. - if ((ByteMask & (1 << i)) == 0) + if (BitMask[i] == 0) continue; - // If the AndMask is all zeros for this byte, clear the bit. - APInt MaskB = AndMask & Byte; + // If the AndMask is zero for this bit, clear the bit. + APInt MaskB = AndMask & Bit; if (MaskB == 0) { - ByteMask &= ~(1U << i); + BitMask.clearBit(i); continue; } - // If the AndMask is not all ones for this byte, it's not a bytezap. - if (MaskB != Byte) - return true; - - // Otherwise, this byte is kept. + // Otherwise, this bit is kept. } - return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, - ByteValues); + return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, + BitValues, BitProvenance); } } // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be - // the input value to the bswap. Some observations: 1) if more than one byte - // is demanded from this input, then it could not be successfully assembled - // into a byteswap. At least one of the two bytes would not be aligned with - // their ultimate destination. - if (!isPowerOf2_32(ByteMask)) return true; - unsigned InputByteNo = countTrailingZeros(ByteMask); - - // 2) The input and ultimate destinations must line up: if byte 3 of an i32 - // is demanded, it needs to go into byte 0 of the result. This means that the - // byte needs to be shifted until it lands in the right byte bucket. The - // shift amount depends on the position: if the byte is coming from the high - // part of the value (e.g. byte 3) then it must be shifted right. If from the - // low part, it must be shifted left. - unsigned DestByteNo = InputByteNo + OverallLeftShift; - if (ByteValues.size()-1-DestByteNo != InputByteNo) + // the input value to the bswap/bitreverse. To be part of a bswap or + // bitreverse we must be demanding a contiguous range of bits from it. + unsigned InputBitLen = BitMask.countPopulation(); + unsigned InputBitNo = BitMask.countTrailingZeros(); + if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo != + InputBitLen) + // Not a contiguous set range of bits! return true; - // If the destination byte value is already defined, the values are or'd - // together, which isn't a bswap (unless it's an or of the same bits). - if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V) + // We know we're moving a contiguous range of bits from the input to the + // output. Record which bits in the output came from which bits in the input. + unsigned DestBitNo = InputBitNo + OverallLeftShift; + for (unsigned I = 0; I < InputBitLen; ++I) + BitProvenance[DestBitNo + I] = InputBitNo + I; + + // If the destination bit value is already defined, the values are or'd + // together, which isn't a bswap/bitreverse (unless it's an or of the same + // bits). + if (BitValues[DestBitNo] && BitValues[DestBitNo] != V) return true; - ByteValues[DestByteNo] = V; + for (unsigned I = 0; I < InputBitLen; ++I) + BitValues[DestBitNo + I] = V; + return false; } -/// Given an OR instruction, check to see if this is a bswap idiom. -/// If so, insert the new bswap intrinsic and return it. -Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { - IntegerType *ITy = dyn_cast(I.getType()); - if (!ITy || ITy->getBitWidth() % 16 || - // ByteMask only allows up to 32-byte values. - ITy->getBitWidth() > 32*8) - return nullptr; // Can only bswap pairs of bytes. Can't do vectors. +static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To, + unsigned BitWidth) { + if (From % 8 != To % 8) + return false; + // Convert from bit indices to byte indices and check for a byte reversal. + From >>= 3; + To >>= 3; + BitWidth >>= 3; + return From == BitWidth - To - 1; +} - /// ByteValues - For each byte of the result, we keep track of which value - /// defines each byte. - SmallVector ByteValues; - ByteValues.resize(ITy->getBitWidth()/8); +static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To, + unsigned BitWidth) { + return From == BitWidth - To - 1; +} +/// Given an OR instruction, check to see if this is a bswap or bitreverse +/// idiom. If so, insert the new intrinsic and return it. +Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { + IntegerType *ITy = dyn_cast(I.getType()); + if (!ITy) + return nullptr; // Can't do vectors. + unsigned BW = ITy->getBitWidth(); + + /// We keep track of which bit (BitProvenance) inside which value (BitValues) + /// defines each bit in the result. + SmallVector BitValues(BW, nullptr); + SmallVector BitProvenance(BW, -1); + // Try to find all the pieces corresponding to the bswap. - uint32_t ByteMask = ~0U >> (32-ByteValues.size()); - if (CollectBSwapParts(&I, 0, ByteMask, ByteValues)) + APInt BitMask = APInt::getAllOnesValue(BitValues.size()); + if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance)) return nullptr; - // Check to see if all of the bytes come from the same value. - Value *V = ByteValues[0]; - if (!V) return nullptr; // Didn't find a byte? Must be zero. + // Check to see if all of the bits come from the same value. + Value *V = BitValues[0]; + if (!V) return nullptr; // Didn't find a bit? Must be zero. - // Check to make sure that all of the bytes come from the same value. - for (unsigned i = 1, e = ByteValues.size(); i != e; ++i) - if (ByteValues[i] != V) - return nullptr; + if (!std::all_of(BitValues.begin(), BitValues.end(), + [&](const Value *X) { return X == V; })) + return nullptr; + + // Now, is the bit permutation correct for a bswap or a bitreverse? We can + // only byteswap values with an even number of bytes. + bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;; + for (unsigned i = 0, e = BitValues.size(); i != e; ++i) { + OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW); + OKForBitReverse &= + bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW); + } + + Intrinsic::ID Intrin; + if (OKForBSwap) + Intrin = Intrinsic::bswap; + else if (OKForBitReverse) + Intrin = Intrinsic::bitreverse; + else + return nullptr; + Module *M = I.getParent()->getParent()->getParent(); - Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy); + Function *F = Intrinsic::getDeclaration(M, Intrin, ITy); return CallInst::Create(F, V); } @@ -2265,7 +2298,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { match(Op1, m_And(m_Value(), m_Value())); if (OrOfOrs || OrOfShifts || OrOfAnds) - if (Instruction *BSwap = MatchBSwap(I)) + if (Instruction *BSwap = MatchBSwapOrBitReverse(I)) return BSwap; // (X^C)|Y -> (X|Y)^C iff Y&C == 0 diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 1bb3ad6c534..534f6700815 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -556,7 +556,7 @@ private: Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); - Instruction *MatchBSwap(BinaryOperator &I); + Instruction *MatchBSwapOrBitReverse(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/test/Transforms/InstCombine/bitreverse-recognize.ll b/test/Transforms/InstCombine/bitreverse-recognize.ll new file mode 100644 index 00000000000..fbd5cb6d139 --- /dev/null +++ b/test/Transforms/InstCombine/bitreverse-recognize.ll @@ -0,0 +1,114 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-apple-macosx10.10.0" + +define zeroext i8 @f_u8(i8 zeroext %a) { +; CHECK-LABEL: @f_u8 +; CHECK-NEXT: %[[A:.*]] = call i8 @llvm.bitreverse.i8(i8 %a) +; CHECK-NEXT: ret i8 %[[A]] + %1 = shl i8 %a, 7 + %2 = shl i8 %a, 5 + %3 = and i8 %2, 64 + %4 = shl i8 %a, 3 + %5 = and i8 %4, 32 + %6 = shl i8 %a, 1 + %7 = and i8 %6, 16 + %8 = lshr i8 %a, 1 + %9 = and i8 %8, 8 + %10 = lshr i8 %a, 3 + %11 = and i8 %10, 4 + %12 = lshr i8 %a, 5 + %13 = and i8 %12, 2 + %14 = lshr i8 %a, 7 + %15 = or i8 %14, %1 + %16 = or i8 %15, %3 + %17 = or i8 %16, %5 + %18 = or i8 %17, %7 + %19 = or i8 %18, %9 + %20 = or i8 %19, %11 + %21 = or i8 %20, %13 + ret i8 %21 +} + +; The ANDs with 32 and 64 have been swapped here, so the sequence does not +; completely match a bitreverse. +define zeroext i8 @f_u8_fail(i8 zeroext %a) { +; CHECK-LABEL: @f_u8_fail +; CHECK-NOT: call +; CHECK: ret i8 + %1 = shl i8 %a, 7 + %2 = shl i8 %a, 5 + %3 = and i8 %2, 32 + %4 = shl i8 %a, 3 + %5 = and i8 %4, 64 + %6 = shl i8 %a, 1 + %7 = and i8 %6, 16 + %8 = lshr i8 %a, 1 + %9 = and i8 %8, 8 + %10 = lshr i8 %a, 3 + %11 = and i8 %10, 4 + %12 = lshr i8 %a, 5 + %13 = and i8 %12, 2 + %14 = lshr i8 %a, 7 + %15 = or i8 %14, %1 + %16 = or i8 %15, %3 + %17 = or i8 %16, %5 + %18 = or i8 %17, %7 + %19 = or i8 %18, %9 + %20 = or i8 %19, %11 + %21 = or i8 %20, %13 + ret i8 %21 +} + +define zeroext i16 @f_u16(i16 zeroext %a) { +; CHECK-LABEL: @f_u16 +; CHECK-NEXT: %[[A:.*]] = call i16 @llvm.bitreverse.i16(i16 %a) +; CHECK-NEXT: ret i16 %[[A]] + %1 = shl i16 %a, 15 + %2 = shl i16 %a, 13 + %3 = and i16 %2, 16384 + %4 = shl i16 %a, 11 + %5 = and i16 %4, 8192 + %6 = shl i16 %a, 9 + %7 = and i16 %6, 4096 + %8 = shl i16 %a, 7 + %9 = and i16 %8, 2048 + %10 = shl i16 %a, 5 + %11 = and i16 %10, 1024 + %12 = shl i16 %a, 3 + %13 = and i16 %12, 512 + %14 = shl i16 %a, 1 + %15 = and i16 %14, 256 + %16 = lshr i16 %a, 1 + %17 = and i16 %16, 128 + %18 = lshr i16 %a, 3 + %19 = and i16 %18, 64 + %20 = lshr i16 %a, 5 + %21 = and i16 %20, 32 + %22 = lshr i16 %a, 7 + %23 = and i16 %22, 16 + %24 = lshr i16 %a, 9 + %25 = and i16 %24, 8 + %26 = lshr i16 %a, 11 + %27 = and i16 %26, 4 + %28 = lshr i16 %a, 13 + %29 = and i16 %28, 2 + %30 = lshr i16 %a, 15 + %31 = or i16 %30, %1 + %32 = or i16 %31, %3 + %33 = or i16 %32, %5 + %34 = or i16 %33, %7 + %35 = or i16 %34, %9 + %36 = or i16 %35, %11 + %37 = or i16 %36, %13 + %38 = or i16 %37, %15 + %39 = or i16 %38, %17 + %40 = or i16 %39, %19 + %41 = or i16 %40, %21 + %42 = or i16 %41, %23 + %43 = or i16 %42, %25 + %44 = or i16 %43, %27 + %45 = or i16 %44, %29 + ret i16 %45 +} \ No newline at end of file