* @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved.
* @Description: spiller implementation
*/
#ifndef OMNI_RUNTIME_SPILLER_H
#define OMNI_RUNTIME_SPILLER_H
#include "type/data_types.h"
#include "spill_merger.h"
#include "operator/pages_index.h"
#include "operator/aggregation/group_aggregation_sort.h"
namespace omniruntime {
namespace op {
class SpillWriter {
public:
SpillWriter(const type::DataTypes &dataTypes, const std::string &dirPath, uint64_t writeBufferSize = 0)
: dataTypes(dataTypes), dirPath(dirPath), writeBufferSize(writeBufferSize), writeBufferOffset(0)
{
if (writeBufferSize != 0) {
writeBuffer = new char[writeBufferSize];
}
}
~SpillWriter()
{
delete[] writeBuffer;
}
ErrorCode WriteVecBatch(vec::VectorBatch *vectorBatch, uint64_t vectorBatchSize);
SpillFileInfo GetSpillFileInfo()
{
SpillFileInfo file{ filePath, fileLength, totalRowCount };
return file;
}
ErrorCode Close();
private:
ErrorCode CreateTempFile();
ErrorCode WriteVecBatchToBuffer(vec::VectorBatch *vectorBatch);
template <typename T>
ErrorCode WriteVectorToBuffer(vec::BaseVector *vector, int32_t rowCount, int32_t &writeOffset);
ErrorCode WriteVecBatchToFile(vec::VectorBatch *vectorBatch);
template <typename T> ErrorCode WriteVector(omniruntime::vec::BaseVector *vector, int32_t rowCount);
ErrorCode Write(void *buf, size_t length);
type::DataTypes dataTypes;
std::string dirPath;
int32_t fd = -1;
std::string filePath;
uint64_t writeBufferSize = 0;
uint64_t writeBufferOffset = 0;
char *writeBuffer = nullptr;
uint64_t fileLength = 0;
uint64_t totalRowCount = 0;
};
class Spiller {
public:
Spiller(const type::DataTypes &dataTypes, const std::vector<int32_t> &sortCols,
const std::vector<SortOrder> &sortOrders, const std::string &spillPath, uint64_t maxSpillBytes,
uint64_t writeBufferSize = 0)
: dataTypes(dataTypes), sortCols(sortCols), sortOrders(sortOrders), writeBufferSize(writeBufferSize)
{
dirPaths.emplace_back(spillPath);
int32_t dataTypeCount = dataTypes.GetSize();
for (int32_t i = 0; i < dataTypeCount; i++) {
outputCols.emplace_back(i);
}
maxRowCountPerBatch = OperatorUtil::GetMaxRowCount(dataTypes.Get(), outputCols.data(), dataTypeCount);
spillTracker = GetRootSpillTracker().CreateSpillTracker(maxSpillBytes);
}
~Spiller()
{
for (auto writer : writers) {
delete writer;
}
writers.clear();
}
ErrorCode Spill(PagesIndex *pagesIndex, bool canInplaceSort, bool canRadixSort);
ErrorCode Spill(AggregationSort *aggregationSort);
std::vector<SpillFileInfo> FinishSpill()
{
std::vector<SpillFileInfo> spillFiles;
for (auto writer : writers) {
spillFiles.emplace_back(writer->GetSpillFileInfo());
}
return spillFiles;
}
SpillMerger *CreateSpillMerger(const std::vector<SpillFileInfo> &spillFiles)
{
auto merger = SpillMerger::Create(dataTypes, sortCols, sortOrders, spillTracker, spillFiles);
return merger;
}
uint64_t GetSpilledBytes()
{
return spillTracker->GetSpilledBytes();
}
SpillTracker *GetSpillTracker() const
{
return spillTracker;
}
void RemoveSpillFiles()
{
for (auto writer : writers) {
auto spillFile = writer->GetSpillFileInfo();
remove(spillFile.filePath.c_str());
}
}
private:
uint64_t CollectVecBatchSize(vec::VectorBatch *vectorBatch);
template <typename T> uint64_t CollectVectorSize(vec::BaseVector *vector);
DataTypes dataTypes;
std::vector<int32_t> sortCols;
std::vector<SortOrder> sortOrders;
std::vector<std::string> dirPaths;
std::vector<int32_t> outputCols;
int32_t maxRowCountPerBatch;
std::unique_ptr<vec::VectorBatch> spillVecBatch = nullptr;
std::vector<SpillWriter *> writers;
SpillTracker *spillTracker = nullptr;
uint64_t writeBufferSize = 0;
};
}
}
#endif