* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <arrow/record_batch.h>
#include <arrow/result.h>
#include <arrow/util/compression.h>
#include <gtest/gtest.h>
#include "LocalRssClient.h"
#include "memory/VeloxColumnarBatch.h"
#include "shuffle/PartitionWriter.h"
#include "shuffle/VeloxShuffleReader.h"
#include "utils/Compression.h"
#include "velox/type/Type.h"
#include "velox/vector/tests/VectorTestUtils.h"
namespace gluten {
namespace {
std::string makeString(uint32_t length) {
static const std::string kLargeStringOf128Bytes =
"thisisalaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaargestringlengthmorethan16bytes";
std::string res{};
auto repeats = length / kLargeStringOf128Bytes.length();
while (repeats--) {
res.append(kLargeStringOf128Bytes);
}
if (auto remains = length % kLargeStringOf128Bytes.length()) {
res.append(kLargeStringOf128Bytes.substr(0, remains));
}
return res;
}
std::unique_ptr<PartitionWriter> createPartitionWriter(
PartitionWriterType partitionWriterType,
uint32_t numPartitions,
const std::string& dataFile,
const std::vector<std::string>& localDirs,
const PartitionWriterOptions& options,
arrow::MemoryPool* pool) {
if (partitionWriterType == PartitionWriterType::kRss) {
auto rssClient = std::make_unique<LocalRssClient>(dataFile);
return std::make_unique<RssPartitionWriter>(numPartitions, options, pool, std::move(rssClient));
}
return std::make_unique<LocalPartitionWriter>(numPartitions, options, pool, dataFile, localDirs);
}
}
struct ShuffleTestParams {
ShuffleWriterType shuffleWriterType;
PartitionWriterType partitionWriterType;
arrow::Compression::type compressionType;
int32_t compressionThreshold{0};
int32_t mergeBufferSize{0};
int32_t compressionBufferSize{0};
bool useRadixSort{false};
std::string toString() const {
std::ostringstream out;
out << "shuffleWriterType = " << shuffleWriterType << ", partitionWriterType = " << partitionWriterType
<< ", compressionType = " << compressionType << ", compressionThreshold = " << compressionThreshold
<< ", mergeBufferSize = " << mergeBufferSize << ", compressionBufferSize = " << compressionBufferSize
<< ", useRadixSort = " << (useRadixSort ? "true" : "false");
return out.str();
}
};
class VeloxShuffleWriterTestBase : public facebook::velox::test::VectorTestBase {
public:
virtual arrow::Status initShuffleWriterOptions() {
RETURN_NOT_OK(setLocalDirsAndDataFile());
return arrow::Status::OK();
}
protected:
void setUp() {
if (!isRegisteredNamedVectorSerde(facebook::velox::VectorSerde::Kind::kPresto)) {
facebook::velox::serializer::presto::PrestoVectorSerde::registerNamedVectorSerde();
}
children1_ = {
makeNullableFlatVector<int8_t>({1, 2, 3, std::nullopt, 4, std::nullopt, 5, 6, std::nullopt, 7}),
makeNullableFlatVector<int8_t>({1, -1, std::nullopt, std::nullopt, -2, 2, std::nullopt, std::nullopt, 3, -3}),
makeNullableFlatVector<int32_t>({1, 2, 3, 4, std::nullopt, 5, 6, 7, 8, std::nullopt}),
makeNullableFlatVector<int64_t>(
{std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt}),
makeNullableFlatVector<float>(
{-0.1234567,
std::nullopt,
0.1234567,
std::nullopt,
-0.142857,
std::nullopt,
0.142857,
0.285714,
0.428617,
std::nullopt}),
makeNullableFlatVector<bool>(
{std::nullopt, true, false, std::nullopt, true, true, false, true, std::nullopt, std::nullopt}),
makeFlatVector<facebook::velox::StringView>(
{"alice0", "bob1", "alice2", "bob3", "Alice4", "Bob5", "AlicE6", "boB7", "ALICE8", "BOB9"}),
makeNullableFlatVector<facebook::velox::StringView>(
{"alice", "bob", std::nullopt, std::nullopt, "Alice", "Bob", std::nullopt, "alicE", std::nullopt, "boB"}),
facebook::velox::BaseVector::create(facebook::velox::UNKNOWN(), 10, pool())};
children2_ = {
makeNullableFlatVector<int8_t>({std::nullopt, std::nullopt}),
makeFlatVector<int8_t>({1, -1}),
makeNullableFlatVector<int32_t>({100, std::nullopt}),
makeFlatVector<int64_t>({1, 1}),
makeFlatVector<float>({0.142857, -0.142857}),
makeFlatVector<bool>({true, false}),
makeFlatVector<facebook::velox::StringView>(
{"bob",
"alicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealicealice"}),
makeNullableFlatVector<facebook::velox::StringView>({std::nullopt, std::nullopt}),
facebook::velox::BaseVector::create(facebook::velox::UNKNOWN(), 2, pool())};
childrenNoNull_ = {
makeFlatVector<int8_t>({0, 1}),
makeFlatVector<int8_t>({0, -1}),
makeFlatVector<int32_t>({0, 100}),
makeFlatVector<int64_t>({0, 1}),
makeFlatVector<float>({0, 0.142857}),
makeFlatVector<bool>({false, true}),
makeFlatVector<facebook::velox::StringView>({"", "alice"}),
makeFlatVector<facebook::velox::StringView>({"alice", ""}),
};
largeString1_ = makeString(1024);
childrenLargeBinary1_ = {
makeFlatVector<int8_t>(std::vector<int8_t>(4096, 0)),
makeFlatVector<int8_t>(std::vector<int8_t>(4096, 0)),
makeFlatVector<int32_t>(std::vector<int32_t>(4096, 0)),
makeFlatVector<int64_t>(std::vector<int64_t>(4096, 0)),
makeFlatVector<float>(std::vector<float>(4096, 0)),
makeFlatVector<bool>(std::vector<bool>(4096, true)),
makeNullableFlatVector<facebook::velox::StringView>(
std::vector<std::optional<facebook::velox::StringView>>(4096, largeString1_.c_str())),
makeNullableFlatVector<facebook::velox::StringView>(
std::vector<std::optional<facebook::velox::StringView>>(4096, std::nullopt)),
};
largeString2_ = makeString(4096);
auto vectorToSpill = childrenLargeBinary2_ = {
makeFlatVector<int8_t>(std::vector<int8_t>(2048, 0)),
makeFlatVector<int8_t>(std::vector<int8_t>(2048, 0)),
makeFlatVector<int32_t>(std::vector<int32_t>(2048, 0)),
makeFlatVector<int64_t>(std::vector<int64_t>(2048, 0)),
makeFlatVector<float>(std::vector<float>(2048, 0)),
makeFlatVector<bool>(std::vector<bool>(2048, true)),
makeNullableFlatVector<facebook::velox::StringView>(
std::vector<std::optional<facebook::velox::StringView>>(2048, largeString2_.c_str())),
makeNullableFlatVector<facebook::velox::StringView>(
std::vector<std::optional<facebook::velox::StringView>>(2048, std::nullopt)),
};
childrenComplex_ = {
makeNullableFlatVector<int32_t>({std::nullopt, 1}),
makeRowVector({
makeFlatVector<int32_t>({1, 3}),
makeNullableFlatVector<facebook::velox::StringView>({std::nullopt, "de"}),
}),
makeNullableFlatVector<facebook::velox::StringView>({std::nullopt, "10 I'm not inline string"}),
makeArrayVector<int64_t>({
{1, 2, 3, 4, 5},
{1, 2, 3},
}),
makeMapVector<int32_t, facebook::velox::StringView>(
{{{1, "str1000"}, {2, "str2000"}}, {{3, "str3000"}, {4, "str4000"}}}),
};
inputVector1_ = makeRowVector(children1_);
inputVector2_ = makeRowVector(children2_);
inputVectorNoNull_ = makeRowVector(childrenNoNull_);
inputVectorLargeBinary1_ = makeRowVector(childrenLargeBinary1_);
inputVectorLargeBinary2_ = makeRowVector(childrenLargeBinary2_);
inputVectorComplex_ = makeRowVector(childrenComplex_);
}
arrow::Status splitRowVector(
VeloxShuffleWriter& shuffleWriter,
facebook::velox::RowVectorPtr vector,
int64_t memLimit = ShuffleWriter::kMinMemLimit) {
std::shared_ptr<ColumnarBatch> cb = std::make_shared<VeloxColumnarBatch>(vector);
return shuffleWriter.write(cb, memLimit);
}
arrow::Status setLocalDirsAndDataFile() {
static const std::string kTestLocalDirsPrefix = "columnar-shuffle-test-";
tmpDirs_.emplace_back();
ARROW_ASSIGN_OR_RAISE(tmpDirs_.back(), arrow::internal::TemporaryDir::Make(kTestLocalDirsPrefix))
ARROW_ASSIGN_OR_RAISE(dataFile_, createTempShuffleFile(tmpDirs_.back()->path().ToString()));
localDirs_.push_back(tmpDirs_.back()->path().ToString());
tmpDirs_.emplace_back();
ARROW_ASSIGN_OR_RAISE(tmpDirs_.back(), arrow::internal::TemporaryDir::Make(kTestLocalDirsPrefix))
localDirs_.push_back(tmpDirs_.back()->path().ToString());
return arrow::Status::OK();
}
virtual std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) = 0;
ShuffleWriterOptions shuffleWriterOptions_{};
PartitionWriterOptions partitionWriterOptions_{};
std::vector<std::unique_ptr<arrow::internal::TemporaryDir>> tmpDirs_;
std::string dataFile_;
std::vector<std::string> localDirs_;
std::vector<facebook::velox::VectorPtr> children1_;
std::vector<facebook::velox::VectorPtr> children2_;
std::vector<facebook::velox::VectorPtr> childrenNoNull_;
std::vector<facebook::velox::VectorPtr> childrenLargeBinary1_;
std::vector<facebook::velox::VectorPtr> childrenLargeBinary2_;
std::vector<facebook::velox::VectorPtr> childrenComplex_;
facebook::velox::RowVectorPtr inputVector1_;
facebook::velox::RowVectorPtr inputVector2_;
facebook::velox::RowVectorPtr inputVectorNoNull_;
std::string largeString1_;
std::string largeString2_;
facebook::velox::RowVectorPtr inputVectorLargeBinary1_;
facebook::velox::RowVectorPtr inputVectorLargeBinary2_;
facebook::velox::RowVectorPtr inputVectorComplex_;
};
class VeloxShuffleWriterTest : public ::testing::TestWithParam<ShuffleTestParams>, public VeloxShuffleWriterTestBase {
public:
arrow::Status initShuffleWriterOptions() override {
RETURN_NOT_OK(VeloxShuffleWriterTestBase::initShuffleWriterOptions());
ShuffleTestParams params = GetParam();
shuffleWriterOptions_.useRadixSort = params.useRadixSort;
shuffleWriterOptions_.sortEvictBufferSize = params.compressionBufferSize;
partitionWriterOptions_.compressionType = params.compressionType;
switch (partitionWriterOptions_.compressionType) {
case arrow::Compression::UNCOMPRESSED:
partitionWriterOptions_.compressionTypeStr = "none";
break;
case arrow::Compression::LZ4_FRAME:
partitionWriterOptions_.compressionTypeStr = "lz4";
break;
case arrow::Compression::ZSTD:
partitionWriterOptions_.compressionTypeStr = "zstd";
break;
default:
break;
};
partitionWriterOptions_.compressionThreshold = params.compressionThreshold;
partitionWriterOptions_.mergeBufferSize = params.mergeBufferSize;
return arrow::Status::OK();
}
std::shared_ptr<VeloxShuffleWriter> createSpecificShuffleWriter(
arrow::MemoryPool* arrowPool,
std::unique_ptr<PartitionWriter> partitionWriter,
ShuffleWriterOptions shuffleWriterOptions,
uint32_t numPartitions,
int32_t bufferSize) {
if (shuffleWriterOptions.shuffleWriterType == ShuffleWriterType::kHashShuffle) {
shuffleWriterOptions.bufferSize = bufferSize;
}
GLUTEN_ASSIGN_OR_THROW(
auto shuffleWriter,
VeloxShuffleWriter::create(
GetParam().shuffleWriterType,
numPartitions,
std::move(partitionWriter),
std::move(shuffleWriterOptions),
pool_,
arrowPool));
return shuffleWriter;
}
protected:
static void SetUpTestCase() {
facebook::velox::memory::MemoryManager::testingSetInstance({});
}
virtual void SetUp() override {
std::cout << "Running test with param: " << GetParam().toString() << std::endl;
VeloxShuffleWriterTestBase::setUp();
}
void TearDown() override {
if (file_ != nullptr && !file_->closed()) {
GLUTEN_THROW_NOT_OK(file_->Close());
}
}
static void checkFileExists(const std::string& fileName) {
ASSERT_EQ(*arrow::internal::FileExists(*arrow::internal::PlatformFilename::FromString(fileName)), true);
}
std::shared_ptr<arrow::Schema> getArrowSchema(facebook::velox::RowVectorPtr& rowVector) {
return toArrowSchema(rowVector->type(), pool());
}
void setReadableFile(const std::string& fileName) {
if (file_ != nullptr && !file_->closed()) {
GLUTEN_THROW_NOT_OK(file_->Close());
}
GLUTEN_ASSIGN_OR_THROW(file_, arrow::io::ReadableFile::Open(fileName))
}
void getRowVectors(
arrow::Compression::type compressionType,
std::shared_ptr<arrow::Schema> schema,
std::vector<facebook::velox::RowVectorPtr>& vectors,
std::shared_ptr<arrow::io::InputStream> in) {
ShuffleReaderOptions options;
options.compressionType = compressionType;
auto codec = createArrowIpcCodec(options.compressionType, CodecBackend::NONE);
auto rowType = facebook::velox::asRowType(gluten::fromArrowSchema(schema));
switch (options.compressionType) {
case arrow::Compression::type::UNCOMPRESSED:
options.compressionTypeStr = "none";
break;
case arrow::Compression::type::LZ4_FRAME:
options.compressionTypeStr = "lz4";
break;
case arrow::Compression::type::ZSTD:
options.compressionTypeStr = "zstd";
break;
default:
break;
};
auto veloxCompressionType = facebook::velox::common::stringToCompressionKind(options.compressionTypeStr);
if (!facebook::velox::isRegisteredVectorSerde()) {
facebook::velox::serializer::presto::PrestoVectorSerde::registerVectorSerde();
}
auto deserializerFactory = std::make_unique<gluten::VeloxColumnarBatchDeserializerFactory>(
schema,
std::move(codec),
veloxCompressionType,
rowType,
std::numeric_limits<int32_t>::max(),
kDefaultReadBufferSize,
defaultArrowMemoryPool().get(),
pool_,
GetParam().shuffleWriterType);
auto reader = std::make_shared<VeloxShuffleReader>(std::move(deserializerFactory));
auto iter = reader->readStream(in);
while (iter->hasNext()) {
auto vector = std::dynamic_pointer_cast<VeloxColumnarBatch>(iter->next())->getRowVector();
vectors.emplace_back(vector);
}
}
std::shared_ptr<arrow::io::ReadableFile> file_;
};
class SinglePartitioningShuffleWriter : public VeloxShuffleWriterTest {
protected:
void testShuffleWrite(VeloxShuffleWriter& shuffleWriter, std::vector<facebook::velox::RowVectorPtr> vectors) {
for (auto& vector : vectors) {
ASSERT_NOT_OK(splitRowVector(shuffleWriter, vector));
}
ASSERT_NOT_OK(shuffleWriter.stop());
checkFileExists(dataFile_);
const auto& lengths = shuffleWriter.partitionLengths();
ASSERT_EQ(lengths.size(), 1);
auto schema = getArrowSchema(vectors[0]);
std::vector<facebook::velox::RowVectorPtr> deserializedVectors;
setReadableFile(dataFile_);
GLUTEN_ASSIGN_OR_THROW(auto in, arrow::io::RandomAccessFile::GetStream(file_, 0, lengths[0]));
getRowVectors(partitionWriterOptions_.compressionType, schema, deserializedVectors, in);
ASSERT_EQ(deserializedVectors.size(), vectors.size());
for (int32_t i = 0; i < deserializedVectors.size(); i++) {
facebook::velox::test::assertEqualVectors(vectors[i], deserializedVectors[i]);
}
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) override {
shuffleWriterOptions_.partitioning = Partitioning::kSingle;
static const uint32_t kNumPartitions = 1;
auto partitionWriter = createPartitionWriter(
GetParam().partitionWriterType, kNumPartitions, dataFile_, localDirs_, partitionWriterOptions_, arrowPool);
std::shared_ptr<VeloxShuffleWriter> shuffleWriter = createSpecificShuffleWriter(
arrowPool, std::move(partitionWriter), std::move(shuffleWriterOptions_), kNumPartitions, 10);
return shuffleWriter;
}
};
class MultiplePartitioningShuffleWriter : public VeloxShuffleWriterTest {
protected:
void shuffleWriteReadMultiBlocks(
VeloxShuffleWriter& shuffleWriter,
int32_t expectPartitionLength,
facebook::velox::TypePtr dataType,
std::vector<std::vector<facebook::velox::RowVectorPtr>> expectedVectors) {
ASSERT_NOT_OK(shuffleWriter.stop());
checkFileExists(dataFile_);
const auto& lengths = shuffleWriter.partitionLengths();
ASSERT_EQ(lengths.size(), expectPartitionLength);
int64_t lengthSum = std::accumulate(lengths.begin(), lengths.end(), 0);
auto schema = toArrowSchema(dataType, pool());
setReadableFile(dataFile_);
ASSERT_EQ(*file_->GetSize(), lengthSum);
for (int32_t i = 0; i < expectPartitionLength; i++) {
if (expectedVectors[i].size() == 0) {
ASSERT_EQ(lengths[i], 0);
} else {
std::vector<facebook::velox::RowVectorPtr> deserializedVectors;
GLUTEN_ASSIGN_OR_THROW(
auto in, arrow::io::RandomAccessFile::GetStream(file_, i == 0 ? 0 : lengths[i - 1], lengths[i]));
getRowVectors(partitionWriterOptions_.compressionType, schema, deserializedVectors, in);
ASSERT_EQ(expectedVectors[i].size(), deserializedVectors.size());
for (int32_t j = 0; j < expectedVectors[i].size(); j++) {
facebook::velox::test::assertEqualVectors(expectedVectors[i][j], deserializedVectors[j]);
}
}
}
}
void testShuffleWriteMultiBlocks(
VeloxShuffleWriter& shuffleWriter,
std::vector<facebook::velox::RowVectorPtr> vectors,
int32_t expectPartitionLength,
facebook::velox::TypePtr dataType,
std::vector<std::vector<facebook::velox::RowVectorPtr>> expectedVectors) {
for (auto& vector : vectors) {
ASSERT_NOT_OK(splitRowVector(shuffleWriter, vector));
}
shuffleWriteReadMultiBlocks(shuffleWriter, expectPartitionLength, dataType, expectedVectors);
}
};
class HashPartitioningShuffleWriter : public MultiplePartitioningShuffleWriter {
protected:
void SetUp() override {
MultiplePartitioningShuffleWriter::SetUp();
children1_.insert((children1_.begin()), makeFlatVector<int32_t>({1, 2, 2, 2, 2, 1, 1, 1, 2, 1}));
hashInputVector1_ = makeRowVector(children1_);
children2_.insert((children2_.begin()), makeFlatVector<int32_t>({2, 2}));
hashInputVector2_ = makeRowVector(children2_);
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) override {
shuffleWriterOptions_.partitioning = Partitioning::kHash;
static const uint32_t kNumPartitions = 2;
auto partitionWriter = createPartitionWriter(
GetParam().partitionWriterType, kNumPartitions, dataFile_, localDirs_, partitionWriterOptions_, arrowPool);
std::shared_ptr<VeloxShuffleWriter> shuffleWriter = createSpecificShuffleWriter(
arrowPool, std::move(partitionWriter), std::move(shuffleWriterOptions_), kNumPartitions, 4);
return shuffleWriter;
}
std::vector<uint32_t> hashPartitionIds_{1, 2};
facebook::velox::RowVectorPtr hashInputVector1_;
facebook::velox::RowVectorPtr hashInputVector2_;
};
class RangePartitioningShuffleWriter : public MultiplePartitioningShuffleWriter {
protected:
void SetUp() override {
MultiplePartitioningShuffleWriter::SetUp();
auto pid1 = makeRowVector({makeFlatVector<int32_t>({0, 1, 0, 1, 0, 1, 0, 1, 0, 1})});
auto rangeVector1 = makeRowVector(inputVector1_->children());
compositeBatch1_ = VeloxColumnarBatch::compose(
pool(), {std::make_shared<VeloxColumnarBatch>(pid1), std::make_shared<VeloxColumnarBatch>(rangeVector1)});
auto pid2 = makeRowVector({makeFlatVector<int32_t>({0, 1})});
auto rangeVector2 = makeRowVector(inputVector2_->children());
compositeBatch2_ = VeloxColumnarBatch::compose(
pool(), {std::make_shared<VeloxColumnarBatch>(pid2), std::make_shared<VeloxColumnarBatch>(rangeVector2)});
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) override {
shuffleWriterOptions_.partitioning = Partitioning::kRange;
static const uint32_t kNumPartitions = 2;
auto partitionWriter = createPartitionWriter(
GetParam().partitionWriterType, kNumPartitions, dataFile_, localDirs_, partitionWriterOptions_, arrowPool);
std::shared_ptr<VeloxShuffleWriter> shuffleWriter = createSpecificShuffleWriter(
arrowPool, std::move(partitionWriter), std::move(shuffleWriterOptions_), kNumPartitions, 4);
return shuffleWriter;
}
void testShuffleWriteMultiBlocks(
VeloxShuffleWriter& shuffleWriter,
std::vector<std::shared_ptr<ColumnarBatch>> batches,
int32_t expectPartitionLength,
facebook::velox::TypePtr dataType,
std::vector<std::vector<facebook::velox::RowVectorPtr>> expectedVectors) {
for (auto& batch : batches) {
ASSERT_NOT_OK(shuffleWriter.write(batch, ShuffleWriter::kMinMemLimit));
}
shuffleWriteReadMultiBlocks(shuffleWriter, expectPartitionLength, dataType, expectedVectors);
}
std::shared_ptr<ColumnarBatch> compositeBatch1_;
std::shared_ptr<ColumnarBatch> compositeBatch2_;
};
class RoundRobinPartitioningShuffleWriter : public MultiplePartitioningShuffleWriter {
protected:
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) override {
static const uint32_t kNumPartitions = 2;
auto partitionWriter = createPartitionWriter(
GetParam().partitionWriterType, kNumPartitions, dataFile_, localDirs_, partitionWriterOptions_, arrowPool);
std::shared_ptr<VeloxShuffleWriter> shuffleWriter = createSpecificShuffleWriter(
arrowPool, std::move(partitionWriter), std::move(shuffleWriterOptions_), kNumPartitions, 4);
return shuffleWriter;
}
};
class VeloxHashShuffleWriterMemoryTest : public VeloxShuffleWriterTestBase, public testing::Test {
protected:
static void SetUpTestCase() {
facebook::velox::memory::MemoryManager::testingSetInstance({});
}
void SetUp() override {
VeloxShuffleWriterTestBase::setUp();
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(uint32_t numPartitions, arrow::MemoryPool* arrowPool) {
auto partitionWriter = createPartitionWriter(
PartitionWriterType::kLocal, numPartitions, dataFile_, localDirs_, partitionWriterOptions_, arrowPool);
GLUTEN_ASSIGN_OR_THROW(
auto shuffleWriter,
VeloxHashShuffleWriter::create(
numPartitions, std::move(partitionWriter), std::move(shuffleWriterOptions_), pool_, arrowPool));
return shuffleWriter;
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* arrowPool) override {
return createShuffleWriter(kDefaultShufflePartitions, arrowPool);
}
int64_t splitRowVectorAndSpill(
VeloxShuffleWriter& shuffleWriter,
std::vector<facebook::velox::RowVectorPtr> vectors,
bool shrink) {
for (auto vector : vectors) {
ASSERT_NOT_OK(splitRowVector(shuffleWriter, vector));
}
auto targetEvicted = shuffleWriter.cachedPayloadSize();
if (shrink) {
targetEvicted += shuffleWriter.partitionBufferSize();
}
int64_t evicted;
ASSERT_NOT_OK(shuffleWriter.reclaimFixedSize(targetEvicted, &evicted));
return evicted;
};
static constexpr uint32_t kDefaultShufflePartitions = 2;
};
}