}
bool ZSTDCodec::doNeedsUncompressedLength() const {
- return true;
+ return false;
+}
+
+void zstdThrowIfError(size_t rc) {
+ if (!ZSTD_isError(rc)) {
+ return;
+ }
+ throw std::runtime_error(
+ to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
}
std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
- size_t rc;
- size_t maxCompressedLength = ZSTD_compressBound(data->length());
- auto out = IOBuf::createCombined(maxCompressedLength);
+ // Support earlier versions of the codec (working with a single IOBuf,
+ // and using ZSTD_decompress which requires ZSTD frame to contain size,
+ // which isn't populated by streaming API).
+ if (!data->isChained()) {
+ auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
+ const auto rc = ZSTD_compress(
+ out->writableData(),
+ out->capacity(),
+ data->data(),
+ data->length(),
+ level_);
+ zstdThrowIfError(rc);
+ out->append(rc);
+ return out;
+ }
- CHECK_EQ(out->length(), 0);
+ auto zcs = ZSTD_createCStream();
+ SCOPE_EXIT {
+ ZSTD_freeCStream(zcs);
+ };
- rc = ZSTD_compress(out->writableTail(),
- out->capacity(),
- data->data(),
- data->length(),
- level_);
+ auto rc = ZSTD_initCStream(zcs, level_);
+ zstdThrowIfError(rc);
- if (ZSTD_isError(rc)) {
- throw std::runtime_error(to<std::string>(
- "ZSTD compression returned an error: ",
- ZSTD_getErrorName(rc)));
+ Cursor cursor(data);
+ auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
+
+ ZSTD_outBuffer out;
+ out.dst = result->writableTail();
+ out.size = result->capacity();
+ out.pos = 0;
+
+ for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
+ ZSTD_inBuffer in;
+ in.src = buffer.data();
+ in.size = buffer.size();
+ for (in.pos = 0; in.pos != in.size;) {
+ rc = ZSTD_compressStream(zcs, &out, &in);
+ zstdThrowIfError(rc);
+ }
+ cursor.skip(in.size);
+ buffer = cursor.peekBytes();
}
- out->append(rc);
- CHECK_EQ(out->length(), rc);
+ rc = ZSTD_endStream(zcs, &out);
+ zstdThrowIfError(rc);
+ CHECK_EQ(rc, 0);
- return out;
+ result->append(out.pos);
+ return result;
}
-std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
- uint64_t uncompressedLength) {
- size_t rc;
- auto out = IOBuf::createCombined(uncompressedLength);
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+ const IOBuf* data,
+ uint64_t uncompressedLength) {
+ auto zds = ZSTD_createDStream();
+ SCOPE_EXIT {
+ ZSTD_freeDStream(zds);
+ };
- CHECK_GE(out->capacity(), uncompressedLength);
- CHECK_EQ(out->length(), 0);
+ auto rc = ZSTD_initDStream(zds);
+ zstdThrowIfError(rc);
- rc = ZSTD_decompress(
- out->writableTail(), out->capacity(), data->data(), data->length());
+ ZSTD_outBuffer out{};
+ ZSTD_inBuffer in{};
- if (ZSTD_isError(rc)) {
- throw std::runtime_error(to<std::string>(
- "ZSTD decompression returned an error: ",
- ZSTD_getErrorName(rc)));
+ auto outputSize = ZSTD_DStreamOutSize();
+ if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
+ outputSize = uncompressedLength;
+ } else {
+ auto decompressedSize =
+ ZSTD_getDecompressedSize(data->data(), data->length());
+ if (decompressedSize != 0 && decompressedSize < outputSize) {
+ outputSize = decompressedSize;
+ }
}
- out->append(rc);
- CHECK_EQ(out->length(), rc);
+ IOBufQueue queue(IOBufQueue::cacheChainLength());
+
+ Cursor cursor(data);
+ for (rc = 0;;) {
+ if (in.pos == in.size) {
+ auto buffer = cursor.peekBytes();
+ in.src = buffer.data();
+ in.size = buffer.size();
+ in.pos = 0;
+ cursor.skip(in.size);
+ if (rc > 1 && in.size == 0) {
+ throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
+ }
+ }
+ if (out.pos == out.size) {
+ if (out.pos != 0) {
+ queue.postallocate(out.pos);
+ }
+ auto buffer = queue.preallocate(outputSize, outputSize);
+ out.dst = buffer.first;
+ out.size = buffer.second;
+ out.pos = 0;
+ outputSize = ZSTD_DStreamOutSize();
+ }
+ rc = ZSTD_decompressStream(zds, &out, &in);
+ zstdThrowIfError(rc);
+ if (rc == 0) {
+ break;
+ }
+ }
+ if (out.pos != 0) {
+ queue.postallocate(out.pos);
+ }
+ if (in.pos != in.size || !cursor.isAtEnd()) {
+ throw std::runtime_error("ZSTD: junk after end of data");
+ }
+ if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+ queue.chainLength() != uncompressedLength) {
+ throw std::runtime_error("ZSTD: invalid uncompressed length");
+ }
- return out;
+ return queue.move();
}
#endif // FOLLY_HAVE_LIBZSTD
#include <folly/io/Compression.h>
#include <random>
+#include <set>
#include <thread>
#include <unordered_map>
EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength());
EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE)
->needsUncompressedLength());
- EXPECT_TRUE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
+ EXPECT_FALSE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength());
}
class CompressionTest
- : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
- protected:
- void SetUp() override {
- auto tup = GetParam();
- uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
- codec_ = getCodec(std::tr1::get<1>(tup));
- }
+ : public testing::TestWithParam<std::tr1::tuple<int, int, CodecType>> {
+ protected:
+ void SetUp() override {
+ auto tup = GetParam();
+ uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
+ chunks_ = std::tr1::get<1>(tup);
+ codec_ = getCodec(std::tr1::get<2>(tup));
+ }
- void runSimpleTest(const DataHolder& dh);
+ void runSimpleTest(const DataHolder& dh);
- uint64_t uncompressedLength_;
- std::unique_ptr<Codec> codec_;
+ private:
+ std::unique_ptr<IOBuf> split(std::unique_ptr<IOBuf> data) const;
+
+ uint64_t uncompressedLength_;
+ size_t chunks_;
+ std::unique_ptr<Codec> codec_;
};
void CompressionTest::runSimpleTest(const DataHolder& dh) {
- auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
- auto compressed = codec_->compress(original.get());
+ const auto original = split(IOBuf::wrapBuffer(dh.data(uncompressedLength_)));
+ const auto compressed = split(codec_->compress(original.get()));
if (!codec_->needsUncompressedLength()) {
auto uncompressed = codec_->uncompress(compressed.get());
-
EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
}
}
}
+// Uniformly split data into (potentially empty) chunks.
+std::unique_ptr<IOBuf> CompressionTest::split(
+ std::unique_ptr<IOBuf> data) const {
+ if (data->isChained()) {
+ data->coalesce();
+ }
+
+ const size_t size = data->computeChainDataLength();
+
+ std::multiset<size_t> splits;
+ for (size_t i = 1; i < chunks_; ++i) {
+ splits.insert(Random::rand64(size));
+ }
+
+ folly::IOBufQueue result;
+
+ size_t offset = 0;
+ for (size_t split : splits) {
+ result.append(IOBuf::copyBuffer(data->data() + offset, split - offset));
+ offset = split;
+ }
+ result.append(IOBuf::copyBuffer(data->data() + offset, size - offset));
+
+ return result.move();
+}
+
TEST_P(CompressionTest, RandomData) {
runSimpleTest(randomDataHolder);
}
INSTANTIATE_TEST_CASE_P(
CompressionTest,
CompressionTest,
- testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
- testing::Values(CodecType::NO_COMPRESSION,
- CodecType::LZ4,
- CodecType::SNAPPY,
- CodecType::ZLIB,
- CodecType::LZ4_VARINT_SIZE,
- CodecType::LZMA2,
- CodecType::LZMA2_VARINT_SIZE,
- CodecType::ZSTD,
- CodecType::GZIP)));
+ testing::Combine(
+ testing::Values(0, 1, 12, 22, 25, 27),
+ testing::Values(1, 2, 3, 8, 65),
+ testing::Values(
+ CodecType::NO_COMPRESSION,
+ CodecType::LZ4,
+ CodecType::SNAPPY,
+ CodecType::ZLIB,
+ CodecType::LZ4_VARINT_SIZE,
+ CodecType::LZMA2,
+ CodecType::LZMA2_VARINT_SIZE,
+ CodecType::ZSTD,
+ CodecType::GZIP)));
class CompressionVarintTest
: public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
}
-TEST_P(CompressionVarintTest, RandomData) { runSimpleTest(randomDataHolder); }
+TEST_P(CompressionVarintTest, RandomData) {
+ runSimpleTest(randomDataHolder);
+}
TEST_P(CompressionVarintTest, ConstantData) {
runSimpleTest(constantDataHolder);
INSTANTIATE_TEST_CASE_P(
CompressionVarintTest,
CompressionVarintTest,
- testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
- testing::Values(CodecType::LZ4_VARINT_SIZE,
- CodecType::LZMA2_VARINT_SIZE)));
+ testing::Combine(
+ testing::Values(0, 1, 12, 22, 25, 27),
+ testing::Values(
+ CodecType::LZ4_VARINT_SIZE,
+ CodecType::LZMA2_VARINT_SIZE)));
class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
protected: