* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef FLINK_TNEL_CSVTABLESOURCE_H
#define FLINK_TNEL_CSVTABLESOURCE_H
#include <string>
#include <vector>
#include <cassert>
#include <fstream>
#include <emhash7.hpp>
#include <nlohmann/json.hpp>
#include "core/typeinfo/TypeInformation.h"
#include "table/data/vectorbatch/VectorBatch.h"
#include "OmniOperatorJIT/core/src/type/data_type.h"
#include "functions/Collector.h"
class CsvTableSource {
public:
CsvTableSource(std::string filepath, std::vector<std::string> fieldTypeStrs)
:fieldTypeStrs(fieldTypeStrs), filepath(filepath) {}
size_t countCsvRows();
std::string getFilePath() const
{
return filepath;
}
std::vector<std::string> getTableFieldTypes() const
{
return fieldTypeStrs;
}
private:
std::vector<std::string> fieldTypeStrs;
std::string filepath;
};
template<typename T>
inline void CsvStrConverterFunc(const std::string &inStr, omniruntime::vec::BaseVector *vec, int rowIndex)
{
if constexpr (std::is_same_v<int64_t, T>) {
static_cast<omniruntime::vec::Vector<int64_t>* >(vec)->SetValue(rowIndex, std::stol(inStr));
} else if constexpr (std::is_same_v<int32_t, T>) {
static_cast<omniruntime::vec::Vector<int32_t>* >(vec)->SetValue(rowIndex, std::stoi(inStr));
} else if constexpr (std::is_same_v<std::string_view, T>) {
std::string_view inStrView(inStr.data(), inStr.size());
using VarcharVecType = omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>;
static_cast<VarcharVecType* >(vec)->SetValue(rowIndex, inStrView);
}
}
template<typename K>
class CsvLookupFunction {
public:
~CsvLookupFunction()
{
delete csvDataVecBatch;
}
CsvLookupFunction(const nlohmann::json& description, CsvTableSource *src) : description(description), src (src)
{
int hashRowCnt = src->countCsvRows();
csvDataVecBatch = new omniruntime::vec::VectorBatch(hashRowCnt);
auto lookupTypeStrs = description["lookupInputTypes"].get<std::vector<std::string>>();
for (size_t i = 0; i < lookupTypeStrs.size(); i++) {
switch (LogicalType::flinkTypeToOmniTypeId(lookupTypeStrs[i])) {
case omniruntime::type::OMNI_LONG: {
csvStrConverters.push_back(CsvStrConverterFunc<int64_t>);
auto vec = new omniruntime::vec::Vector<int64_t>(hashRowCnt);
csvDataVecBatch->Append(vec);
break;
}
case omniruntime::type::OMNI_CHAR:
case omniruntime::type::OMNI_VARCHAR: {
csvStrConverters.push_back(CsvStrConverterFunc<std::string_view>);
using VarcharVecType = omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>;
auto vec = new VarcharVecType(hashRowCnt);
csvDataVecBatch->Append(vec);
break;
}
case omniruntime::type::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case omniruntime::type::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case omniruntime::type::OMNI_TIMESTAMP: {
csvStrConverters.push_back(CsvStrConverterFunc<int64_t>);
auto vec = new omniruntime::vec::Vector<int64_t>(hashRowCnt);
csvDataVecBatch->Append(vec);
break;
}
default:
std::runtime_error("not supported type");
}
}
}
void open()
{
auto lookupKeys = description["lookupKeys"];
auto selectedFields = description["selectedFields"].get<std::vector<int>>();
for (const auto &[key, value]: lookupKeys.items()) {
sourceKeys.push_back(value["index"]);
int targetIdx = std::stoi(key);
if (targetIdx == -1) {
throw std::runtime_error("CsvLookupFunction open: targetIdx is -1");
}
targetKeys.push_back(targetIdx);
}
if (sourceKeys.size() != 1 || targetKeys.size() != 1) {
NOT_IMPL_EXCEPTION;
}
int32_t hashKeyIndex = targetKeys[0];
if (src->getTableFieldTypes()[hashKeyIndex] != "BIGINT"
&& src->getTableFieldTypes()[hashKeyIndex] != "BIGINT NOT NULL") {
NOT_IMPL_EXCEPTION
}
std::ifstream file(src->getFilePath());
int32_t irow = 0;
std::string line;
std::string keyStr;
auto hashVectors = csvDataVecBatch->GetVectors();
while (std::getline(file, line)) {
std::stringstream ss(line);
std::string token;
int icol = 0;
size_t colIndex = 0;
while (std::getline(ss, token, ',') && colIndex < selectedFields.size()) {
if (icol == selectedFields[colIndex]) {
csvStrConverters[colIndex](token, hashVectors[colIndex], irow);
colIndex++;
}
if (icol == hashKeyIndex) {
keyStr = token;
}
icol++;
}
K key = std::stol(keyStr);
dataMap[key].push_back(irow);
irow++;
}
file.close();
}
std::vector<int32_t> &getTestFunc(K key)
{
return dataMap[key];
}
void eval(omnistream::VectorBatch *vb, Collector *collector)
{
int32_t totRowCnt = 0;
auto probekeyCol = reinterpret_cast<omniruntime::vec::Vector<K>* > (vb->GetVectors()[sourceKeys[0]]);
int32_t rowCount = vb->GetRowCount();
std::vector<std::tuple<int32_t*, int32_t>> matchedRows(rowCount, {nullptr, 0});
for (int i = 0; i < rowCount; i++) {
auto key = probekeyCol->GetValue(i);
auto it = dataMap.find(key);
if (it != dataMap.end()) {
matchedRows[i] = {it->second.data(), it->second.size()};
totRowCnt += it->second.size();
}
}
auto output = buildOutput(vb, matchedRows, totRowCnt);
collector->collect(output);
delete vb;
}
private:
omnistream::VectorBatch *buildOutput(omnistream::VectorBatch *vb,
std::vector<std::tuple<int32_t*, int32_t>> &matchedRows,
int32_t totRowCnt)
{
auto output = new omnistream::VectorBatch(totRowCnt);
omniruntime::vec::BaseVector *outCol = nullptr;
BuildProbeOutput(vb, matchedRows, totRowCnt, output, outCol);
BuildHashOutput(vb, matchedRows, totRowCnt, output, outCol);
int32_t outputRowIndex = 0;
int32_t rowIndex = 0;
for (size_t probeRow = 0; probeRow < matchedRows.size(); probeRow++) {
int32_t targetRowsCnt = std::get<1>(matchedRows[probeRow]);
if (targetRowsCnt == 0) {
continue;
}
int64_t timestamp = vb->getTimestamp(probeRow);
RowKind rowKind = vb->getRowKind(probeRow);
for (int i = 0; i < targetRowsCnt; i++) {
output->setRowKind(outputRowIndex, rowKind);
output->setTimestamp(outputRowIndex++, timestamp);
}
rowIndex += targetRowsCnt;
}
return output;
};
void BuildProbeOutput(omnistream::VectorBatch *vb, std::vector<std::tuple<int32_t*, int32_t>> &matchedRows,
int32_t totRowCnt, omnistream::VectorBatch* output, omniruntime::vec::BaseVector * outCol)
{
for (int icol = 0; icol < vb->GetVectorCount(); icol++) {
auto typeId = vb->Get(icol)->GetTypeId();
switch (typeId) {
case omniruntime::type::OMNI_LONG:
outCol = buildProbeOutputColumn<int64_t>(vb, matchedRows, totRowCnt, icol);
break;
case omniruntime::type::OMNI_TIMESTAMP:
case omniruntime::type::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case omniruntime::type::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
outCol = buildProbeOutputColumn<int64_t>(vb, matchedRows, totRowCnt, icol);
break;
case omniruntime::type::OMNI_VARCHAR:
case omniruntime::type::OMNI_CHAR:
outCol = buildProbeOutputColumn
<omniruntime::vec::LargeStringContainer<std::string_view>>(vb, matchedRows, totRowCnt, icol);
break;
default:
std::runtime_error("Type not supported!");
}
output->Append(outCol);
}
}
void BuildHashOutput(omnistream::VectorBatch *vb, std::vector<std::tuple<int32_t*, int32_t>> &matchedRows,
int32_t totRowCnt, omnistream::VectorBatch* output, omniruntime::vec::BaseVector * outCol)
{
for (size_t icol = 0; icol < src->getTableFieldTypes().size(); icol++) {
auto typeId = LogicalType::flinkTypeToOmniTypeId(src->getTableFieldTypes()[icol]);
switch (typeId) {
case omniruntime::type::OMNI_LONG:
outCol = buildHashOutputColumn<int64_t>(matchedRows, totRowCnt, icol);
break;
case omniruntime::type::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case omniruntime::type::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case omniruntime::type::OMNI_TIMESTAMP:
outCol = buildHashOutputColumn<int64_t>(matchedRows, totRowCnt, icol);
break;
case omniruntime::type::OMNI_VARCHAR:
case omniruntime::type::OMNI_CHAR:
outCol = buildHashOutputColumn<omniruntime::vec::LargeStringContainer<std::string_view>>(
matchedRows, totRowCnt, icol);
break;
default:
std::runtime_error("Type not supported!");
}
output->Append(outCol);
}
}
template<typename T>
omniruntime::vec::Vector<T> *buildHashOutputColumn(std::vector<std::tuple<int32_t *, int32_t> > &matchedRows,
int32_t totRowCnt, int colIndex)
{
auto outVec = new omniruntime::vec::Vector<T>(totRowCnt);
auto inputCol = static_cast<omniruntime::vec::Vector<T> *> (csvDataVecBatch->Get(colIndex));
int32_t outRowIndex = 0;
for (size_t probeRow = 0; probeRow < matchedRows.size(); probeRow++) {
int32_t targetRowsCnt = std::get<1>(matchedRows[probeRow]);
if (targetRowsCnt == 0) {
continue;
}
int32_t *targetRowsIndices = std::get<0>(matchedRows[probeRow]);
for (int32_t i = 0; i < targetRowsCnt; i++) {
auto value = inputCol->GetValue(targetRowsIndices[i]);
outVec->SetValue(outRowIndex++, value);
}
}
return outVec;
}
template<typename T>
omniruntime::vec::Vector<T> *buildProbeOutputColumn(omnistream::VectorBatch *vb,
std::vector<std::tuple<int32_t*, int32_t>> &matchedRows,
int32_t totRowCnt, int colIndex)
{
auto outVec = new omniruntime::vec::Vector<T>(totRowCnt);
auto inputCol = static_cast<omniruntime::vec::Vector<T> *> (vb->Get(colIndex));
int32_t rowIndex = 0;
for (size_t probeRow = 0; probeRow < matchedRows.size(); probeRow++) {
int32_t targetRowsCnt = std::get<1>(matchedRows[probeRow]);
if (targetRowsCnt == 0) {
continue;
}
auto val = inputCol->GetValue(probeRow);
for (int i = 0; i < targetRowsCnt; i++) {
outVec->SetValue(rowIndex + i, val);
}
rowIndex += targetRowsCnt;
}
return outVec;
}
nlohmann::json description;
CsvTableSource* src;
std::vector<int32_t> sourceKeys;
std::vector<int32_t> targetKeys;
emhash7::HashMap<K, std::vector<int32_t>> dataMap;
omniruntime::vec::VectorBatch *csvDataVecBatch;
using GetFromStrAndSetToVB = void (*)(const std::string &, omniruntime::vec::BaseVector *, int);
std::vector<GetFromStrAndSetToVB> csvStrConverters;
};
#endif