return result;
}
-std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+static std::unique_ptr<IOBuf> zstdUncompressBuffer(
+ const IOBuf* data,
+ uint64_t uncompressedLength) {
+ // Check preconditions
+ DCHECK(!data->isChained());
+ DCHECK(uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH);
+
+ auto uncompressed = IOBuf::create(uncompressedLength);
+ const auto decompressedSize = ZSTD_decompress(
+ uncompressed->writableTail(),
+ uncompressed->tailroom(),
+ data->data(),
+ data->length());
+ zstdThrowIfError(decompressedSize);
+ if (decompressedSize != uncompressedLength) {
+ throw std::runtime_error("ZSTD: invalid uncompressed length");
+ }
+ uncompressed->append(decompressedSize);
+ return uncompressed;
+}
+
+static std::unique_ptr<IOBuf> zstdUncompressStream(
const IOBuf* data,
uint64_t uncompressedLength) {
auto zds = ZSTD_createDStream();
ZSTD_inBuffer in{};
auto outputSize = ZSTD_DStreamOutSize();
- if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
+ if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH) {
outputSize = uncompressedLength;
- } else {
- auto decompressedSize =
- ZSTD_getDecompressedSize(data->data(), data->length());
- if (decompressedSize != 0 && decompressedSize < outputSize) {
- outputSize = decompressedSize;
- }
}
IOBufQueue queue(IOBufQueue::cacheChainLength());
if (in.pos != in.size || !cursor.isAtEnd()) {
throw std::runtime_error("ZSTD: junk after end of data");
}
- if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+ if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
queue.chainLength() != uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
}
return queue.move();
}
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+ const IOBuf* data,
+ uint64_t uncompressedLength) {
+ {
+ // Read decompressed size from frame if available in first IOBuf.
+ const auto decompressedSize =
+ ZSTD_getDecompressedSize(data->data(), data->length());
+ if (decompressedSize != 0) {
+ if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
+ uncompressedLength != decompressedSize) {
+ throw std::runtime_error("ZSTD: invalid uncompressed length");
+ }
+ uncompressedLength = decompressedSize;
+ }
+ }
+ // Faster to decompress using ZSTD_decompress() if we can.
+ if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && !data->isChained()) {
+ return zstdUncompressBuffer(data, uncompressedLength);
+ }
+ // Fall back to slower streaming decompression.
+ return zstdUncompressStream(data, uncompressedLength);
+}
+
#endif // FOLLY_HAVE_LIBZSTD
} // namespace