c7c4c8f8dd05d22ce743094a9404db0b69eaddde
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   } else if (len > maxUncompressedLength()) {
60     throw std::runtime_error("Codec: uncompressed length too large");
61   }
62
63   return doCompress(data);
64 }
65
66 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
67                                          uint64_t uncompressedLength) {
68   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
69     if (needsUncompressedLength()) {
70       throw std::invalid_argument("Codec: uncompressed length required");
71     }
72   } else if (uncompressedLength > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   if (data->empty()) {
77     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
78         uncompressedLength != 0) {
79       throw std::runtime_error("Codec: invalid uncompressed length");
80     }
81     return IOBuf::create(0);
82   }
83
84   return doUncompress(data, uncompressedLength);
85 }
86
87 bool Codec::needsUncompressedLength() const {
88   return doNeedsUncompressedLength();
89 }
90
91 uint64_t Codec::maxUncompressedLength() const {
92   return doMaxUncompressedLength();
93 }
94
95 bool Codec::doNeedsUncompressedLength() const {
96   return false;
97 }
98
99 uint64_t Codec::doMaxUncompressedLength() const {
100   return UNLIMITED_UNCOMPRESSED_LENGTH;
101 }
102
103 namespace {
104
105 /**
106  * No compression
107  */
108 class NoCompressionCodec final : public Codec {
109  public:
110   static std::unique_ptr<Codec> create(int level, CodecType type);
111   explicit NoCompressionCodec(int level, CodecType type);
112
113  private:
114   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
115   std::unique_ptr<IOBuf> doUncompress(
116       const IOBuf* data,
117       uint64_t uncompressedLength) override;
118 };
119
120 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
121   return make_unique<NoCompressionCodec>(level, type);
122 }
123
124 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
125   : Codec(type) {
126   DCHECK(type == CodecType::NO_COMPRESSION);
127   switch (level) {
128   case COMPRESSION_LEVEL_DEFAULT:
129   case COMPRESSION_LEVEL_FASTEST:
130   case COMPRESSION_LEVEL_BEST:
131     level = 0;
132   }
133   if (level != 0) {
134     throw std::invalid_argument(to<std::string>(
135         "NoCompressionCodec: invalid level ", level));
136   }
137 }
138
139 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
140     const IOBuf* data) {
141   return data->clone();
142 }
143
144 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
145     const IOBuf* data,
146     uint64_t uncompressedLength) {
147   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
148       data->computeChainDataLength() != uncompressedLength) {
149     throw std::runtime_error(to<std::string>(
150         "NoCompressionCodec: invalid uncompressed length"));
151   }
152   return data->clone();
153 }
154
155 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
156
157 namespace {
158
159 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
160   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
161   out->append(encodeVarint(val, out->writableTail()));
162 }
163
164 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
165   uint64_t val = 0;
166   int8_t b = 0;
167   for (int shift = 0; shift <= 63; shift += 7) {
168     b = cursor.read<int8_t>();
169     val |= static_cast<uint64_t>(b & 0x7f) << shift;
170     if (b >= 0) {
171       break;
172     }
173   }
174   if (b < 0) {
175     throw std::invalid_argument("Invalid varint value. Too big.");
176   }
177   return val;
178 }
179
180 }  // namespace
181
182 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
183
184 #if FOLLY_HAVE_LIBLZ4
185
186 /**
187  * LZ4 compression
188  */
189 class LZ4Codec final : public Codec {
190  public:
191   static std::unique_ptr<Codec> create(int level, CodecType type);
192   explicit LZ4Codec(int level, CodecType type);
193
194  private:
195   bool doNeedsUncompressedLength() const override;
196   uint64_t doMaxUncompressedLength() const override;
197
198   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
199
200   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
201   std::unique_ptr<IOBuf> doUncompress(
202       const IOBuf* data,
203       uint64_t uncompressedLength) override;
204
205   bool highCompression_;
206 };
207
208 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
209   return make_unique<LZ4Codec>(level, type);
210 }
211
212 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
213   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
214
215   switch (level) {
216   case COMPRESSION_LEVEL_FASTEST:
217   case COMPRESSION_LEVEL_DEFAULT:
218     level = 1;
219     break;
220   case COMPRESSION_LEVEL_BEST:
221     level = 2;
222     break;
223   }
224   if (level < 1 || level > 2) {
225     throw std::invalid_argument(to<std::string>(
226         "LZ4Codec: invalid level: ", level));
227   }
228   highCompression_ = (level > 1);
229 }
230
231 bool LZ4Codec::doNeedsUncompressedLength() const {
232   return !encodeSize();
233 }
234
235 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
236 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
237 // here.
238 #ifndef LZ4_MAX_INPUT_SIZE
239 # define LZ4_MAX_INPUT_SIZE 0x7E000000
240 #endif
241
242 uint64_t LZ4Codec::doMaxUncompressedLength() const {
243   return LZ4_MAX_INPUT_SIZE;
244 }
245
246 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
247   std::unique_ptr<IOBuf> clone;
248   if (data->isChained()) {
249     // LZ4 doesn't support streaming, so we have to coalesce
250     clone = data->clone();
251     clone->coalesce();
252     data = clone.get();
253   }
254
255   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
256   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
257   if (encodeSize()) {
258     encodeVarintToIOBuf(data->length(), out.get());
259   }
260
261   int n;
262   if (highCompression_) {
263     n = LZ4_compress_HC(
264         reinterpret_cast<const char*>(data->data()),
265         reinterpret_cast<char*>(out->writableTail()),
266         data->length(),
267         out->tailroom(),
268         0);
269   } else {
270     n = LZ4_compress_default(
271         reinterpret_cast<const char*>(data->data()),
272         reinterpret_cast<char*>(out->writableTail()),
273         data->length(),
274         out->tailroom());
275   }
276
277   CHECK_GE(n, 0);
278   CHECK_LE(n, out->capacity());
279
280   out->append(n);
281   return out;
282 }
283
284 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
285     const IOBuf* data,
286     uint64_t uncompressedLength) {
287   std::unique_ptr<IOBuf> clone;
288   if (data->isChained()) {
289     // LZ4 doesn't support streaming, so we have to coalesce
290     clone = data->clone();
291     clone->coalesce();
292     data = clone.get();
293   }
294
295   folly::io::Cursor cursor(data);
296   uint64_t actualUncompressedLength;
297   if (encodeSize()) {
298     actualUncompressedLength = decodeVarintFromCursor(cursor);
299     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
300         uncompressedLength != actualUncompressedLength) {
301       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
302     }
303   } else {
304     actualUncompressedLength = uncompressedLength;
305     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
306         actualUncompressedLength > maxUncompressedLength()) {
307       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
308     }
309   }
310
311   auto sp = StringPiece{cursor.peekBytes()};
312   auto out = IOBuf::create(actualUncompressedLength);
313   int n = LZ4_decompress_safe(
314       sp.data(),
315       reinterpret_cast<char*>(out->writableTail()),
316       sp.size(),
317       actualUncompressedLength);
318
319   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
320     throw std::runtime_error(to<std::string>(
321         "LZ4 decompression returned invalid value ", n));
322   }
323   out->append(actualUncompressedLength);
324   return out;
325 }
326
327 #endif  // FOLLY_HAVE_LIBLZ4
328
329 #if FOLLY_HAVE_LIBSNAPPY
330
331 /**
332  * Snappy compression
333  */
334
335 /**
336  * Implementation of snappy::Source that reads from a IOBuf chain.
337  */
338 class IOBufSnappySource final : public snappy::Source {
339  public:
340   explicit IOBufSnappySource(const IOBuf* data);
341   size_t Available() const override;
342   const char* Peek(size_t* len) override;
343   void Skip(size_t n) override;
344  private:
345   size_t available_;
346   io::Cursor cursor_;
347 };
348
349 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
350   : available_(data->computeChainDataLength()),
351     cursor_(data) {
352 }
353
354 size_t IOBufSnappySource::Available() const {
355   return available_;
356 }
357
358 const char* IOBufSnappySource::Peek(size_t* len) {
359   auto sp = StringPiece{cursor_.peekBytes()};
360   *len = sp.size();
361   return sp.data();
362 }
363
364 void IOBufSnappySource::Skip(size_t n) {
365   CHECK_LE(n, available_);
366   cursor_.skip(n);
367   available_ -= n;
368 }
369
370 class SnappyCodec final : public Codec {
371  public:
372   static std::unique_ptr<Codec> create(int level, CodecType type);
373   explicit SnappyCodec(int level, CodecType type);
374
375  private:
376   uint64_t doMaxUncompressedLength() const override;
377   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
378   std::unique_ptr<IOBuf> doUncompress(
379       const IOBuf* data,
380       uint64_t uncompressedLength) override;
381 };
382
383 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
384   return make_unique<SnappyCodec>(level, type);
385 }
386
387 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
388   DCHECK(type == CodecType::SNAPPY);
389   switch (level) {
390   case COMPRESSION_LEVEL_FASTEST:
391   case COMPRESSION_LEVEL_DEFAULT:
392   case COMPRESSION_LEVEL_BEST:
393     level = 1;
394   }
395   if (level != 1) {
396     throw std::invalid_argument(to<std::string>(
397         "SnappyCodec: invalid level: ", level));
398   }
399 }
400
401 uint64_t SnappyCodec::doMaxUncompressedLength() const {
402   // snappy.h uses uint32_t for lengths, so there's that.
403   return std::numeric_limits<uint32_t>::max();
404 }
405
406 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
407   IOBufSnappySource source(data);
408   auto out =
409     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
410
411   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
412       out->writableTail()));
413
414   size_t n = snappy::Compress(&source, &sink);
415
416   CHECK_LE(n, out->capacity());
417   out->append(n);
418   return out;
419 }
420
421 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
422                                                  uint64_t uncompressedLength) {
423   uint32_t actualUncompressedLength = 0;
424
425   {
426     IOBufSnappySource source(data);
427     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
428       throw std::runtime_error("snappy::GetUncompressedLength failed");
429     }
430     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
431         uncompressedLength != actualUncompressedLength) {
432       throw std::runtime_error("snappy: invalid uncompressed length");
433     }
434   }
435
436   auto out = IOBuf::create(actualUncompressedLength);
437
438   {
439     IOBufSnappySource source(data);
440     if (!snappy::RawUncompress(&source,
441                                reinterpret_cast<char*>(out->writableTail()))) {
442       throw std::runtime_error("snappy::RawUncompress failed");
443     }
444   }
445
446   out->append(actualUncompressedLength);
447   return out;
448 }
449
450 #endif  // FOLLY_HAVE_LIBSNAPPY
451
452 #if FOLLY_HAVE_LIBZ
453 /**
454  * Zlib codec
455  */
456 class ZlibCodec final : public Codec {
457  public:
458   static std::unique_ptr<Codec> create(int level, CodecType type);
459   explicit ZlibCodec(int level, CodecType type);
460
461  private:
462   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
463   std::unique_ptr<IOBuf> doUncompress(
464       const IOBuf* data,
465       uint64_t uncompressedLength) override;
466
467   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
468   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
469
470   int level_;
471 };
472
473 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
474   return make_unique<ZlibCodec>(level, type);
475 }
476
477 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
478   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
479   switch (level) {
480   case COMPRESSION_LEVEL_FASTEST:
481     level = 1;
482     break;
483   case COMPRESSION_LEVEL_DEFAULT:
484     level = Z_DEFAULT_COMPRESSION;
485     break;
486   case COMPRESSION_LEVEL_BEST:
487     level = 9;
488     break;
489   }
490   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
491     throw std::invalid_argument(to<std::string>(
492         "ZlibCodec: invalid level: ", level));
493   }
494   level_ = level;
495 }
496
497 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
498                                                   uint32_t length) {
499   CHECK_EQ(stream->avail_out, 0);
500
501   auto buf = IOBuf::create(length);
502   buf->append(length);
503
504   stream->next_out = buf->writableData();
505   stream->avail_out = buf->length();
506
507   return buf;
508 }
509
510 bool ZlibCodec::doInflate(z_stream* stream,
511                           IOBuf* head,
512                           uint32_t bufferLength) {
513   if (stream->avail_out == 0) {
514     head->prependChain(addOutputBuffer(stream, bufferLength));
515   }
516
517   int rc = inflate(stream, Z_NO_FLUSH);
518
519   switch (rc) {
520   case Z_OK:
521     break;
522   case Z_STREAM_END:
523     return true;
524   case Z_BUF_ERROR:
525   case Z_NEED_DICT:
526   case Z_DATA_ERROR:
527   case Z_MEM_ERROR:
528     throw std::runtime_error(to<std::string>(
529         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
530   default:
531     CHECK(false) << rc << ": " << stream->msg;
532   }
533
534   return false;
535 }
536
537 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
538   z_stream stream;
539   stream.zalloc = nullptr;
540   stream.zfree = nullptr;
541   stream.opaque = nullptr;
542
543   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
544   // base two logarithm of the maximum window size (...) The default value is
545   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
546   // around the compressed data instead of a zlib wrapper. The gzip header
547   // will have no file name, no extra data, no comment, no modification time
548   // (set to zero), no header crc, and the operating system will be set to 255
549   // (unknown)."
550   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
551   // All other parameters (method, memLevel, strategy) get default values from
552   // the zlib manual.
553   int rc = deflateInit2(&stream,
554                         level_,
555                         Z_DEFLATED,
556                         windowBits,
557                         /* memLevel */ 8,
558                         Z_DEFAULT_STRATEGY);
559   if (rc != Z_OK) {
560     throw std::runtime_error(to<std::string>(
561         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
562   }
563
564   stream.next_in = stream.next_out = nullptr;
565   stream.avail_in = stream.avail_out = 0;
566   stream.total_in = stream.total_out = 0;
567
568   bool success = false;
569
570   SCOPE_EXIT {
571     rc = deflateEnd(&stream);
572     // If we're here because of an exception, it's okay if some data
573     // got dropped.
574     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
575       << rc << ": " << stream.msg;
576   };
577
578   uint64_t uncompressedLength = data->computeChainDataLength();
579   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
580
581   // Max 64MiB in one go
582   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
583   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
584
585   auto out = addOutputBuffer(
586       &stream,
587       (maxCompressedLength <= maxSingleStepLength ?
588        maxCompressedLength :
589        defaultBufferLength));
590
591   for (auto& range : *data) {
592     uint64_t remaining = range.size();
593     uint64_t written = 0;
594     while (remaining) {
595       uint32_t step = (remaining > maxSingleStepLength ?
596                        maxSingleStepLength : remaining);
597       stream.next_in = const_cast<uint8_t*>(range.data() + written);
598       stream.avail_in = step;
599       remaining -= step;
600       written += step;
601
602       while (stream.avail_in != 0) {
603         if (stream.avail_out == 0) {
604           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
605         }
606
607         rc = deflate(&stream, Z_NO_FLUSH);
608
609         CHECK_EQ(rc, Z_OK) << stream.msg;
610       }
611     }
612   }
613
614   do {
615     if (stream.avail_out == 0) {
616       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
617     }
618
619     rc = deflate(&stream, Z_FINISH);
620   } while (rc == Z_OK);
621
622   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
623
624   out->prev()->trimEnd(stream.avail_out);
625
626   success = true;  // we survived
627
628   return out;
629 }
630
631 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
632                                                uint64_t uncompressedLength) {
633   z_stream stream;
634   stream.zalloc = nullptr;
635   stream.zfree = nullptr;
636   stream.opaque = nullptr;
637
638   // "The windowBits parameter is the base two logarithm of the maximum window
639   // size (...) The default value is 15 (...) add 16 to decode only the gzip
640   // format (the zlib format will return a Z_DATA_ERROR)."
641   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
642   int rc = inflateInit2(&stream, windowBits);
643   if (rc != Z_OK) {
644     throw std::runtime_error(to<std::string>(
645         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
646   }
647
648   stream.next_in = stream.next_out = nullptr;
649   stream.avail_in = stream.avail_out = 0;
650   stream.total_in = stream.total_out = 0;
651
652   bool success = false;
653
654   SCOPE_EXIT {
655     rc = inflateEnd(&stream);
656     // If we're here because of an exception, it's okay if some data
657     // got dropped.
658     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
659       << rc << ": " << stream.msg;
660   };
661
662   // Max 64MiB in one go
663   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
664   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
665
666   auto out = addOutputBuffer(
667       &stream,
668       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
669         uncompressedLength <= maxSingleStepLength) ?
670        uncompressedLength :
671        defaultBufferLength));
672
673   bool streamEnd = false;
674   for (auto& range : *data) {
675     if (range.empty()) {
676       continue;
677     }
678
679     stream.next_in = const_cast<uint8_t*>(range.data());
680     stream.avail_in = range.size();
681
682     while (stream.avail_in != 0) {
683       if (streamEnd) {
684         throw std::runtime_error(to<std::string>(
685             "ZlibCodec: junk after end of data"));
686       }
687
688       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
689     }
690   }
691
692   while (!streamEnd) {
693     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
694   }
695
696   out->prev()->trimEnd(stream.avail_out);
697
698   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
699       uncompressedLength != stream.total_out) {
700     throw std::runtime_error(to<std::string>(
701         "ZlibCodec: invalid uncompressed length"));
702   }
703
704   success = true;  // we survived
705
706   return out;
707 }
708
709 #endif  // FOLLY_HAVE_LIBZ
710
711 #if FOLLY_HAVE_LIBLZMA
712
713 /**
714  * LZMA2 compression
715  */
716 class LZMA2Codec final : public Codec {
717  public:
718   static std::unique_ptr<Codec> create(int level, CodecType type);
719   explicit LZMA2Codec(int level, CodecType type);
720
721  private:
722   bool doNeedsUncompressedLength() const override;
723   uint64_t doMaxUncompressedLength() const override;
724
725   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
726
727   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
728   std::unique_ptr<IOBuf> doUncompress(
729       const IOBuf* data,
730       uint64_t uncompressedLength) override;
731
732   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
733   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
734
735   int level_;
736 };
737
738 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
739   return make_unique<LZMA2Codec>(level, type);
740 }
741
742 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
743   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
744   switch (level) {
745   case COMPRESSION_LEVEL_FASTEST:
746     level = 0;
747     break;
748   case COMPRESSION_LEVEL_DEFAULT:
749     level = LZMA_PRESET_DEFAULT;
750     break;
751   case COMPRESSION_LEVEL_BEST:
752     level = 9;
753     break;
754   }
755   if (level < 0 || level > 9) {
756     throw std::invalid_argument(to<std::string>(
757         "LZMA2Codec: invalid level: ", level));
758   }
759   level_ = level;
760 }
761
762 bool LZMA2Codec::doNeedsUncompressedLength() const {
763   return !encodeSize();
764 }
765
766 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
767   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
768   return uint64_t(1) << 63;
769 }
770
771 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
772     lzma_stream* stream,
773     size_t length) {
774
775   CHECK_EQ(stream->avail_out, 0);
776
777   auto buf = IOBuf::create(length);
778   buf->append(length);
779
780   stream->next_out = buf->writableData();
781   stream->avail_out = buf->length();
782
783   return buf;
784 }
785
786 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
787   lzma_ret rc;
788   lzma_stream stream = LZMA_STREAM_INIT;
789
790   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
791   if (rc != LZMA_OK) {
792     throw std::runtime_error(folly::to<std::string>(
793       "LZMA2Codec: lzma_easy_encoder error: ", rc));
794   }
795
796   SCOPE_EXIT { lzma_end(&stream); };
797
798   uint64_t uncompressedLength = data->computeChainDataLength();
799   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
800
801   // Max 64MiB in one go
802   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
803   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
804
805   auto out = addOutputBuffer(
806     &stream,
807     (maxCompressedLength <= maxSingleStepLength ?
808      maxCompressedLength :
809      defaultBufferLength));
810
811   if (encodeSize()) {
812     auto size = IOBuf::createCombined(kMaxVarintLength64);
813     encodeVarintToIOBuf(uncompressedLength, size.get());
814     size->appendChain(std::move(out));
815     out = std::move(size);
816   }
817
818   for (auto& range : *data) {
819     if (range.empty()) {
820       continue;
821     }
822
823     stream.next_in = const_cast<uint8_t*>(range.data());
824     stream.avail_in = range.size();
825
826     while (stream.avail_in != 0) {
827       if (stream.avail_out == 0) {
828         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
829       }
830
831       rc = lzma_code(&stream, LZMA_RUN);
832
833       if (rc != LZMA_OK) {
834         throw std::runtime_error(folly::to<std::string>(
835           "LZMA2Codec: lzma_code error: ", rc));
836       }
837     }
838   }
839
840   do {
841     if (stream.avail_out == 0) {
842       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
843     }
844
845     rc = lzma_code(&stream, LZMA_FINISH);
846   } while (rc == LZMA_OK);
847
848   if (rc != LZMA_STREAM_END) {
849     throw std::runtime_error(folly::to<std::string>(
850       "LZMA2Codec: lzma_code ended with error: ", rc));
851   }
852
853   out->prev()->trimEnd(stream.avail_out);
854
855   return out;
856 }
857
858 bool LZMA2Codec::doInflate(lzma_stream* stream,
859                           IOBuf* head,
860                           size_t bufferLength) {
861   if (stream->avail_out == 0) {
862     head->prependChain(addOutputBuffer(stream, bufferLength));
863   }
864
865   lzma_ret rc = lzma_code(stream, LZMA_RUN);
866
867   switch (rc) {
868   case LZMA_OK:
869     break;
870   case LZMA_STREAM_END:
871     return true;
872   default:
873     throw std::runtime_error(to<std::string>(
874         "LZMA2Codec: lzma_code error: ", rc));
875   }
876
877   return false;
878 }
879
880 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
881                                                uint64_t uncompressedLength) {
882   lzma_ret rc;
883   lzma_stream stream = LZMA_STREAM_INIT;
884
885   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
886   if (rc != LZMA_OK) {
887     throw std::runtime_error(folly::to<std::string>(
888       "LZMA2Codec: lzma_auto_decoder error: ", rc));
889   }
890
891   SCOPE_EXIT { lzma_end(&stream); };
892
893   // Max 64MiB in one go
894   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
895   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
896
897   folly::io::Cursor cursor(data);
898   uint64_t actualUncompressedLength;
899   if (encodeSize()) {
900     actualUncompressedLength = decodeVarintFromCursor(cursor);
901     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
902         uncompressedLength != actualUncompressedLength) {
903       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
904     }
905   } else {
906     actualUncompressedLength = uncompressedLength;
907     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
908   }
909
910   auto out = addOutputBuffer(
911       &stream,
912       (actualUncompressedLength <= maxSingleStepLength ?
913        actualUncompressedLength :
914        defaultBufferLength));
915
916   bool streamEnd = false;
917   auto buf = cursor.peekBytes();
918   while (!buf.empty()) {
919     stream.next_in = const_cast<uint8_t*>(buf.data());
920     stream.avail_in = buf.size();
921
922     while (stream.avail_in != 0) {
923       if (streamEnd) {
924         throw std::runtime_error(to<std::string>(
925             "LZMA2Codec: junk after end of data"));
926       }
927
928       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
929     }
930
931     cursor.skip(buf.size());
932     buf = cursor.peekBytes();
933   }
934
935   while (!streamEnd) {
936     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
937   }
938
939   out->prev()->trimEnd(stream.avail_out);
940
941   if (actualUncompressedLength != stream.total_out) {
942     throw std::runtime_error(to<std::string>(
943         "LZMA2Codec: invalid uncompressed length"));
944   }
945
946   return out;
947 }
948
949 #endif  // FOLLY_HAVE_LIBLZMA
950
951 #ifdef FOLLY_HAVE_LIBZSTD
952
953 /**
954  * ZSTD compression
955  */
956 class ZSTDCodec final : public Codec {
957  public:
958   static std::unique_ptr<Codec> create(int level, CodecType);
959   explicit ZSTDCodec(int level, CodecType type);
960
961  private:
962   bool doNeedsUncompressedLength() const override;
963   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
964   std::unique_ptr<IOBuf> doUncompress(
965       const IOBuf* data,
966       uint64_t uncompressedLength) override;
967
968   int level_;
969 };
970
971 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
972   return make_unique<ZSTDCodec>(level, type);
973 }
974
975 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
976   DCHECK(type == CodecType::ZSTD);
977   switch (level) {
978     case COMPRESSION_LEVEL_FASTEST:
979       level = 1;
980       break;
981     case COMPRESSION_LEVEL_DEFAULT:
982       level = 1;
983       break;
984     case COMPRESSION_LEVEL_BEST:
985       level = 19;
986       break;
987   }
988   if (level < 1 || level > ZSTD_maxCLevel()) {
989     throw std::invalid_argument(
990         to<std::string>("ZSTD: invalid level: ", level));
991   }
992   level_ = level;
993 }
994
995 bool ZSTDCodec::doNeedsUncompressedLength() const {
996   return false;
997 }
998
999 void zstdThrowIfError(size_t rc) {
1000   if (!ZSTD_isError(rc)) {
1001     return;
1002   }
1003   throw std::runtime_error(
1004       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1005 }
1006
1007 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1008   // Support earlier versions of the codec (working with a single IOBuf,
1009   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1010   // which isn't populated by streaming API).
1011   if (!data->isChained()) {
1012     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1013     const auto rc = ZSTD_compress(
1014         out->writableData(),
1015         out->capacity(),
1016         data->data(),
1017         data->length(),
1018         level_);
1019     zstdThrowIfError(rc);
1020     out->append(rc);
1021     return out;
1022   }
1023
1024   auto zcs = ZSTD_createCStream();
1025   SCOPE_EXIT {
1026     ZSTD_freeCStream(zcs);
1027   };
1028
1029   auto rc = ZSTD_initCStream(zcs, level_);
1030   zstdThrowIfError(rc);
1031
1032   Cursor cursor(data);
1033   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1034
1035   ZSTD_outBuffer out;
1036   out.dst = result->writableTail();
1037   out.size = result->capacity();
1038   out.pos = 0;
1039
1040   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1041     ZSTD_inBuffer in;
1042     in.src = buffer.data();
1043     in.size = buffer.size();
1044     for (in.pos = 0; in.pos != in.size;) {
1045       rc = ZSTD_compressStream(zcs, &out, &in);
1046       zstdThrowIfError(rc);
1047     }
1048     cursor.skip(in.size);
1049     buffer = cursor.peekBytes();
1050   }
1051
1052   rc = ZSTD_endStream(zcs, &out);
1053   zstdThrowIfError(rc);
1054   CHECK_EQ(rc, 0);
1055
1056   result->append(out.pos);
1057   return result;
1058 }
1059
1060 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1061     const IOBuf* data,
1062     uint64_t uncompressedLength) {
1063   auto zds = ZSTD_createDStream();
1064   SCOPE_EXIT {
1065     ZSTD_freeDStream(zds);
1066   };
1067
1068   auto rc = ZSTD_initDStream(zds);
1069   zstdThrowIfError(rc);
1070
1071   ZSTD_outBuffer out{};
1072   ZSTD_inBuffer in{};
1073
1074   auto outputSize = ZSTD_DStreamOutSize();
1075   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
1076     outputSize = uncompressedLength;
1077   } else {
1078     auto decompressedSize =
1079         ZSTD_getDecompressedSize(data->data(), data->length());
1080     if (decompressedSize != 0 && decompressedSize < outputSize) {
1081       outputSize = decompressedSize;
1082     }
1083   }
1084
1085   IOBufQueue queue(IOBufQueue::cacheChainLength());
1086
1087   Cursor cursor(data);
1088   for (rc = 0;;) {
1089     if (in.pos == in.size) {
1090       auto buffer = cursor.peekBytes();
1091       in.src = buffer.data();
1092       in.size = buffer.size();
1093       in.pos = 0;
1094       cursor.skip(in.size);
1095       if (rc > 1 && in.size == 0) {
1096         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1097       }
1098     }
1099     if (out.pos == out.size) {
1100       if (out.pos != 0) {
1101         queue.postallocate(out.pos);
1102       }
1103       auto buffer = queue.preallocate(outputSize, outputSize);
1104       out.dst = buffer.first;
1105       out.size = buffer.second;
1106       out.pos = 0;
1107       outputSize = ZSTD_DStreamOutSize();
1108     }
1109     rc = ZSTD_decompressStream(zds, &out, &in);
1110     zstdThrowIfError(rc);
1111     if (rc == 0) {
1112       break;
1113     }
1114   }
1115   if (out.pos != 0) {
1116     queue.postallocate(out.pos);
1117   }
1118   if (in.pos != in.size || !cursor.isAtEnd()) {
1119     throw std::runtime_error("ZSTD: junk after end of data");
1120   }
1121   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1122       queue.chainLength() != uncompressedLength) {
1123     throw std::runtime_error("ZSTD: invalid uncompressed length");
1124   }
1125
1126   return queue.move();
1127 }
1128
1129 #endif  // FOLLY_HAVE_LIBZSTD
1130
1131 }  // namespace
1132
1133 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1134   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1135
1136   static CodecFactory codecFactories[
1137     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1138     nullptr,  // USER_DEFINED
1139     NoCompressionCodec::create,
1140
1141 #if FOLLY_HAVE_LIBLZ4
1142     LZ4Codec::create,
1143 #else
1144     nullptr,
1145 #endif
1146
1147 #if FOLLY_HAVE_LIBSNAPPY
1148     SnappyCodec::create,
1149 #else
1150     nullptr,
1151 #endif
1152
1153 #if FOLLY_HAVE_LIBZ
1154     ZlibCodec::create,
1155 #else
1156     nullptr,
1157 #endif
1158
1159 #if FOLLY_HAVE_LIBLZ4
1160     LZ4Codec::create,
1161 #else
1162     nullptr,
1163 #endif
1164
1165 #if FOLLY_HAVE_LIBLZMA
1166     LZMA2Codec::create,
1167     LZMA2Codec::create,
1168 #else
1169     nullptr,
1170     nullptr,
1171 #endif
1172
1173 #if FOLLY_HAVE_LIBZSTD
1174     ZSTDCodec::create,
1175 #else
1176     nullptr,
1177 #endif
1178
1179 #if FOLLY_HAVE_LIBZ
1180     ZlibCodec::create,
1181 #else
1182     nullptr,
1183 #endif
1184   };
1185
1186   size_t idx = static_cast<size_t>(type);
1187   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1188     throw std::invalid_argument(to<std::string>(
1189         "Compression type ", idx, " not supported"));
1190   }
1191   auto factory = codecFactories[idx];
1192   if (!factory) {
1193     throw std::invalid_argument(to<std::string>(
1194         "Compression type ", idx, " not supported"));
1195   }
1196   auto codec = (*factory)(level, type);
1197   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1198   return codec;
1199 }
1200
1201 }}  // namespaces