X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2FCompression.cpp;h=d02a6b3e6f1bc741f972ce5c55af8f70c28c269b;hb=fb0a122815ac6ed1c746aefb4246616d83f35c10;hp=985c18dbaa9b1b7c2952669ee1486aaeaab087f8;hpb=1e8e9c94da711e14dad56f944b49cf968861a6b8;p=folly.git diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index 985c18db..d02a6b3e 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -27,12 +27,12 @@ #include #if FOLLY_HAVE_LIBSNAPPY -#include #include +#include #endif #if FOLLY_HAVE_LIBZ -#include +#include #endif #if FOLLY_HAVE_LIBLZMA @@ -40,6 +40,7 @@ #endif #if FOLLY_HAVE_LIBZSTD +#define ZSTD_STATIC_LINKING_ONLY #include #endif @@ -54,19 +55,24 @@ #include #include #include +#include #include #include -namespace folly { namespace io { +using folly::io::compression::detail::dataStartsWithLE; +using folly::io::compression::detail::prefixToStringLE; + +namespace folly { +namespace io { Codec::Codec(CodecType type) : type_(type) { } // Ensure consistent behavior in the nullptr case std::unique_ptr Codec::compress(const IOBuf* data) { - uint64_t len = data->computeChainDataLength(); - if (len == 0) { - return IOBuf::create(0); + if (data == nullptr) { + throw std::invalid_argument("Codec: data must not be nullptr"); } + uint64_t len = data->computeChainDataLength(); if (len > maxUncompressedLength()) { throw std::runtime_error("Codec: uncompressed length too large"); } @@ -76,9 +82,6 @@ std::unique_ptr Codec::compress(const IOBuf* data) { std::string Codec::compress(const StringPiece data) { const uint64_t len = data.size(); - if (len == 0) { - return ""; - } if (len > maxUncompressedLength()) { throw std::runtime_error("Codec: uncompressed length too large"); } @@ -89,6 +92,9 @@ std::string Codec::compress(const StringPiece data) { std::unique_ptr Codec::uncompress( const IOBuf* data, Optional uncompressedLength) { + if (data == nullptr) { + throw std::invalid_argument("Codec: data must not be nullptr"); + } if (!uncompressedLength) { if (needsUncompressedLength()) { throw std::invalid_argument("Codec: uncompressed length required"); @@ -176,6 +182,259 @@ std::string Codec::doUncompressString( return output; } +uint64_t Codec::maxCompressedLength(uint64_t uncompressedLength) const { + return doMaxCompressedLength(uncompressedLength); +} + +Optional Codec::getUncompressedLength( + const folly::IOBuf* data, + Optional uncompressedLength) const { + auto const compressedLength = data->computeChainDataLength(); + if (compressedLength == 0) { + if (uncompressedLength.value_or(0) != 0) { + throw std::runtime_error("Invalid uncompressed length"); + } + return 0; + } + return doGetUncompressedLength(data, uncompressedLength); +} + +Optional Codec::doGetUncompressedLength( + const folly::IOBuf*, + Optional uncompressedLength) const { + return uncompressedLength; +} + +bool StreamCodec::needsDataLength() const { + return doNeedsDataLength(); +} + +bool StreamCodec::doNeedsDataLength() const { + return false; +} + +void StreamCodec::assertStateIs(State expected) const { + if (state_ != expected) { + throw std::logic_error(folly::to( + "Codec: state is ", state_, "; expected state ", expected)); + } +} + +void StreamCodec::resetStream(Optional uncompressedLength) { + state_ = State::RESET; + uncompressedLength_ = uncompressedLength; + progressMade_ = true; + doResetStream(); +} + +bool StreamCodec::compressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (state_ == State::RESET && input.empty() && + flushOp == StreamCodec::FlushOp::END && + uncompressedLength().value_or(0) != 0) { + throw std::runtime_error("Codec: invalid uncompressed length"); + } + + if (!uncompressedLength() && needsDataLength()) { + throw std::runtime_error("Codec: uncompressed length required"); + } + if (state_ == State::RESET && !input.empty() && + uncompressedLength() == uint64_t(0)) { + throw std::runtime_error("Codec: invalid uncompressed length"); + } + // Handle input state transitions + switch (flushOp) { + case StreamCodec::FlushOp::NONE: + if (state_ == State::RESET) { + state_ = State::COMPRESS; + } + assertStateIs(State::COMPRESS); + break; + case StreamCodec::FlushOp::FLUSH: + if (state_ == State::RESET || state_ == State::COMPRESS) { + state_ = State::COMPRESS_FLUSH; + } + assertStateIs(State::COMPRESS_FLUSH); + break; + case StreamCodec::FlushOp::END: + if (state_ == State::RESET || state_ == State::COMPRESS) { + state_ = State::COMPRESS_END; + } + assertStateIs(State::COMPRESS_END); + break; + } + size_t const inputSize = input.size(); + size_t const outputSize = output.size(); + bool const done = doCompressStream(input, output, flushOp); + if (!done && inputSize == input.size() && outputSize == output.size()) { + if (!progressMade_) { + throw std::runtime_error("Codec: No forward progress made"); + } + // Throw an exception if there is no progress again next time + progressMade_ = false; + } else { + progressMade_ = true; + } + // Handle output state transitions + if (done) { + if (state_ == State::COMPRESS_FLUSH) { + state_ = State::COMPRESS; + } else if (state_ == State::COMPRESS_END) { + state_ = State::END; + } + // Check internal invariants + DCHECK(input.empty()); + DCHECK(flushOp != StreamCodec::FlushOp::NONE); + } + return done; +} + +bool StreamCodec::uncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (state_ == State::RESET && input.empty()) { + if (uncompressedLength().value_or(0) == 0) { + return true; + } + return false; + } + // Handle input state transitions + if (state_ == State::RESET) { + state_ = State::UNCOMPRESS; + } + assertStateIs(State::UNCOMPRESS); + size_t const inputSize = input.size(); + size_t const outputSize = output.size(); + bool const done = doUncompressStream(input, output, flushOp); + if (!done && inputSize == input.size() && outputSize == output.size()) { + if (!progressMade_) { + throw std::runtime_error("Codec: no forward progress made"); + } + // Throw an exception if there is no progress again next time + progressMade_ = false; + } else { + progressMade_ = true; + } + // Handle output state transitions + if (done) { + state_ = State::END; + } + return done; +} + +static std::unique_ptr addOutputBuffer( + MutableByteRange& output, + uint64_t size) { + DCHECK(output.empty()); + auto buffer = IOBuf::create(size); + buffer->append(buffer->capacity()); + output = {buffer->writableData(), buffer->length()}; + return buffer; +} + +std::unique_ptr StreamCodec::doCompress(IOBuf const* data) { + uint64_t const uncompressedLength = data->computeChainDataLength(); + resetStream(uncompressedLength); + uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength); + + auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB + auto constexpr kDefaultBufferLength = uint64_t(4) << 20; // 4 MB + + MutableByteRange output; + auto buffer = addOutputBuffer( + output, + maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen + : kDefaultBufferLength); + + // Compress the entire IOBuf chain into the IOBuf chain pointed to by buffer + IOBuf const* current = data; + ByteRange input{current->data(), current->length()}; + StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE; + bool done = false; + while (!done) { + while (input.empty() && current->next() != data) { + current = current->next(); + input = {current->data(), current->length()}; + } + if (current->next() == data) { + // This is the last input buffer so end the stream + flushOp = StreamCodec::FlushOp::END; + } + if (output.empty()) { + buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength)); + } + done = compressStream(input, output, flushOp); + if (done) { + DCHECK(input.empty()); + DCHECK(flushOp == StreamCodec::FlushOp::END); + DCHECK_EQ(current->next(), data); + } + } + buffer->prev()->trimEnd(output.size()); + return buffer; +} + +static uint64_t computeBufferLength( + uint64_t const compressedLength, + uint64_t const blockSize) { + uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB + uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength); + return std::min(goodBufferSize, kMaxBufferLength); +} + +std::unique_ptr StreamCodec::doUncompress( + IOBuf const* data, + Optional uncompressedLength) { + auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB + auto constexpr kBlockSize = uint64_t(128) << 10; + auto const defaultBufferLength = + computeBufferLength(data->computeChainDataLength(), kBlockSize); + + uncompressedLength = getUncompressedLength(data, uncompressedLength); + resetStream(uncompressedLength); + + MutableByteRange output; + auto buffer = addOutputBuffer( + output, + (uncompressedLength && *uncompressedLength <= kMaxSingleStepLength + ? *uncompressedLength + : defaultBufferLength)); + + // Uncompress the entire IOBuf chain into the IOBuf chain pointed to by buffer + IOBuf const* current = data; + ByteRange input{current->data(), current->length()}; + StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE; + bool done = false; + while (!done) { + while (input.empty() && current->next() != data) { + current = current->next(); + input = {current->data(), current->length()}; + } + if (current->next() == data) { + // Tell the uncompressor there is no more input (it may optimize) + flushOp = StreamCodec::FlushOp::END; + } + if (output.empty()) { + buffer->prependChain(addOutputBuffer(output, defaultBufferLength)); + } + done = uncompressStream(input, output, flushOp); + } + if (!input.empty()) { + throw std::runtime_error("Codec: Junk after end of data"); + } + + buffer->prev()->trimEnd(output.size()); + if (uncompressedLength && + *uncompressedLength != buffer->computeChainDataLength()) { + throw std::runtime_error("Codec: invalid uncompressed length"); + } + + return buffer; +} + namespace { /** @@ -187,6 +446,7 @@ class NoCompressionCodec final : public Codec { explicit NoCompressionCodec(int level, CodecType type); private: + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; std::unique_ptr doCompress(const IOBuf* data) override; std::unique_ptr doUncompress( const IOBuf* data, @@ -194,17 +454,17 @@ class NoCompressionCodec final : public Codec { }; std::unique_ptr NoCompressionCodec::create(int level, CodecType type) { - return make_unique(level, type); + return std::make_unique(level, type); } NoCompressionCodec::NoCompressionCodec(int level, CodecType type) - : Codec(type) { + : Codec(type) { DCHECK(type == CodecType::NO_COMPRESSION); switch (level) { - case COMPRESSION_LEVEL_DEFAULT: - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_BEST: - level = 0; + case COMPRESSION_LEVEL_DEFAULT: + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_BEST: + level = 0; } if (level != 0) { throw std::invalid_argument(to( @@ -212,6 +472,11 @@ NoCompressionCodec::NoCompressionCodec(int level, CodecType type) } } +uint64_t NoCompressionCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { + return uncompressedLength; +} + std::unique_ptr NoCompressionCodec::doCompress( const IOBuf* data) { return data->clone(); @@ -253,51 +518,10 @@ inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) { return val; } -} // namespace +} // namespace #endif // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA -namespace { -/** - * Reads sizeof(T) bytes, and returns false if not enough bytes are available. - * Returns true if the first n bytes are equal to prefix when interpreted as - * a little endian T. - */ -template -typename std::enable_if::value, bool>::type -dataStartsWithLE(const IOBuf* data, T prefix, uint64_t n = sizeof(T)) { - DCHECK_GT(n, 0); - DCHECK_LE(n, sizeof(T)); - T value; - Cursor cursor{data}; - if (!cursor.tryReadLE(value)) { - return false; - } - const T mask = n == sizeof(T) ? T(-1) : (T(1) << (8 * n)) - 1; - return prefix == (value & mask); -} - -template -typename std::enable_if::value, std::string>::type -prefixToStringLE(T prefix, uint64_t n = sizeof(T)) { - DCHECK_GT(n, 0); - DCHECK_LE(n, sizeof(T)); - prefix = Endian::little(prefix); - std::string result; - result.resize(n); - memcpy(&result[0], &prefix, n); - return result; -} - -static uint64_t computeBufferLength( - uint64_t const compressedLength, - uint64_t const blockSize) { - uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB - uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength); - return std::min(goodBufferSize, kMaxBufferLength); -} -} // namespace - #if FOLLY_HAVE_LIBLZ4 /** @@ -311,6 +535,7 @@ class LZ4Codec final : public Codec { private: bool doNeedsUncompressedLength() const override; uint64_t doMaxUncompressedLength() const override; + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; } @@ -323,20 +548,20 @@ class LZ4Codec final : public Codec { }; std::unique_ptr LZ4Codec::create(int level, CodecType type) { - return make_unique(level, type); + return std::make_unique(level, type); } LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) { DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_DEFAULT: - level = 1; - break; - case COMPRESSION_LEVEL_BEST: - level = 2; - break; + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_DEFAULT: + level = 1; + break; + case COMPRESSION_LEVEL_BEST: + level = 2; + break; } if (level < 1 || level > 2) { throw std::invalid_argument(to( @@ -360,6 +585,11 @@ uint64_t LZ4Codec::doMaxUncompressedLength() const { return LZ4_MAX_INPUT_SIZE; } +uint64_t LZ4Codec::doMaxCompressedLength(uint64_t uncompressedLength) const { + return LZ4_compressBound(uncompressedLength) + + (encodeSize() ? kMaxVarintLength64 : 0); +} + std::unique_ptr LZ4Codec::doCompress(const IOBuf* data) { IOBuf clone; if (data->isChained()) { @@ -368,8 +598,7 @@ std::unique_ptr LZ4Codec::doCompress(const IOBuf* data) { data = &clone; } - uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0; - auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length())); + auto out = IOBuf::create(maxCompressedLength(data->length())); if (encodeSize()) { encodeVarintToIOBuf(data->length(), out.get()); } @@ -445,13 +674,15 @@ class LZ4FrameCodec final : public Codec { public: static std::unique_ptr create(int level, CodecType type); explicit LZ4FrameCodec(int level, CodecType type); - ~LZ4FrameCodec(); + ~LZ4FrameCodec() override; std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) const override; private: + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; + std::unique_ptr doCompress(const IOBuf* data) override; std::unique_ptr doUncompress( const IOBuf* data, @@ -468,7 +699,7 @@ class LZ4FrameCodec final : public Codec { /* static */ std::unique_ptr LZ4FrameCodec::create( int level, CodecType type) { - return make_unique(level, type); + return std::make_unique(level, type); } static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204; @@ -481,6 +712,14 @@ bool LZ4FrameCodec::canUncompress(const IOBuf* data, Optional) const { return dataStartsWithLE(data, kLZ4FrameMagicLE); } +uint64_t LZ4FrameCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { + LZ4F_preferences_t prefs{}; + prefs.compressionLevel = level_; + prefs.frameInfo.contentSize = uncompressedLength; + return LZ4F_compressFrameBound(uncompressedLength, &prefs); +} + static size_t lz4FrameThrowOnError(size_t code) { if (LZ4F_isError(code)) { throw std::runtime_error( @@ -535,7 +774,7 @@ std::unique_ptr LZ4FrameCodec::doCompress(const IOBuf* data) { prefs.compressionLevel = level_; prefs.frameInfo.contentSize = uncompressedLength; // Compress - auto buf = IOBuf::create(LZ4F_compressFrameBound(uncompressedLength, &prefs)); + auto buf = IOBuf::create(maxCompressedLength(uncompressedLength)); const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame( buf->writableTail(), buf->tailroom(), @@ -659,6 +898,7 @@ class SnappyCodec final : public Codec { private: uint64_t doMaxUncompressedLength() const override; + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; std::unique_ptr doCompress(const IOBuf* data) override; std::unique_ptr doUncompress( const IOBuf* data, @@ -666,16 +906,16 @@ class SnappyCodec final : public Codec { }; std::unique_ptr SnappyCodec::create(int level, CodecType type) { - return make_unique(level, type); + return std::make_unique(level, type); } SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) { DCHECK(type == CodecType::SNAPPY); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_DEFAULT: - case COMPRESSION_LEVEL_BEST: - level = 1; + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_DEFAULT: + case COMPRESSION_LEVEL_BEST: + level = 1; } if (level != 1) { throw std::invalid_argument(to( @@ -688,10 +928,13 @@ uint64_t SnappyCodec::doMaxUncompressedLength() const { return std::numeric_limits::max(); } +uint64_t SnappyCodec::doMaxCompressedLength(uint64_t uncompressedLength) const { + return snappy::MaxCompressedLength(uncompressedLength); +} + std::unique_ptr SnappyCodec::doCompress(const IOBuf* data) { IOBufSnappySource source(data); - auto out = - IOBuf::create(snappy::MaxCompressedLength(source.Available())); + auto out = IOBuf::create(maxCompressedLength(source.Available())); snappy::UncheckedByteArraySink sink(reinterpret_cast( out->writableTail())); @@ -734,599 +977,361 @@ std::unique_ptr SnappyCodec::doUncompress( #endif // FOLLY_HAVE_LIBSNAPPY -#if FOLLY_HAVE_LIBZ +#if FOLLY_HAVE_LIBLZMA + /** - * Zlib codec + * LZMA2 compression */ -class ZlibCodec final : public Codec { +class LZMA2StreamCodec final : public StreamCodec { public: - static std::unique_ptr create(int level, CodecType type); - explicit ZlibCodec(int level, CodecType type); + static std::unique_ptr createCodec(int level, CodecType type); + static std::unique_ptr createStream(int level, CodecType type); + explicit LZMA2StreamCodec(int level, CodecType type); + ~LZMA2StreamCodec() override; std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) const override; private: - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; + bool doNeedsDataLength() const override; + uint64_t doMaxUncompressedLength() const override; + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; + + bool encodeSize() const { + return type() == CodecType::LZMA2_VARINT_SIZE; + } + + void doResetStream() override; + bool doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + bool doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + + void resetCStream(); + void resetDStream(); + + bool decodeAndCheckVarint(ByteRange& input); + bool flushVarintBuffer(MutableByteRange& output); + void resetVarintBuffer(); - std::unique_ptr addOutputBuffer(z_stream* stream, uint32_t length); - bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength); + Optional cstream_{}; + Optional dstream_{}; + + std::array varintBuffer_; + ByteRange varintToEncode_; + size_t varintBufferPos_{0}; int level_; + bool needReset_{true}; + bool needDecodeSize_{false}; }; -static constexpr uint16_t kGZIPMagicLE = 0x8B1F; - -std::vector ZlibCodec::validPrefixes() const { - if (type() == CodecType::ZLIB) { - // Zlib streams start with a 2 byte header. - // - // 0 1 - // +---+---+ - // |CMF|FLG| - // +---+---+ - // - // We won't restrict the values of any sub-fields except as described below. - // - // The lowest 4 bits of CMF is the compression method (CM). - // CM == 0x8 is the deflate compression method, which is currently the only - // supported compression method, so any valid prefix must have CM == 0x8. - // - // The lowest 5 bits of FLG is FCHECK. - // FCHECK must be such that the two header bytes are a multiple of 31 when - // interpreted as a big endian 16-bit number. - std::vector result; - // 16 values for the first byte, 8 values for the second byte. - // There are also 4 combinations where both 0x00 and 0x1F work as FCHECK. - result.reserve(132); - // Select all values for the CMF byte that use the deflate algorithm 0x8. - for (uint32_t first = 0x0800; first <= 0xF800; first += 0x1000) { - // Select all values for the FLG, but leave FCHECK as 0 since it's fixed. - for (uint32_t second = 0x00; second <= 0xE0; second += 0x20) { - uint16_t prefix = first | second; - // Compute FCHECK. - prefix += 31 - (prefix % 31); - result.push_back(prefixToStringLE(Endian::big(prefix))); - // zlib won't produce this, but it is a valid prefix. - if ((prefix & 0x1F) == 31) { - prefix -= 31; - result.push_back(prefixToStringLE(Endian::big(prefix))); - } - } - } - return result; - } else { - // The gzip frame starts with 2 magic bytes. - return {prefixToStringLE(kGZIPMagicLE)}; +static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD; +static constexpr unsigned kLZMA2MagicBytes = 6; + +std::vector LZMA2StreamCodec::validPrefixes() const { + if (type() == CodecType::LZMA2_VARINT_SIZE) { + return {}; } + return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)}; } -bool ZlibCodec::canUncompress(const IOBuf* data, Optional) const { - if (type() == CodecType::ZLIB) { - uint16_t value; - Cursor cursor{data}; - if (!cursor.tryReadBE(value)) { - return false; - } - // zlib compressed if using deflate and is a multiple of 31. - return (value & 0x0F00) == 0x0800 && value % 31 == 0; - } else { - return dataStartsWithLE(data, kGZIPMagicLE); +bool LZMA2StreamCodec::doNeedsDataLength() const { + return encodeSize(); +} + +bool LZMA2StreamCodec::canUncompress(const IOBuf* data, Optional) + const { + if (type() == CodecType::LZMA2_VARINT_SIZE) { + return false; } + // Returns false for all inputs less than 8 bytes. + // This is okay, because no valid LZMA2 streams are less than 8 bytes. + return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes); +} + +std::unique_ptr LZMA2StreamCodec::createCodec( + int level, + CodecType type) { + return make_unique(level, type); } -std::unique_ptr ZlibCodec::create(int level, CodecType type) { - return make_unique(level, type); +std::unique_ptr LZMA2StreamCodec::createStream( + int level, + CodecType type) { + return make_unique(level, type); } -ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) { - DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP); +LZMA2StreamCodec::LZMA2StreamCodec(int level, CodecType type) + : StreamCodec(type) { + DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - level = 1; - break; - case COMPRESSION_LEVEL_DEFAULT: - level = Z_DEFAULT_COMPRESSION; - break; - case COMPRESSION_LEVEL_BEST: - level = 9; - break; - } - if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) { - throw std::invalid_argument(to( - "ZlibCodec: invalid level: ", level)); + case COMPRESSION_LEVEL_FASTEST: + level = 0; + break; + case COMPRESSION_LEVEL_DEFAULT: + level = LZMA_PRESET_DEFAULT; + break; + case COMPRESSION_LEVEL_BEST: + level = 9; + break; + } + if (level < 0 || level > 9) { + throw std::invalid_argument( + to("LZMA2Codec: invalid level: ", level)); } level_ = level; } -std::unique_ptr ZlibCodec::addOutputBuffer(z_stream* stream, - uint32_t length) { - CHECK_EQ(stream->avail_out, 0); +LZMA2StreamCodec::~LZMA2StreamCodec() { + if (cstream_) { + lzma_end(cstream_.get_pointer()); + cstream_.clear(); + } + if (dstream_) { + lzma_end(dstream_.get_pointer()); + dstream_.clear(); + } +} - auto buf = IOBuf::create(length); - buf->append(buf->capacity()); +uint64_t LZMA2StreamCodec::doMaxUncompressedLength() const { + // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)" + return uint64_t(1) << 63; +} - stream->next_out = buf->writableData(); - stream->avail_out = buf->length(); +uint64_t LZMA2StreamCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { + return lzma_stream_buffer_bound(uncompressedLength) + + (encodeSize() ? kMaxVarintLength64 : 0); +} - return buf; +void LZMA2StreamCodec::doResetStream() { + needReset_ = true; } -bool ZlibCodec::doInflate(z_stream* stream, - IOBuf* head, - uint32_t bufferLength) { - if (stream->avail_out == 0) { - head->prependChain(addOutputBuffer(stream, bufferLength)); +void LZMA2StreamCodec::resetCStream() { + if (!cstream_) { + cstream_.assign(LZMA_STREAM_INIT); } + lzma_ret const rc = + lzma_easy_encoder(cstream_.get_pointer(), level_, LZMA_CHECK_NONE); + if (rc != LZMA_OK) { + throw std::runtime_error(folly::to( + "LZMA2StreamCodec: lzma_easy_encoder error: ", rc)); + } +} - int rc = inflate(stream, Z_NO_FLUSH); +void LZMA2StreamCodec::resetDStream() { + if (!dstream_) { + dstream_.assign(LZMA_STREAM_INIT); + } + lzma_ret const rc = lzma_auto_decoder( + dstream_.get_pointer(), std::numeric_limits::max(), 0); + if (rc != LZMA_OK) { + throw std::runtime_error(folly::to( + "LZMA2StreamCodec: lzma_auto_decoder error: ", rc)); + } +} +static lzma_ret lzmaThrowOnError(lzma_ret const rc) { switch (rc) { - case Z_OK: - break; - case Z_STREAM_END: - return true; - case Z_BUF_ERROR: - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: - throw std::runtime_error(to( - "ZlibCodec: inflate error: ", rc, ": ", stream->msg)); - default: - CHECK(false) << rc << ": " << stream->msg; + case LZMA_OK: + case LZMA_STREAM_END: + case LZMA_BUF_ERROR: // not fatal: returned if no progress was made twice + return rc; + default: + throw std::runtime_error( + to("LZMA2StreamCodec: error: ", rc)); } - - return false; } -std::unique_ptr ZlibCodec::doCompress(const IOBuf* data) { - z_stream stream; - stream.zalloc = nullptr; - stream.zfree = nullptr; - stream.opaque = nullptr; - - // Using deflateInit2() to support gzip. "The windowBits parameter is the - // base two logarithm of the maximum window size (...) The default value is - // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer - // around the compressed data instead of a zlib wrapper. The gzip header - // will have no file name, no extra data, no comment, no modification time - // (set to zero), no header crc, and the operating system will be set to 255 - // (unknown)." - int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); - // All other parameters (method, memLevel, strategy) get default values from - // the zlib manual. - int rc = deflateInit2(&stream, - level_, - Z_DEFLATED, - windowBits, - /* memLevel */ 8, - Z_DEFAULT_STRATEGY); - if (rc != Z_OK) { - throw std::runtime_error(to( - "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg)); +static lzma_action lzmaTranslateFlush(StreamCodec::FlushOp flush) { + switch (flush) { + case StreamCodec::FlushOp::NONE: + return LZMA_RUN; + case StreamCodec::FlushOp::FLUSH: + return LZMA_SYNC_FLUSH; + case StreamCodec::FlushOp::END: + return LZMA_FINISH; + default: + throw std::invalid_argument("LZMA2StreamCodec: Invalid flush"); } +} - stream.next_in = stream.next_out = nullptr; - stream.avail_in = stream.avail_out = 0; - stream.total_in = stream.total_out = 0; +/** + * Flushes the varint buffer. + * Advances output by the number of bytes written. + * Returns true when flushing is complete. + */ +bool LZMA2StreamCodec::flushVarintBuffer(MutableByteRange& output) { + if (varintToEncode_.empty()) { + return true; + } + const size_t numBytesToCopy = std::min(varintToEncode_.size(), output.size()); + if (numBytesToCopy > 0) { + memcpy(output.data(), varintToEncode_.data(), numBytesToCopy); + } + varintToEncode_.advance(numBytesToCopy); + output.advance(numBytesToCopy); + return varintToEncode_.empty(); +} + +bool LZMA2StreamCodec::doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + resetCStream(); + if (encodeSize()) { + varintBufferPos_ = 0; + size_t const varintSize = + encodeVarint(*uncompressedLength(), varintBuffer_.data()); + varintToEncode_ = {varintBuffer_.data(), varintSize}; + } + needReset_ = false; + } - bool success = false; + if (!flushVarintBuffer(output)) { + return false; + } + cstream_->next_in = const_cast(input.data()); + cstream_->avail_in = input.size(); + cstream_->next_out = output.data(); + cstream_->avail_out = output.size(); SCOPE_EXIT { - rc = deflateEnd(&stream); - // If we're here because of an exception, it's okay if some data - // got dropped. - CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) - << rc << ": " << stream.msg; + input.uncheckedAdvance(input.size() - cstream_->avail_in); + output.uncheckedAdvance(output.size() - cstream_->avail_out); }; + lzma_ret const rc = lzmaThrowOnError( + lzma_code(cstream_.get_pointer(), lzmaTranslateFlush(flushOp))); + switch (flushOp) { + case StreamCodec::FlushOp::NONE: + return false; + case StreamCodec::FlushOp::FLUSH: + return cstream_->avail_in == 0 && cstream_->avail_out != 0; + case StreamCodec::FlushOp::END: + return rc == LZMA_STREAM_END; + default: + throw std::invalid_argument("LZMA2StreamCodec: invalid FlushOp"); + } +} - uint64_t uncompressedLength = data->computeChainDataLength(); - uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength); +/** + * Attempts to decode a varint from input. + * The function advances input by the number of bytes read. + * + * If there are too many bytes and the varint is not valid, throw a + * runtime_error. + * + * If the uncompressed length was provided and a decoded varint does not match + * the provided length, throw a runtime_error. + * + * Returns true if the varint was successfully decoded and matches the + * uncompressed length if provided, and false if more bytes are needed. + */ +bool LZMA2StreamCodec::decodeAndCheckVarint(ByteRange& input) { + if (input.empty()) { + return false; + } + size_t const numBytesToCopy = + std::min(kMaxVarintLength64 - varintBufferPos_, input.size()); + memcpy(varintBuffer_.data() + varintBufferPos_, input.data(), numBytesToCopy); - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(4) << 20; // 4MiB + size_t const rangeSize = varintBufferPos_ + numBytesToCopy; + ByteRange range{varintBuffer_.data(), rangeSize}; + auto const ret = tryDecodeVarint(range); - auto out = addOutputBuffer( - &stream, - (maxCompressedLength <= maxSingleStepLength ? - maxCompressedLength : - defaultBufferLength)); - - for (auto& range : *data) { - uint64_t remaining = range.size(); - uint64_t written = 0; - while (remaining) { - uint32_t step = (remaining > maxSingleStepLength ? - maxSingleStepLength : remaining); - stream.next_in = const_cast(range.data() + written); - stream.avail_in = step; - remaining -= step; - written += step; - - while (stream.avail_in != 0) { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = deflate(&stream, Z_NO_FLUSH); - - CHECK_EQ(rc, Z_OK) << stream.msg; - } + if (ret.hasValue()) { + size_t const varintSize = rangeSize - range.size(); + input.advance(varintSize - varintBufferPos_); + if (uncompressedLength() && *uncompressedLength() != ret.value()) { + throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length"); } + return true; + } else if (ret.error() == DecodeVarintError::TooManyBytes) { + throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length"); + } else { + // Too few bytes + input.advance(numBytesToCopy); + varintBufferPos_ += numBytesToCopy; + return false; } +} - do { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); +bool LZMA2StreamCodec::doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + resetDStream(); + needReset_ = false; + needDecodeSize_ = encodeSize(); + if (encodeSize()) { + // Reset buffer + varintBufferPos_ = 0; } + } - rc = deflate(&stream, Z_FINISH); - } while (rc == Z_OK); + if (needDecodeSize_) { + // Try decoding the varint. If the input does not contain the entire varint, + // buffer the input. If the varint can not be decoded, fail. + if (!decodeAndCheckVarint(input)) { + return false; + } + needDecodeSize_ = false; + } - CHECK_EQ(rc, Z_STREAM_END) << stream.msg; + dstream_->next_in = const_cast(input.data()); + dstream_->avail_in = input.size(); + dstream_->next_out = output.data(); + dstream_->avail_out = output.size(); + SCOPE_EXIT { + input.advance(input.size() - dstream_->avail_in); + output.advance(output.size() - dstream_->avail_out); + }; - out->prev()->trimEnd(stream.avail_out); - - success = true; // we survived - - return out; -} - -std::unique_ptr ZlibCodec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - z_stream stream; - stream.zalloc = nullptr; - stream.zfree = nullptr; - stream.opaque = nullptr; - - // "The windowBits parameter is the base two logarithm of the maximum window - // size (...) The default value is 15 (...) add 16 to decode only the gzip - // format (the zlib format will return a Z_DATA_ERROR)." - int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); - int rc = inflateInit2(&stream, windowBits); - if (rc != Z_OK) { - throw std::runtime_error(to( - "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg)); - } - - stream.next_in = stream.next_out = nullptr; - stream.avail_in = stream.avail_out = 0; - stream.total_in = stream.total_out = 0; - - bool success = false; - - SCOPE_EXIT { - rc = inflateEnd(&stream); - // If we're here because of an exception, it's okay if some data - // got dropped. - CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) - << rc << ": " << stream.msg; - }; - - // Max 64MiB in one go - constexpr uint64_t maxSingleStepLength = uint64_t(64) << 20; // 64MiB - constexpr uint64_t kBlockSize = uint64_t(32) << 10; // 32 KiB - const uint64_t defaultBufferLength = - computeBufferLength(data->computeChainDataLength(), kBlockSize); - - auto out = addOutputBuffer( - &stream, - ((uncompressedLength && *uncompressedLength <= maxSingleStepLength) - ? *uncompressedLength - : defaultBufferLength)); - - bool streamEnd = false; - for (auto& range : *data) { - if (range.empty()) { - continue; - } - - stream.next_in = const_cast(range.data()); - stream.avail_in = range.size(); - - while (stream.avail_in != 0) { - if (streamEnd) { - throw std::runtime_error(to( - "ZlibCodec: junk after end of data")); - } - - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); - } - } - - while (!streamEnd) { - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); - } - - out->prev()->trimEnd(stream.avail_out); - - if (uncompressedLength && *uncompressedLength != stream.total_out) { - throw std::runtime_error( - to("ZlibCodec: invalid uncompressed length")); - } - - success = true; // we survived - - return out; -} - -#endif // FOLLY_HAVE_LIBZ - -#if FOLLY_HAVE_LIBLZMA - -/** - * LZMA2 compression - */ -class LZMA2Codec final : public Codec { - public: - static std::unique_ptr create(int level, CodecType type); - explicit LZMA2Codec(int level, CodecType type); - - std::vector validPrefixes() const override; - bool canUncompress(const IOBuf* data, Optional uncompressedLength) - const override; - - private: - bool doNeedsUncompressedLength() const override; - uint64_t doMaxUncompressedLength() const override; - - bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; } - - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; - - std::unique_ptr addOutputBuffer(lzma_stream* stream, size_t length); - bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength); - - int level_; -}; - -static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD; -static constexpr unsigned kLZMA2MagicBytes = 6; - -std::vector LZMA2Codec::validPrefixes() const { - if (type() == CodecType::LZMA2_VARINT_SIZE) { - return {}; - } - return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)}; -} - -bool LZMA2Codec::canUncompress(const IOBuf* data, Optional) const { - if (type() == CodecType::LZMA2_VARINT_SIZE) { - return false; - } - // Returns false for all inputs less than 8 bytes. - // This is okay, because no valid LZMA2 streams are less than 8 bytes. - return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes); -} - -std::unique_ptr LZMA2Codec::create(int level, CodecType type) { - return make_unique(level, type); -} - -LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) { - DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE); - switch (level) { - case COMPRESSION_LEVEL_FASTEST: - level = 0; - break; - case COMPRESSION_LEVEL_DEFAULT: - level = LZMA_PRESET_DEFAULT; - break; - case COMPRESSION_LEVEL_BEST: - level = 9; - break; - } - if (level < 0 || level > 9) { - throw std::invalid_argument(to( - "LZMA2Codec: invalid level: ", level)); - } - level_ = level; -} - -bool LZMA2Codec::doNeedsUncompressedLength() const { - return false; -} - -uint64_t LZMA2Codec::doMaxUncompressedLength() const { - // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)" - return uint64_t(1) << 63; -} - -std::unique_ptr LZMA2Codec::addOutputBuffer( - lzma_stream* stream, - size_t length) { - - CHECK_EQ(stream->avail_out, 0); - - auto buf = IOBuf::create(length); - buf->append(buf->capacity()); - - stream->next_out = buf->writableData(); - stream->avail_out = buf->length(); - - return buf; -} - -std::unique_ptr LZMA2Codec::doCompress(const IOBuf* data) { lzma_ret rc; - lzma_stream stream = LZMA_STREAM_INIT; - - rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE); - if (rc != LZMA_OK) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_easy_encoder error: ", rc)); - } - - SCOPE_EXIT { lzma_end(&stream); }; - - uint64_t uncompressedLength = data->computeChainDataLength(); - uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength); - - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(4) << 20; // 4MiB - - auto out = addOutputBuffer( - &stream, - (maxCompressedLength <= maxSingleStepLength ? - maxCompressedLength : - defaultBufferLength)); - - if (encodeSize()) { - auto size = IOBuf::createCombined(kMaxVarintLength64); - encodeVarintToIOBuf(uncompressedLength, size.get()); - size->appendChain(std::move(out)); - out = std::move(size); - } - - for (auto& range : *data) { - if (range.empty()) { - continue; - } - - stream.next_in = const_cast(range.data()); - stream.avail_in = range.size(); - - while (stream.avail_in != 0) { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = lzma_code(&stream, LZMA_RUN); - - if (rc != LZMA_OK) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_code error: ", rc)); - } - } - } - - do { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = lzma_code(&stream, LZMA_FINISH); - } while (rc == LZMA_OK); - - if (rc != LZMA_STREAM_END) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_code ended with error: ", rc)); + switch (flushOp) { + case StreamCodec::FlushOp::NONE: + case StreamCodec::FlushOp::FLUSH: + rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_RUN)); + break; + case StreamCodec::FlushOp::END: + rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_FINISH)); + break; + default: + throw std::invalid_argument("LZMA2StreamCodec: invalid flush"); } - - out->prev()->trimEnd(stream.avail_out); - - return out; + return rc == LZMA_STREAM_END; } +#endif // FOLLY_HAVE_LIBLZMA -bool LZMA2Codec::doInflate(lzma_stream* stream, - IOBuf* head, - size_t bufferLength) { - if (stream->avail_out == 0) { - head->prependChain(addOutputBuffer(stream, bufferLength)); - } - - lzma_ret rc = lzma_code(stream, LZMA_RUN); - - switch (rc) { - case LZMA_OK: - break; - case LZMA_STREAM_END: - return true; - default: - throw std::runtime_error(to( - "LZMA2Codec: lzma_code error: ", rc)); - } +#ifdef FOLLY_HAVE_LIBZSTD - return false; +namespace { +void zstdFreeCStream(ZSTD_CStream* zcs) { + ZSTD_freeCStream(zcs); } -std::unique_ptr LZMA2Codec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - lzma_ret rc; - lzma_stream stream = LZMA_STREAM_INIT; - - rc = lzma_auto_decoder(&stream, std::numeric_limits::max(), 0); - if (rc != LZMA_OK) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_auto_decoder error: ", rc)); - } - - SCOPE_EXIT { lzma_end(&stream); }; - - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB - - folly::io::Cursor cursor(data); - if (encodeSize()) { - const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor); - if (uncompressedLength && *uncompressedLength != actualUncompressedLength) { - throw std::runtime_error("LZMA2Codec: invalid uncompressed length"); - } - uncompressedLength = actualUncompressedLength; - } - - auto out = addOutputBuffer( - &stream, - ((uncompressedLength && *uncompressedLength <= maxSingleStepLength) - ? *uncompressedLength - : defaultBufferLength)); - - bool streamEnd = false; - auto buf = cursor.peekBytes(); - while (!buf.empty()) { - stream.next_in = const_cast(buf.data()); - stream.avail_in = buf.size(); - - while (stream.avail_in != 0) { - if (streamEnd) { - throw std::runtime_error(to( - "LZMA2Codec: junk after end of data")); - } - - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); - } - - cursor.skip(buf.size()); - buf = cursor.peekBytes(); - } - - while (!streamEnd) { - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); - } - - out->prev()->trimEnd(stream.avail_out); - - if (uncompressedLength && *uncompressedLength != stream.total_out) { - throw std::runtime_error( - to("LZMA2Codec: invalid uncompressed length")); - } - - return out; +void zstdFreeDStream(ZSTD_DStream* zds) { + ZSTD_freeDStream(zds); +} } - -#endif // FOLLY_HAVE_LIBLZMA - -#ifdef FOLLY_HAVE_LIBZSTD /** * ZSTD compression */ -class ZSTDCodec final : public Codec { +class ZSTDStreamCodec final : public StreamCodec { public: - static std::unique_ptr create(int level, CodecType); - explicit ZSTDCodec(int level, CodecType type); + static std::unique_ptr createCodec(int level, CodecType); + static std::unique_ptr createStream(int level, CodecType); + explicit ZSTDStreamCodec(int level, CodecType type); std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) @@ -1334,29 +1339,62 @@ class ZSTDCodec final : public Codec { private: bool doNeedsUncompressedLength() const override; - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; + Optional doGetUncompressedLength( + IOBuf const* data, + Optional uncompressedLength) const override; + + void doResetStream() override; + bool doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + bool doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + + void resetCStream(); + void resetDStream(); + + bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const; + bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const; int level_; + bool needReset_{true}; + std::unique_ptr< + ZSTD_CStream, + folly::static_function_deleter> + cstream_{nullptr}; + std::unique_ptr< + ZSTD_DStream, + folly::static_function_deleter> + dstream_{nullptr}; }; static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528; -std::vector ZSTDCodec::validPrefixes() const { +std::vector ZSTDStreamCodec::validPrefixes() const { return {prefixToStringLE(kZSTDMagicLE)}; } -bool ZSTDCodec::canUncompress(const IOBuf* data, Optional) const { +bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional) + const { return dataStartsWithLE(data, kZSTDMagicLE); } -std::unique_ptr ZSTDCodec::create(int level, CodecType type) { - return make_unique(level, type); +std::unique_ptr ZSTDStreamCodec::createCodec(int level, CodecType type) { + return make_unique(level, type); +} + +std::unique_ptr ZSTDStreamCodec::createStream( + int level, + CodecType type) { + return make_unique(level, type); } -ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { +ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type) + : StreamCodec(type) { DCHECK(type == CodecType::ZSTD); switch (level) { case COMPRESSION_LEVEL_FASTEST: @@ -1376,10 +1414,15 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { level_ = level; } -bool ZSTDCodec::doNeedsUncompressedLength() const { +bool ZSTDStreamCodec::doNeedsUncompressedLength() const { return false; } +uint64_t ZSTDStreamCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { + return ZSTD_compressBound(uncompressedLength); +} + void zstdThrowIfError(size_t rc) { if (!ZSTD_isError(rc)) { return; @@ -1388,162 +1431,160 @@ void zstdThrowIfError(size_t rc) { to("ZSTD returned an error: ", ZSTD_getErrorName(rc))); } -std::unique_ptr ZSTDCodec::doCompress(const IOBuf* data) { - // Support earlier versions of the codec (working with a single IOBuf, - // and using ZSTD_decompress which requires ZSTD frame to contain size, - // which isn't populated by streaming API). - if (!data->isChained()) { - auto out = IOBuf::createCombined(ZSTD_compressBound(data->length())); - const auto rc = ZSTD_compress( - out->writableData(), - out->capacity(), - data->data(), - data->length(), - level_); - zstdThrowIfError(rc); - out->append(rc); - return out; - } - - auto zcs = ZSTD_createCStream(); - SCOPE_EXIT { - ZSTD_freeCStream(zcs); - }; - - auto rc = ZSTD_initCStream(zcs, level_); - zstdThrowIfError(rc); - - Cursor cursor(data); - auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength())); - - ZSTD_outBuffer out; - out.dst = result->writableTail(); - out.size = result->capacity(); - out.pos = 0; - - for (auto buffer = cursor.peekBytes(); !buffer.empty();) { - ZSTD_inBuffer in; - in.src = buffer.data(); - in.size = buffer.size(); - for (in.pos = 0; in.pos != in.size;) { - rc = ZSTD_compressStream(zcs, &out, &in); - zstdThrowIfError(rc); +Optional ZSTDStreamCodec::doGetUncompressedLength( + IOBuf const* data, + Optional uncompressedLength) const { + // Read decompressed size from frame if available in first IOBuf. + auto const decompressedSize = + ZSTD_getDecompressedSize(data->data(), data->length()); + if (decompressedSize != 0) { + if (uncompressedLength && *uncompressedLength != decompressedSize) { + throw std::runtime_error("ZSTD: invalid uncompressed length"); } - cursor.skip(in.size); - buffer = cursor.peekBytes(); + uncompressedLength = decompressedSize; } + return uncompressedLength; +} - rc = ZSTD_endStream(zcs, &out); - zstdThrowIfError(rc); - CHECK_EQ(rc, 0); - - result->append(out.pos); - return result; +void ZSTDStreamCodec::doResetStream() { + needReset_ = true; } -static std::unique_ptr zstdUncompressBuffer( - const IOBuf* data, - Optional uncompressedLength) { - // Check preconditions - DCHECK(!data->isChained()); - DCHECK(uncompressedLength.hasValue()); - - auto uncompressed = IOBuf::create(*uncompressedLength); - const auto decompressedSize = ZSTD_decompress( - uncompressed->writableTail(), - uncompressed->tailroom(), - data->data(), - data->length()); - zstdThrowIfError(decompressedSize); - if (decompressedSize != uncompressedLength) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); +bool ZSTDStreamCodec::tryBlockCompress( + ByteRange& input, + MutableByteRange& output) const { + DCHECK(needReset_); + // We need to know that we have enough output space to use block compression + if (output.size() < ZSTD_compressBound(input.size())) { + return false; } - uncompressed->append(decompressedSize); - return uncompressed; + size_t const length = ZSTD_compress( + output.data(), output.size(), input.data(), input.size(), level_); + zstdThrowIfError(length); + input.uncheckedAdvance(input.size()); + output.uncheckedAdvance(length); + return true; } -static std::unique_ptr zstdUncompressStream( - const IOBuf* data, - Optional uncompressedLength) { - auto zds = ZSTD_createDStream(); +void ZSTDStreamCodec::resetCStream() { + if (!cstream_) { + cstream_.reset(ZSTD_createCStream()); + if (!cstream_) { + throw std::bad_alloc{}; + } + } + // Advanced API usage works for all supported versions of zstd. + // Required to set contentSizeFlag. + auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0); + params.fParams.contentSizeFlag = uncompressedLength().hasValue(); + zstdThrowIfError(ZSTD_initCStream_advanced( + cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0))); +} + +bool ZSTDStreamCodec::doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + // If we are given all the input in one chunk try to use block compression + if (flushOp == StreamCodec::FlushOp::END && + tryBlockCompress(input, output)) { + return true; + } + resetCStream(); + needReset_ = false; + } + ZSTD_inBuffer in = {input.data(), input.size(), 0}; + ZSTD_outBuffer out = {output.data(), output.size(), 0}; SCOPE_EXIT { - ZSTD_freeDStream(zds); + input.uncheckedAdvance(in.pos); + output.uncheckedAdvance(out.pos); }; - - auto rc = ZSTD_initDStream(zds); - zstdThrowIfError(rc); - - ZSTD_outBuffer out{}; - ZSTD_inBuffer in{}; - - auto outputSize = uncompressedLength.value_or(ZSTD_DStreamOutSize()); - - IOBufQueue queue(IOBufQueue::cacheChainLength()); - - Cursor cursor(data); - for (rc = 0;;) { - if (in.pos == in.size) { - auto buffer = cursor.peekBytes(); - in.src = buffer.data(); - in.size = buffer.size(); - in.pos = 0; - cursor.skip(in.size); - if (rc > 1 && in.size == 0) { - throw std::runtime_error(to("ZSTD: incomplete input")); - } + if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) { + zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in)); + } + if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) { + size_t rc; + switch (flushOp) { + case StreamCodec::FlushOp::FLUSH: + rc = ZSTD_flushStream(cstream_.get(), &out); + break; + case StreamCodec::FlushOp::END: + rc = ZSTD_endStream(cstream_.get(), &out); + break; + default: + throw std::invalid_argument("ZSTD: invalid FlushOp"); } - if (out.pos == out.size) { - if (out.pos != 0) { - queue.postallocate(out.pos); - } - auto buffer = queue.preallocate(outputSize, outputSize); - out.dst = buffer.first; - out.size = buffer.second; - out.pos = 0; - outputSize = ZSTD_DStreamOutSize(); - } - rc = ZSTD_decompressStream(zds, &out, &in); zstdThrowIfError(rc); if (rc == 0) { - break; + return true; } } - if (out.pos != 0) { - queue.postallocate(out.pos); - } - if (in.pos != in.size || !cursor.isAtEnd()) { - throw std::runtime_error("ZSTD: junk after end of data"); - } - if (uncompressedLength && queue.chainLength() != *uncompressedLength) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); - } + return false; +} - return queue.move(); +bool ZSTDStreamCodec::tryBlockUncompress( + ByteRange& input, + MutableByteRange& output) const { + DCHECK(needReset_); +#if ZSTD_VERSION_NUMBER < 10104 + // We require ZSTD_findFrameCompressedSize() to perform this optimization. + return false; +#else + // We need to know the uncompressed length and have enough output space. + if (!uncompressedLength() || output.size() < *uncompressedLength()) { + return false; + } + size_t const compressedLength = + ZSTD_findFrameCompressedSize(input.data(), input.size()); + zstdThrowIfError(compressedLength); + size_t const length = ZSTD_decompress( + output.data(), *uncompressedLength(), input.data(), compressedLength); + zstdThrowIfError(length); + if (length != *uncompressedLength()) { + throw std::runtime_error("ZSTDStreamCodec: Incorrect uncompressed length"); + } + input.uncheckedAdvance(compressedLength); + output.uncheckedAdvance(length); + return true; +#endif } -std::unique_ptr ZSTDCodec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - { - // Read decompressed size from frame if available in first IOBuf. - const auto decompressedSize = - ZSTD_getDecompressedSize(data->data(), data->length()); - if (decompressedSize != 0) { - if (uncompressedLength && *uncompressedLength != decompressedSize) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); - } - uncompressedLength = decompressedSize; +void ZSTDStreamCodec::resetDStream() { + if (!dstream_) { + dstream_.reset(ZSTD_createDStream()); + if (!dstream_) { + throw std::bad_alloc{}; } } - // Faster to decompress using ZSTD_decompress() if we can. - if (uncompressedLength && !data->isChained()) { - return zstdUncompressBuffer(data, uncompressedLength); + zstdThrowIfError(ZSTD_initDStream(dstream_.get())); +} + +bool ZSTDStreamCodec::doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + // If we are given all the input in one chunk try to use block uncompression + if (flushOp == StreamCodec::FlushOp::END && + tryBlockUncompress(input, output)) { + return true; + } + resetDStream(); + needReset_ = false; } - // Fall back to slower streaming decompression. - return zstdUncompressStream(data, uncompressedLength); + ZSTD_inBuffer in = {input.data(), input.size(), 0}; + ZSTD_outBuffer out = {output.data(), output.size(), 0}; + SCOPE_EXIT { + input.uncheckedAdvance(in.pos); + output.uncheckedAdvance(out.pos); + }; + size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in); + zstdThrowIfError(rc); + return rc == 0; } -#endif // FOLLY_HAVE_LIBZSTD +#endif // FOLLY_HAVE_LIBZSTD #if FOLLY_HAVE_LIBBZ2 @@ -1557,6 +1598,7 @@ class Bzip2Codec final : public Codec { const override; private: + uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; std::unique_ptr doCompress(IOBuf const* data) override; std::unique_ptr doUncompress( IOBuf const* data, @@ -1568,7 +1610,7 @@ class Bzip2Codec final : public Codec { /* static */ std::unique_ptr Bzip2Codec::create( int level, CodecType type) { - return make_unique(level, type); + return std::make_unique(level, type); } Bzip2Codec::Bzip2Codec(int level, CodecType type) : Codec(type) { @@ -1602,6 +1644,14 @@ bool Bzip2Codec::canUncompress(IOBuf const* data, Optional) const { return dataStartsWithLE(data, kBzip2MagicLE, kBzip2MagicBytes); } +uint64_t Bzip2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const { + // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress + // To guarantee that the compressed data will fit in its buffer, allocate an + // output buffer of size 1% larger than the uncompressed data, plus six + // hundred extra bytes. + return uncompressedLength + uncompressedLength / 100 + 600; +} + static bz_stream createBzStream() { bz_stream stream; stream.bzalloc = nullptr; @@ -1626,14 +1676,6 @@ static int bzCheck(int const rc) { } } -static uint64_t bzCompressBound(uint64_t const uncompressedLength) { - // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress - // To guarantee that the compressed data will fit in its buffer, allocate an - // output buffer of size 1% larger than the uncompressed data, plus six - // hundred extra bytes. - return uncompressedLength + uncompressedLength / 100 + 600; -} - static std::unique_ptr addOutputBuffer( bz_stream* stream, uint64_t const bufferLength) { @@ -1657,14 +1699,14 @@ std::unique_ptr Bzip2Codec::doCompress(IOBuf const* data) { }; uint64_t const uncompressedLength = data->computeChainDataLength(); - uint64_t const maxCompressedLength = bzCompressBound(uncompressedLength); + uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength); uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB uint64_t constexpr kDefaultBufferLength = uint64_t(4) << 20; auto out = addOutputBuffer( &stream, - maxCompressedLength <= kMaxSingleStepLength ? maxCompressedLength - : kDefaultBufferLength); + maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen + : kDefaultBufferLength); for (auto range : *data) { while (!range.empty()) { @@ -1732,8 +1774,11 @@ std::unique_ptr Bzip2Codec::doUncompress( if (stream.avail_out == 0) { out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength)); } - + size_t const outputSize = stream.avail_out; rc = bzCheck(BZ2_bzDecompress(&stream)); + if (outputSize == stream.avail_out) { + throw std::runtime_error("Bzip2Codec: Truncated input"); + } } out->prev()->trimEnd(stream.avail_out); @@ -1749,6 +1794,24 @@ std::unique_ptr Bzip2Codec::doUncompress( #endif // FOLLY_HAVE_LIBBZ2 +#if FOLLY_HAVE_LIBZ + +zlib::Options getZlibOptions(CodecType type) { + DCHECK(type == CodecType::GZIP || type == CodecType::ZLIB); + return type == CodecType::GZIP ? zlib::defaultGzipOptions() + : zlib::defaultZlibOptions(); +} + +std::unique_ptr getZlibCodec(int level, CodecType type) { + return zlib::getCodec(getZlibOptions(type), level); +} + +std::unique_ptr getZlibStreamCodec(int level, CodecType type) { + return zlib::getStreamCodec(getZlibOptions(type), level); +} + +#endif // FOLLY_HAVE_LIBZ + /** * Automatic decompression */ @@ -1766,6 +1829,10 @@ class AutomaticCodec final : public Codec { bool doNeedsUncompressedLength() const override; uint64_t doMaxUncompressedLength() const override; + uint64_t doMaxCompressedLength(uint64_t) const override { + throw std::runtime_error( + "AutomaticCodec error: maxCompressedLength() not supported."); + } std::unique_ptr doCompress(const IOBuf*) override { throw std::runtime_error("AutomaticCodec error: compress() not supported."); } @@ -1817,7 +1884,7 @@ void AutomaticCodec::addCodecIfSupported(CodecType type) { /* static */ std::unique_ptr AutomaticCodec::create( std::vector> customCodecs) { - return make_unique(std::move(customCodecs)); + return std::make_unique(std::move(customCodecs)); } AutomaticCodec::AutomaticCodec(std::vector> customCodecs) @@ -1909,93 +1976,112 @@ std::unique_ptr AutomaticCodec::doUncompress( throw std::runtime_error("AutomaticCodec error: Unknown compressed data"); } -} // namespace +using CodecFactory = std::unique_ptr (*)(int, CodecType); +using StreamCodecFactory = std::unique_ptr (*)(int, CodecType); +struct Factory { + CodecFactory codec; + StreamCodecFactory stream; +}; -typedef std::unique_ptr (*CodecFactory)(int, CodecType); -static constexpr CodecFactory +constexpr Factory codecFactories[static_cast(CodecType::NUM_CODEC_TYPES)] = { - nullptr, // USER_DEFINED - NoCompressionCodec::create, + {}, // USER_DEFINED + {NoCompressionCodec::create, nullptr}, #if FOLLY_HAVE_LIBLZ4 - LZ4Codec::create, + {LZ4Codec::create, nullptr}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBSNAPPY - SnappyCodec::create, + {SnappyCodec::create, nullptr}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBZ - ZlibCodec::create, + {getZlibCodec, getZlibStreamCodec}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBLZ4 - LZ4Codec::create, + {LZ4Codec::create, nullptr}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBLZMA - LZMA2Codec::create, - LZMA2Codec::create, + {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream}, + {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream}, #else - nullptr, - nullptr, + {}, + {}, #endif #if FOLLY_HAVE_LIBZSTD - ZSTDCodec::create, + {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBZ - ZlibCodec::create, + {getZlibCodec, getZlibStreamCodec}, #else - nullptr, + {}, #endif #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301) - LZ4FrameCodec::create, + {LZ4FrameCodec::create, nullptr}, #else - nullptr, + {}, #endif #if FOLLY_HAVE_LIBBZ2 - Bzip2Codec::create, + {Bzip2Codec::create, nullptr}, #else - nullptr + {}, #endif }; -bool hasCodec(CodecType type) { - size_t idx = static_cast(type); +Factory const& getFactory(CodecType type) { + size_t const idx = static_cast(type); if (idx >= static_cast(CodecType::NUM_CODEC_TYPES)) { throw std::invalid_argument( to("Compression type ", idx, " invalid")); } - return codecFactories[idx] != nullptr; + return codecFactories[idx]; +} +} // namespace + +bool hasCodec(CodecType type) { + return getFactory(type).codec != nullptr; } std::unique_ptr getCodec(CodecType type, int level) { - size_t idx = static_cast(type); - if (idx >= static_cast(CodecType::NUM_CODEC_TYPES)) { + auto const factory = getFactory(type).codec; + if (!factory) { throw std::invalid_argument( - to("Compression type ", idx, " invalid")); + to("Compression type ", type, " not supported")); } - auto factory = codecFactories[idx]; + auto codec = (*factory)(level, type); + DCHECK(codec->type() == type); + return codec; +} + +bool hasStreamCodec(CodecType type) { + return getFactory(type).stream != nullptr; +} + +std::unique_ptr getStreamCodec(CodecType type, int level) { + auto const factory = getFactory(type).stream; if (!factory) { - throw std::invalid_argument(to( - "Compression type ", idx, " not supported")); + throw std::invalid_argument( + to("Compression type ", type, " not supported")); } auto codec = (*factory)(level, type); - DCHECK_EQ(static_cast(codec->type()), idx); + DCHECK(codec->type() == type); return codec; } @@ -2003,4 +2089,5 @@ std::unique_ptr getAutoUncompressionCodec( std::vector> customCodecs) { return AutomaticCodec::create(std::move(customCodecs)); } -}} // namespaces +} // namespace io +} // namespace folly