From 9dc004f8d3b22fbffc470cc2890a17caefa32f98 Mon Sep 17 00:00:00 2001 From: Stella Lau Date: Thu, 24 Aug 2017 09:26:59 -0700 Subject: [PATCH] Add LZMA streaming interface Summary: - Replace LZMA2Codec with LZMA2StreamCodec - Update tests to reflect LZMA2_VARINT_SIZE requiring data length Reviewed By: terrelln Differential Revision: D5625388 fbshipit-source-id: 3303c6dda5d41f40615c87504a46923815b0b716 --- folly/io/Compression.cpp | 450 +++++++++++++++++------------- folly/io/test/CompressionTest.cpp | 155 +++++++--- 2 files changed, 377 insertions(+), 228 deletions(-) diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index de065f10..883d0da3 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -251,6 +251,9 @@ bool StreamCodec::compressStream( } return true; } + if (!uncompressedLength() && needsDataLength()) { + throw std::runtime_error("Codec: uncompressed length required"); + } if (state_ == State::RESET && !input.empty() && uncompressedLength() == uint64_t(0)) { throw std::runtime_error("Codec: invalid uncompressed length"); @@ -459,13 +462,13 @@ std::unique_ptr NoCompressionCodec::create(int level, CodecType type) { } NoCompressionCodec::NoCompressionCodec(int level, CodecType type) - : Codec(type) { + : Codec(type) { DCHECK(type == CodecType::NO_COMPRESSION); switch (level) { - case COMPRESSION_LEVEL_DEFAULT: - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_BEST: - level = 0; + case COMPRESSION_LEVEL_DEFAULT: + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_BEST: + level = 0; } if (level != 0) { throw std::invalid_argument(to( @@ -556,13 +559,13 @@ LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) { DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_DEFAULT: - level = 1; - break; - case COMPRESSION_LEVEL_BEST: - level = 2; - break; + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_DEFAULT: + level = 1; + break; + case COMPRESSION_LEVEL_BEST: + level = 2; + break; } if (level < 1 || level > 2) { throw std::invalid_argument(to( @@ -913,10 +916,10 @@ std::unique_ptr SnappyCodec::create(int level, CodecType type) { SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) { DCHECK(type == CodecType::SNAPPY); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - case COMPRESSION_LEVEL_DEFAULT: - case COMPRESSION_LEVEL_BEST: - level = 1; + case COMPRESSION_LEVEL_FASTEST: + case COMPRESSION_LEVEL_DEFAULT: + case COMPRESSION_LEVEL_BEST: + level = 1; } if (level != 1) { throw std::invalid_argument(to( @@ -983,44 +986,71 @@ std::unique_ptr SnappyCodec::doUncompress( /** * LZMA2 compression */ -class LZMA2Codec final : public Codec { +class LZMA2StreamCodec final : public StreamCodec { public: - static std::unique_ptr create(int level, CodecType type); - explicit LZMA2Codec(int level, CodecType type); + static std::unique_ptr createCodec(int level, CodecType type); + static std::unique_ptr createStream(int level, CodecType type); + explicit LZMA2StreamCodec(int level, CodecType type); + ~LZMA2StreamCodec() override; std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) const override; private: - bool doNeedsUncompressedLength() const override; + bool doNeedsDataLength() const override; uint64_t doMaxUncompressedLength() const override; uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; - bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; } + bool encodeSize() const { + return type() == CodecType::LZMA2_VARINT_SIZE; + } - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; + void doResetStream() override; + bool doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + bool doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; - std::unique_ptr addOutputBuffer(lzma_stream* stream, size_t length); - bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength); + void resetCStream(); + void resetDStream(); + + size_t decodeVarint(ByteRange& input); + bool flushVarintBuffer(MutableByteRange& output); + void resetVarintBuffer(); + + Optional cstream_{}; + Optional dstream_{}; + + std::array varintBuffer_; + ByteRange varintToEncode_; + size_t varintBufferPos_{0}; int level_; + bool needReset_{true}; + bool needDecodeSize_{false}; }; static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD; static constexpr unsigned kLZMA2MagicBytes = 6; -std::vector LZMA2Codec::validPrefixes() const { +std::vector LZMA2StreamCodec::validPrefixes() const { if (type() == CodecType::LZMA2_VARINT_SIZE) { return {}; } return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)}; } -bool LZMA2Codec::canUncompress(const IOBuf* data, Optional) const { +bool LZMA2StreamCodec::doNeedsDataLength() const { + return encodeSize(); +} + +bool LZMA2StreamCodec::canUncompress(const IOBuf* data, Optional) + const { if (type() == CodecType::LZMA2_VARINT_SIZE) { return false; } @@ -1029,220 +1059,258 @@ bool LZMA2Codec::canUncompress(const IOBuf* data, Optional) const { return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes); } -std::unique_ptr LZMA2Codec::create(int level, CodecType type) { - return std::make_unique(level, type); +std::unique_ptr LZMA2StreamCodec::createCodec( + int level, + CodecType type) { + return make_unique(level, type); } -LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) { +std::unique_ptr LZMA2StreamCodec::createStream( + int level, + CodecType type) { + return make_unique(level, type); +} + +LZMA2StreamCodec::LZMA2StreamCodec(int level, CodecType type) + : StreamCodec(type) { DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - level = 0; - break; - case COMPRESSION_LEVEL_DEFAULT: - level = LZMA_PRESET_DEFAULT; - break; - case COMPRESSION_LEVEL_BEST: - level = 9; - break; + case COMPRESSION_LEVEL_FASTEST: + level = 0; + break; + case COMPRESSION_LEVEL_DEFAULT: + level = LZMA_PRESET_DEFAULT; + break; + case COMPRESSION_LEVEL_BEST: + level = 9; + break; } if (level < 0 || level > 9) { - throw std::invalid_argument(to( - "LZMA2Codec: invalid level: ", level)); + throw std::invalid_argument( + to("LZMA2Codec: invalid level: ", level)); } level_ = level; } -bool LZMA2Codec::doNeedsUncompressedLength() const { - return false; +LZMA2StreamCodec::~LZMA2StreamCodec() { + if (cstream_) { + lzma_end(cstream_.get_pointer()); + cstream_.clear(); + } + if (dstream_) { + lzma_end(dstream_.get_pointer()); + dstream_.clear(); + } } -uint64_t LZMA2Codec::doMaxUncompressedLength() const { +uint64_t LZMA2StreamCodec::doMaxUncompressedLength() const { // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)" return uint64_t(1) << 63; } -uint64_t LZMA2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const { +uint64_t LZMA2StreamCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { return lzma_stream_buffer_bound(uncompressedLength) + (encodeSize() ? kMaxVarintLength64 : 0); } -std::unique_ptr LZMA2Codec::addOutputBuffer( - lzma_stream* stream, - size_t length) { - - CHECK_EQ(stream->avail_out, 0); - - auto buf = IOBuf::create(length); - buf->append(buf->capacity()); - - stream->next_out = buf->writableData(); - stream->avail_out = buf->length(); - - return buf; +void LZMA2StreamCodec::doResetStream() { + needReset_ = true; } -std::unique_ptr LZMA2Codec::doCompress(const IOBuf* data) { - lzma_ret rc; - lzma_stream stream = LZMA_STREAM_INIT; - - rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE); +void LZMA2StreamCodec::resetCStream() { + if (!cstream_) { + cstream_.assign(LZMA_STREAM_INIT); + } + lzma_ret const rc = + lzma_easy_encoder(cstream_.get_pointer(), level_, LZMA_CHECK_NONE); if (rc != LZMA_OK) { throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_easy_encoder error: ", rc)); - } - - SCOPE_EXIT { lzma_end(&stream); }; - - uint64_t uncompressedLength = data->computeChainDataLength(); - uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength); - - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(4) << 20; // 4MiB - - auto out = addOutputBuffer( - &stream, - (maxCompressedLength <= maxSingleStepLength ? - maxCompressedLength : - defaultBufferLength)); - - if (encodeSize()) { - auto size = IOBuf::createCombined(kMaxVarintLength64); - encodeVarintToIOBuf(uncompressedLength, size.get()); - size->appendChain(std::move(out)); - out = std::move(size); + "LZMA2StreamCodec: lzma_easy_encoder error: ", rc)); } +} - for (auto& range : *data) { - if (range.empty()) { - continue; - } - - stream.next_in = const_cast(range.data()); - stream.avail_in = range.size(); - - while (stream.avail_in != 0) { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = lzma_code(&stream, LZMA_RUN); - - if (rc != LZMA_OK) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_code error: ", rc)); - } - } +void LZMA2StreamCodec::resetDStream() { + if (!dstream_) { + dstream_.assign(LZMA_STREAM_INIT); } - - do { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = lzma_code(&stream, LZMA_FINISH); - } while (rc == LZMA_OK); - - if (rc != LZMA_STREAM_END) { + lzma_ret const rc = lzma_auto_decoder( + dstream_.get_pointer(), std::numeric_limits::max(), 0); + if (rc != LZMA_OK) { throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_code ended with error: ", rc)); + "LZMA2StreamCodec: lzma_auto_decoder error: ", rc)); } - - out->prev()->trimEnd(stream.avail_out); - - return out; } -bool LZMA2Codec::doInflate(lzma_stream* stream, - IOBuf* head, - size_t bufferLength) { - if (stream->avail_out == 0) { - head->prependChain(addOutputBuffer(stream, bufferLength)); +static lzma_ret lzmaThrowOnError(lzma_ret const rc) { + switch (rc) { + case LZMA_OK: + case LZMA_STREAM_END: + case LZMA_BUF_ERROR: // not fatal: returned if no progress was made twice + return rc; + default: + throw std::runtime_error( + to("LZMA2StreamCodec: error: ", rc)); } +} - lzma_ret rc = lzma_code(stream, LZMA_RUN); +static lzma_action lzmaTranslateFlush(StreamCodec::FlushOp flush) { + switch (flush) { + case StreamCodec::FlushOp::NONE: + return LZMA_RUN; + case StreamCodec::FlushOp::FLUSH: + return LZMA_SYNC_FLUSH; + case StreamCodec::FlushOp::END: + return LZMA_FINISH; + default: + throw std::invalid_argument("LZMA2StreamCodec: Invalid flush"); + } +} - switch (rc) { - case LZMA_OK: - break; - case LZMA_STREAM_END: +/** + * Flushes the varint buffer. + * Advances output by the number of bytes written. + * Returns true when flushing is complete. + */ +bool LZMA2StreamCodec::flushVarintBuffer(MutableByteRange& output) { + if (varintToEncode_.empty()) { return true; - default: - throw std::runtime_error(to( - "LZMA2Codec: lzma_code error: ", rc)); } - - return false; + const size_t numBytesToCopy = std::min(varintToEncode_.size(), output.size()); + if (numBytesToCopy > 0) { + memcpy(output.data(), varintToEncode_.data(), numBytesToCopy); + } + varintToEncode_.advance(numBytesToCopy); + output.advance(numBytesToCopy); + return varintToEncode_.empty(); } -std::unique_ptr LZMA2Codec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - lzma_ret rc; - lzma_stream stream = LZMA_STREAM_INIT; - - rc = lzma_auto_decoder(&stream, std::numeric_limits::max(), 0); - if (rc != LZMA_OK) { - throw std::runtime_error(folly::to( - "LZMA2Codec: lzma_auto_decoder error: ", rc)); +bool LZMA2StreamCodec::doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + resetCStream(); + if (encodeSize()) { + varintBufferPos_ = 0; + size_t const varintSize = + encodeVarint(*uncompressedLength(), varintBuffer_.data()); + varintToEncode_ = {varintBuffer_.data(), varintSize}; + } + needReset_ = false; } - SCOPE_EXIT { lzma_end(&stream); }; - - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB + if (!flushVarintBuffer(output)) { + return false; + } - folly::io::Cursor cursor(data); - if (encodeSize()) { - const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor); - if (uncompressedLength && *uncompressedLength != actualUncompressedLength) { - throw std::runtime_error("LZMA2Codec: invalid uncompressed length"); - } - uncompressedLength = actualUncompressedLength; + cstream_->next_in = const_cast(input.data()); + cstream_->avail_in = input.size(); + cstream_->next_out = output.data(); + cstream_->avail_out = output.size(); + SCOPE_EXIT { + input.uncheckedAdvance(input.size() - cstream_->avail_in); + output.uncheckedAdvance(output.size() - cstream_->avail_out); + }; + lzma_ret const rc = lzmaThrowOnError( + lzma_code(cstream_.get_pointer(), lzmaTranslateFlush(flushOp))); + switch (flushOp) { + case StreamCodec::FlushOp::NONE: + return false; + case StreamCodec::FlushOp::FLUSH: + return cstream_->avail_in == 0 && cstream_->avail_out != 0; + case StreamCodec::FlushOp::END: + return rc == LZMA_STREAM_END; + default: + throw std::invalid_argument("LZMA2StreamCodec: invalid FlushOp"); } +} - auto out = addOutputBuffer( - &stream, - ((uncompressedLength && *uncompressedLength <= maxSingleStepLength) - ? *uncompressedLength - : defaultBufferLength)); +/** + * Attempts to decode a varint from input. + * The function advances input by the number of bytes read. + * + * If there are too many bytes and the varint is not valid, throw a + * runtime_error. + * Returns the decoded size or 0 if more bytes are needed. + */ +size_t LZMA2StreamCodec::decodeVarint(ByteRange& input) { + if (input.empty()) { + return 0; + } + size_t const numBytesToCopy = + std::min(kMaxVarintLength64 - varintBufferPos_, input.size()); + memcpy(varintBuffer_.data() + varintBufferPos_, input.data(), numBytesToCopy); - bool streamEnd = false; - auto buf = cursor.peekBytes(); - while (!buf.empty()) { - stream.next_in = const_cast(buf.data()); - stream.avail_in = buf.size(); + size_t const rangeSize = varintBufferPos_ + numBytesToCopy; + ByteRange range{varintBuffer_.data(), rangeSize}; + auto const ret = tryDecodeVarint(range); - while (stream.avail_in != 0) { - if (streamEnd) { - throw std::runtime_error(to( - "LZMA2Codec: junk after end of data")); - } + if (ret.hasValue()) { + size_t const varintSize = rangeSize - range.size(); + input.advance(varintSize - varintBufferPos_); + return ret.value(); + } else if (ret.error() == DecodeVarintError::TooManyBytes) { + throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length"); + } else { + // Too few bytes + input.advance(numBytesToCopy); + varintBufferPos_ += numBytesToCopy; + return 0; + } +} - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); +bool LZMA2StreamCodec::doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + resetDStream(); + needReset_ = false; + needDecodeSize_ = encodeSize(); + if (encodeSize()) { + // Reset buffer + varintBufferPos_ = 0; } - - cursor.skip(buf.size()); - buf = cursor.peekBytes(); } - while (!streamEnd) { - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); + if (needDecodeSize_) { + // Try decoding the varint. If the input does not contain the entire varint, + // buffer the input. If the varint can not be decoded, fail. + size_t const size = decodeVarint(input); + if (!size) { + return false; + } + if (uncompressedLength() && *uncompressedLength() != size) { + throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length"); + } + needDecodeSize_ = false; } - out->prev()->trimEnd(stream.avail_out); + dstream_->next_in = const_cast(input.data()); + dstream_->avail_in = input.size(); + dstream_->next_out = output.data(); + dstream_->avail_out = output.size(); + SCOPE_EXIT { + input.advance(input.size() - dstream_->avail_in); + output.advance(output.size() - dstream_->avail_out); + }; - if (uncompressedLength && *uncompressedLength != stream.total_out) { - throw std::runtime_error( - to("LZMA2Codec: invalid uncompressed length")); + lzma_ret rc; + switch (flushOp) { + case StreamCodec::FlushOp::NONE: + case StreamCodec::FlushOp::FLUSH: + rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_RUN)); + break; + case StreamCodec::FlushOp::END: + rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_FINISH)); + break; + default: + throw std::invalid_argument("LZMA2StreamCodec: invalid flush"); } - - return out; + return rc == LZMA_STREAM_END; } - -#endif // FOLLY_HAVE_LIBLZMA +#endif // FOLLY_HAVE_LIBLZMA #ifdef FOLLY_HAVE_LIBZSTD @@ -1945,8 +2013,8 @@ constexpr Factory #endif #if FOLLY_HAVE_LIBLZMA - {LZMA2Codec::create, nullptr}, - {LZMA2Codec::create, nullptr}, + {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream}, + {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream}, #else {}, {}, diff --git a/folly/io/test/CompressionTest.cpp b/folly/io/test/CompressionTest.cpp index 6beccbdd..c8deed91 100644 --- a/folly/io/test/CompressionTest.cpp +++ b/folly/io/test/CompressionTest.cpp @@ -175,20 +175,22 @@ static std::vector availableStreamCodecs() { } TEST(CompressionTestNeedsUncompressedLength, Simple) { - static const struct { CodecType type; bool needsUncompressedLength; } - expectations[] = { - { CodecType::NO_COMPRESSION, false }, - { CodecType::LZ4, true }, - { CodecType::SNAPPY, false }, - { CodecType::ZLIB, false }, - { CodecType::LZ4_VARINT_SIZE, false }, - { CodecType::LZMA2, false }, - { CodecType::LZMA2_VARINT_SIZE, false }, - { CodecType::ZSTD, false }, - { CodecType::GZIP, false }, - { CodecType::LZ4_FRAME, false }, - { CodecType::BZIP2, false }, - }; + static const struct { + CodecType type; + bool needsUncompressedLength; + } expectations[] = { + {CodecType::NO_COMPRESSION, false}, + {CodecType::LZ4, true}, + {CodecType::SNAPPY, false}, + {CodecType::ZLIB, false}, + {CodecType::LZ4_VARINT_SIZE, false}, + {CodecType::LZMA2, false}, + {CodecType::LZMA2_VARINT_SIZE, false}, + {CodecType::ZSTD, false}, + {CodecType::GZIP, false}, + {CodecType::LZ4_FRAME, false}, + {CodecType::BZIP2, false}, + }; for (auto const& test : expectations) { if (hasCodec(test.type)) { @@ -360,7 +362,19 @@ INSTANTIATE_TEST_CASE_P( testing::ValuesIn(supportedCodecs({ CodecType::LZ4_VARINT_SIZE, CodecType::LZMA2_VARINT_SIZE, - })))); + })))); + +TEST(LZMATest, UncompressBadVarint) { + if (hasStreamCodec(CodecType::LZMA2_VARINT_SIZE)) { + std::string const str(kMaxVarintLength64 * 2, '\xff'); + ByteRange input((folly::StringPiece(str))); + auto codec = getStreamCodec(CodecType::LZMA2_VARINT_SIZE); + auto buffer = IOBuf::create(16); + buffer->append(buffer->capacity()); + MutableByteRange output{buffer->writableData(), buffer->length()}; + EXPECT_THROW(codec->uncompressStream(input, output), std::runtime_error); + } +} class CompressionCorruptionTest : public testing::TestWithParam { protected: @@ -449,6 +463,26 @@ class StreamingUnitTest : public testing::TestWithParam { std::unique_ptr codec_; }; +TEST(StreamingUnitTest, needsDataLength) { + static const struct { + CodecType type; + bool needsDataLength; + } expectations[] = { + {CodecType::ZLIB, false}, + {CodecType::GZIP, false}, + {CodecType::LZMA2, false}, + {CodecType::LZMA2_VARINT_SIZE, true}, + {CodecType::ZSTD, false}, + }; + + for (auto const& test : expectations) { + if (hasStreamCodec(test.type)) { + EXPECT_EQ( + getStreamCodec(test.type)->needsDataLength(), test.needsDataLength); + } + } +} + TEST_P(StreamingUnitTest, maxCompressedLength) { EXPECT_EQ(0, codec_->maxCompressedLength(0)); for (uint64_t const length : {1, 10, 100, 1000, 10000, 100000, 1000000}) { @@ -483,22 +517,24 @@ TEST_P(StreamingUnitTest, emptyData) { MutableByteRange output{}; // Test compressing empty data in one pass - EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END)); + if (!codec_->needsDataLength()) { + EXPECT_TRUE( + codec_->compressStream(input, output, StreamCodec::FlushOp::END)); + } codec_->resetStream(0); EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END)); - codec_->resetStream(); output = {buffer->writableData(), buffer->length()}; EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END)); EXPECT_EQ(buffer->length(), output.size()); // Test compressing empty data with multiple calls to compressStream() - codec_->resetStream(); + codec_->resetStream(0); output = {}; EXPECT_FALSE(codec_->compressStream(input, output)); EXPECT_TRUE( codec_->compressStream(input, output, StreamCodec::FlushOp::FLUSH)); EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END)); - codec_->resetStream(); + codec_->resetStream(0); output = {buffer->writableData(), buffer->length()}; EXPECT_FALSE(codec_->compressStream(input, output)); EXPECT_TRUE( @@ -541,7 +577,11 @@ TEST_P(StreamingUnitTest, noForwardProgressOkay) { MutableByteRange emptyOutput; // Compress some data to avoid empty data special casing - codec_->resetStream(); + if (codec_->needsDataLength()) { + codec_->resetStream(inBuffer->computeChainDataLength()); + } else { + codec_->resetStream(); + } while (!input.empty()) { codec_->compressStream(input, output); } @@ -549,7 +589,11 @@ TEST_P(StreamingUnitTest, noForwardProgressOkay) { codec_->compressStream(emptyInput, emptyOutput); codec_->compressStream(emptyInput, emptyOutput, StreamCodec::FlushOp::FLUSH); - codec_->resetStream(); + if (codec_->needsDataLength()) { + codec_->resetStream(inBuffer->computeChainDataLength()); + } else { + codec_->resetStream(); + } input = inBuffer->coalesce(); output = {outBuffer->writableTail(), outBuffer->tailroom()}; while (!input.empty()) { @@ -590,6 +634,20 @@ TEST_P(StreamingUnitTest, stateTransitions) { auto output = empty ? MutableByteRange{} : out; return codec_->compressStream(input, output, flushOp); }; + auto compress_all = [&](bool expect, + StreamCodec::FlushOp flushOp = + StreamCodec::FlushOp::NONE, + bool empty = false) { + auto input = in; + auto output = empty ? MutableByteRange{} : out; + while (!input.empty()) { + if (expect) { + EXPECT_TRUE(codec_->compressStream(input, output, flushOp)); + } else { + EXPECT_FALSE(codec_->compressStream(input, output, flushOp)); + } + } + }; auto uncompress = [&]( StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE, bool empty = false) { @@ -599,12 +657,21 @@ TEST_P(StreamingUnitTest, stateTransitions) { }; // compression flow - codec_->resetStream(); - EXPECT_FALSE(compress()); - EXPECT_FALSE(compress()); - EXPECT_TRUE(compress(StreamCodec::FlushOp::FLUSH)); - EXPECT_FALSE(compress()); - EXPECT_TRUE(compress(StreamCodec::FlushOp::END)); + if (!codec_->needsDataLength()) { + codec_->resetStream(); + EXPECT_FALSE(compress()); + EXPECT_FALSE(compress()); + EXPECT_TRUE(compress(StreamCodec::FlushOp::FLUSH)); + EXPECT_FALSE(compress()); + EXPECT_TRUE(compress(StreamCodec::FlushOp::END)); + } + codec_->resetStream(in.size() * 5); + compress_all(false); + compress_all(false); + compress_all(true, StreamCodec::FlushOp::FLUSH); + compress_all(false); + compress_all(true, StreamCodec::FlushOp::END); + // uncompression flow codec_->resetStream(); EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true)); @@ -617,34 +684,40 @@ TEST_P(StreamingUnitTest, stateTransitions) { codec_->resetStream(); EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH)); // compress -> uncompress - codec_->resetStream(); + codec_->resetStream(in.size()); EXPECT_FALSE(compress()); EXPECT_THROW(uncompress(), std::logic_error); // uncompress -> compress - codec_->resetStream(); + codec_->resetStream(inBuffer->computeChainDataLength()); EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH)); EXPECT_THROW(compress(), std::logic_error); // end -> compress - codec_->resetStream(); - EXPECT_FALSE(compress()); - EXPECT_TRUE(compress(StreamCodec::FlushOp::END)); + if (!codec_->needsDataLength()) { + codec_->resetStream(); + EXPECT_FALSE(compress()); + EXPECT_TRUE(compress(StreamCodec::FlushOp::END)); + EXPECT_THROW(compress(), std::logic_error); + } + codec_->resetStream(in.size() * 2); + compress_all(false); + compress_all(true, StreamCodec::FlushOp::END); EXPECT_THROW(compress(), std::logic_error); // end -> uncompress codec_->resetStream(); EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH)); EXPECT_THROW(uncompress(), std::logic_error); // flush -> compress - codec_->resetStream(); + codec_->resetStream(in.size()); EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true)); EXPECT_THROW(compress(), std::logic_error); // flush -> end - codec_->resetStream(); + codec_->resetStream(in.size()); EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true)); EXPECT_THROW(compress(StreamCodec::FlushOp::END), std::logic_error); // undefined -> compress codec_->compress(inBuffer.get()); EXPECT_THROW(compress(), std::logic_error); - codec_->uncompress(compressed.get()); + codec_->uncompress(compressed.get(), inBuffer->computeChainDataLength()); EXPECT_THROW(compress(), std::logic_error); // undefined -> undefined codec_->uncompress(compressed.get()); @@ -738,7 +811,11 @@ void StreamingCompressionTest::runResetStreamTest(DataHolder const& dh) { codec_->resetStream(uncompressedLength_); compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::NONE); // Reset stream and compress all - codec_->resetStream(); + if (codec_->needsDataLength()) { + codec_->resetStream(uncompressedLength_); + } else { + codec_->resetStream(); + } auto compressed = compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::END); auto const uncompressed = codec_->uncompress(compressed.get(), input.size()); @@ -820,7 +897,11 @@ void StreamingCompressionTest::runFlushTest(DataHolder const& dh) { auto const inputs = split(dh.data(uncompressedLength_)); auto uncodec = getStreamCodec(codec_->type()); - codec_->resetStream(); + if (codec_->needsDataLength()) { + codec_->resetStream(uncompressedLength_); + } else { + codec_->resetStream(); + } for (auto input : inputs) { // Compress some data and flush the stream auto compressed = compressSome( -- 2.34.1