* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Adaptor.hh"
#include "Compression.hh"
#include "lz4.h"
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <sstream>
#include "securec.h"
#include "zlib.h"
#include "zstd.h"
#include "wrap/snappy_wrapper.h"
#ifndef ZSTD_CLEVEL_DEFAULT
#define ZSTD_CLEVEL_DEFAULT 3
#endif
#ifndef LZ4_ACCELERATION_DEFAULT
#define LZ4_ACCELERATION_DEFAULT 1
#endif
#ifndef LZ4_ACCELERATION_MAX
#define LZ4_ACCELERATION_MAX 65537
#endif
namespace spark {
class CompressionStreamBase: public BufferedOutputStream {
public:
CompressionStreamBase(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool);
virtual bool Next(void** data, int*size) override = 0;
virtual void BackUp(int count) override;
virtual std::string getName() const override = 0;
virtual uint64_t flush() override;
virtual bool isCompressed() const override { return true; }
virtual uint64_t getSize() const override;
protected:
void writeHeader(char * buffer, size_t compressedSize, bool original) {
buffer[0] = static_cast<char>((compressedSize << 1) + (original ? 1 : 0));
buffer[1] = static_cast<char>(compressedSize >> 7);
buffer[2] = static_cast<char>(compressedSize >> 15);
}
void ensureHeader();
DataBuffer<unsigned char> rawInputBuffer;
int level;
char * outputBuffer;
int bufferSize;
int outputPosition;
int outputSize;
};
CompressionStreamBase::CompressionStreamBase(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool) :
BufferedOutputStream(pool,
outStream,
capacity,
blockSize),
rawInputBuffer(pool, blockSize),
level(compressionLevel),
outputBuffer(nullptr),
bufferSize(0),
outputPosition(0),
outputSize(0) {
}
void CompressionStreamBase::BackUp(int count) {
if (count > bufferSize) {
throw std::logic_error("Can't backup that much!");
}
bufferSize -= count;
}
uint64_t CompressionStreamBase::flush() {
void * data;
int size;
if (!Next(&data, &size)) {
throw std::runtime_error("Failed to flush compression buffer.");
}
BufferedOutputStream::BackUp(outputSize - outputPosition);
bufferSize = outputSize = outputPosition = 0;
return BufferedOutputStream::flush();
}
uint64_t CompressionStreamBase::getSize() const {
return BufferedOutputStream::getSize() -
static_cast<uint64_t>(outputSize - outputPosition);
}
void CompressionStreamBase::ensureHeader() {
if (outputPosition + 3 >= outputSize) {
int newPosition = outputPosition + 3 - outputSize;
if (!BufferedOutputStream::Next(
reinterpret_cast<void **>(&outputBuffer),
&outputSize)) {
throw std::runtime_error(
"Failed to get next output buffer from output stream.");
}
outputPosition = newPosition;
} else {
outputPosition += 3;
}
}
* Streaming compression base class
*/
class CompressionStream: public CompressionStreamBase {
public:
CompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool);
virtual bool Next(void** data, int*size) override;
virtual std::string getName() const override = 0;
protected:
virtual uint64_t doStreamingCompression() = 0;
};
CompressionStream::CompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool) :
CompressionStreamBase(outStream,
compressionLevel,
capacity,
blockSize,
pool) {
}
bool CompressionStream::Next(void** data, int*size) {
if (bufferSize != 0) {
ensureHeader();
uint64_t totalCompressedSize = doStreamingCompression();
char * header = outputBuffer + outputPosition - totalCompressedSize - 3;
if (totalCompressedSize >= static_cast<unsigned long>(bufferSize)) {
writeHeader(header, static_cast<size_t>(bufferSize), true);
memcpy_s(
header + 3,
static_cast<size_t>(bufferSize),
rawInputBuffer.data(),
static_cast<size_t>(bufferSize));
int backup = static_cast<int>(totalCompressedSize) - bufferSize;
BufferedOutputStream::BackUp(backup);
outputPosition -= backup;
outputSize -= backup;
} else {
writeHeader(header, totalCompressedSize, false);
}
}
*data = rawInputBuffer.data();
*size = static_cast<int>(rawInputBuffer.size());
bufferSize = *size;
return true;
}
class ZlibCompressionStream: public CompressionStream {
public:
ZlibCompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool);
virtual ~ZlibCompressionStream() override {
end();
}
virtual std::string getName() const override;
protected:
virtual uint64_t doStreamingCompression() override;
private:
void init();
void end();
z_stream strm;
};
ZlibCompressionStream::ZlibCompressionStream(
OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool)
: CompressionStream(outStream,
compressionLevel,
capacity,
blockSize,
pool) {
init();
}
uint64_t ZlibCompressionStream::doStreamingCompression() {
if (deflateReset(&strm) != Z_OK) {
throw std::runtime_error("Failed to reset inflate.");
}
strm.avail_in = static_cast<unsigned int>(bufferSize);
strm.next_in = rawInputBuffer.data();
do {
if (outputPosition >= outputSize) {
if (!BufferedOutputStream::Next(
reinterpret_cast<void **>(&outputBuffer),
&outputSize)) {
throw std::runtime_error(
"Failed to get next output buffer from output stream.");
}
outputPosition = 0;
}
strm.next_out = reinterpret_cast<unsigned char *>
(outputBuffer + outputPosition);
strm.avail_out = static_cast<unsigned int>
(outputSize - outputPosition);
int ret = deflate(&strm, Z_FINISH);
outputPosition = outputSize - static_cast<int>(strm.avail_out);
if (ret == Z_STREAM_END) {
break;
} else if (ret == Z_OK) {
} else {
throw std::runtime_error("Failed to deflate input data.");
}
} while (strm.avail_out == 0);
return strm.total_out;
}
std::string ZlibCompressionStream::getName() const {
return "ZlibCompressionStream";
}
#if defined(__GNUC__) || defined(__clang__)
DIAGNOSTIC_IGNORE("-Wold-style-cast")
#endif
void ZlibCompressionStream::init() {
strm.zalloc = nullptr;
strm.zfree = nullptr;
strm.opaque = nullptr;
strm.next_in = nullptr;
if (deflateInit2(&strm, level, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY)
!= Z_OK) {
throw std::runtime_error("Error while calling deflateInit2() for zlib.");
}
}
void ZlibCompressionStream::end() {
(void)deflateEnd(&strm);
}
enum DecompressState { DECOMPRESS_HEADER,
DECOMPRESS_START,
DECOMPRESS_CONTINUE,
DECOMPRESS_ORIGINAL,
DECOMPRESS_EOF};
#if defined(__GNUC__) || defined(__clang__)
DIAGNOSTIC_IGNORE("-Wold-style-cast")
#endif
* Block compression base class
*/
class BlockCompressionStream: public CompressionStreamBase {
public:
BlockCompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool)
: CompressionStreamBase(outStream,
compressionLevel,
capacity,
blockSize,
pool)
, compressorBuffer(pool) {
}
virtual bool Next(void** data, int*size) override;
virtual std::string getName() const override = 0;
protected:
virtual uint64_t doBlockCompression() = 0;
virtual uint64_t estimateMaxCompressionSize() = 0;
DataBuffer<unsigned char> compressorBuffer;
};
bool BlockCompressionStream::Next(void** data, int*size) {
if (bufferSize != 0) {
ensureHeader();
size_t totalCompressedSize = doBlockCompression();
const unsigned char * dataToWrite = nullptr;
int totalSizeToWrite = 0;
char * header = outputBuffer + outputPosition - 3;
if (totalCompressedSize >= static_cast<size_t>(bufferSize)) {
writeHeader(header, static_cast<size_t>(bufferSize), true);
dataToWrite = rawInputBuffer.data();
totalSizeToWrite = bufferSize;
} else {
writeHeader(header, totalCompressedSize, false);
dataToWrite = compressorBuffer.data();
totalSizeToWrite = static_cast<int>(totalCompressedSize);
}
char * dst = header + 3;
while (totalSizeToWrite > 0) {
if (outputPosition == outputSize) {
if (!BufferedOutputStream::Next(reinterpret_cast<void **>(&outputBuffer),
&outputSize)) {
throw std::logic_error(
"Failed to get next output buffer from output stream.");
}
outputPosition = 0;
dst = outputBuffer;
} else if (outputPosition > outputSize) {
throw std::logic_error("Write to an out-of-bound place!");
}
int sizeToWrite = std::min(totalSizeToWrite, outputSize - outputPosition);
memcpy_s(dst, static_cast<size_t>(sizeToWrite), dataToWrite, static_cast<size_t>(sizeToWrite));
outputPosition += sizeToWrite;
dataToWrite += sizeToWrite;
totalSizeToWrite -= sizeToWrite;
dst += sizeToWrite;
}
}
*data = rawInputBuffer.data();
*size = static_cast<int>(rawInputBuffer.size());
bufferSize = *size;
compressorBuffer.resize(estimateMaxCompressionSize());
return true;
}
* LZ4 block compression
*/
class Lz4CompressionSteam: public BlockCompressionStream {
public:
Lz4CompressionSteam(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool)
: BlockCompressionStream(outStream,
compressionLevel,
capacity,
blockSize,
pool) {
this->init();
}
virtual std::string getName() const override {
return "Lz4CompressionStream";
}
virtual ~Lz4CompressionSteam() override {
this->end();
}
protected:
virtual uint64_t doBlockCompression() override;
virtual uint64_t estimateMaxCompressionSize() override {
return static_cast<uint64_t>(LZ4_compressBound(bufferSize));
}
private:
void init();
void end();
LZ4_stream_t *state;
};
uint64_t Lz4CompressionSteam::doBlockCompression() {
int result = LZ4_compress_fast_extState(static_cast<void*>(state),
reinterpret_cast<const char*>(rawInputBuffer.data()),
reinterpret_cast<char*>(compressorBuffer.data()),
bufferSize,
static_cast<int>(compressorBuffer.size()),
level);
if (result == 0) {
throw std::runtime_error("Error during block compression using lz4.");
}
return static_cast<uint64_t>(result);
}
void Lz4CompressionSteam::init() {
state = LZ4_createStream();
if (!state) {
throw std::runtime_error("Error while allocating state for lz4.");
}
}
void Lz4CompressionSteam::end() {
(void)LZ4_freeStream(state);
state = nullptr;
}
* Snappy block compression
*/
class SnappyCompressionStream: public BlockCompressionStream {
public:
SnappyCompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool)
: BlockCompressionStream(outStream,
compressionLevel,
capacity,
blockSize,
pool) {
}
virtual std::string getName() const override {
return "SnappyCompressionStream";
}
virtual ~SnappyCompressionStream() override {
}
protected:
virtual uint64_t doBlockCompression() override;
virtual uint64_t estimateMaxCompressionSize() override {
return static_cast<uint64_t>
(snappy::MaxCompressedLength(static_cast<size_t>(bufferSize)));
}
};
uint64_t SnappyCompressionStream::doBlockCompression() {
size_t compressedLength;
snappy::RawCompress(reinterpret_cast<const char*>(rawInputBuffer.data()),
static_cast<size_t>(bufferSize),
reinterpret_cast<char*>(compressorBuffer.data()),
&compressedLength);
return static_cast<uint64_t>(compressedLength);
}
* ZSTD block compression
*/
class ZSTDCompressionStream: public BlockCompressionStream{
public:
ZSTDCompressionStream(OutputStream * outStream,
int compressionLevel,
uint64_t capacity,
uint64_t blockSize,
MemoryPool& pool)
: BlockCompressionStream(outStream,
compressionLevel,
capacity,
blockSize,
pool) {
this->init();
}
virtual std::string getName() const override {
return "ZstdCompressionStream";
}
virtual ~ZSTDCompressionStream() override {
this->end();
}
protected:
virtual uint64_t doBlockCompression() override;
virtual uint64_t estimateMaxCompressionSize() override {
return ZSTD_compressBound(static_cast<size_t>(bufferSize));
}
private:
void init();
void end();
ZSTD_CCtx *cctx;
};
uint64_t ZSTDCompressionStream::doBlockCompression() {
return ZSTD_compressCCtx(cctx,
compressorBuffer.data(),
compressorBuffer.size(),
rawInputBuffer.data(),
static_cast<size_t>(bufferSize),
level);
}
#if defined(__GNUC__) || defined(__clang__)
DIAGNOSTIC_IGNORE("-Wold-style-cast")
#endif
void ZSTDCompressionStream::init() {
cctx = ZSTD_createCCtx();
if (!cctx) {
throw std::runtime_error("Error while calling ZSTD_createCCtx() for zstd.");
}
}
void ZSTDCompressionStream::end() {
(void)ZSTD_freeCCtx(cctx);
cctx = nullptr;
}
#if defined(__GNUC__) || defined(__clang__)
DIAGNOSTIC_IGNORE("-Wold-style-cast")
#endif
std::unique_ptr<BufferedOutputStream>
createCompressor(
CompressionKind kind,
OutputStream * outStream,
CompressionStrategy strategy,
uint64_t bufferCapacity,
uint64_t compressionBlockSize,
MemoryPool& pool) {
switch (static_cast<int64_t>(kind)) {
case CompressionKind_NONE: {
return std::unique_ptr<BufferedOutputStream>
(new BufferedOutputStream(
pool, outStream, bufferCapacity, compressionBlockSize));
}
case CompressionKind_ZLIB: {
int level = (strategy == CompressionStrategy_SPEED) ?
Z_BEST_SPEED + 1 : Z_DEFAULT_COMPRESSION;
return std::unique_ptr<BufferedOutputStream>
(new ZlibCompressionStream(
outStream, level, bufferCapacity, compressionBlockSize, pool));
}
case CompressionKind_ZSTD: {
int level = (strategy == CompressionStrategy_SPEED) ?
1 : ZSTD_CLEVEL_DEFAULT;
return std::unique_ptr<BufferedOutputStream>
(new ZSTDCompressionStream(
outStream, level, bufferCapacity, compressionBlockSize, pool));
}
case CompressionKind_LZ4: {
int level = (strategy == CompressionStrategy_SPEED) ?
LZ4_ACCELERATION_MAX : LZ4_ACCELERATION_DEFAULT;
return std::unique_ptr<BufferedOutputStream>
(new Lz4CompressionSteam(
outStream, level, bufferCapacity, compressionBlockSize, pool));
}
case CompressionKind_SNAPPY: {
int level = 0;
return std::unique_ptr<BufferedOutputStream>
(new SnappyCompressionStream(
outStream, level, bufferCapacity, compressionBlockSize, pool));
}
case CompressionKind_LZO:
default:
throw std::logic_error("compression codec not supported");
}
}
}