* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "VectorBatch.h"
#include <fstream>
#include "data/binary/BinaryRowData.h"
#include "table/data/rowdata_marshaller.h"
#include "OmniOperatorJIT/core/src/codegen/time_util.h"
namespace omnistream {
VectorBatch::VectorBatch(size_t rowCnt)
: omniruntime::vec::VectorBatch(rowCnt),
timestamps(nullptr),
rowKinds(nullptr),
maxTimestamp(INT64_MIN)
{
if (rowCnt > 0) {
timestamps = new int64_t[rowCnt];
memset_s(timestamps, sizeof(int64_t) * rowCnt, 0, sizeof(int64_t) * rowCnt);
rowKinds = new RowKind[rowCnt];
memset_s(rowKinds, sizeof(RowKind) * rowCnt, 0, sizeof(RowKind) * rowCnt);
}
}
VectorBatch::~VectorBatch()
{
delete[] timestamps;
delete[] rowKinds;
}
VectorBatch::VectorBatch(omniruntime::vec::VectorBatch* baseVecBatch, int64_t* timestamps, RowKind* rowkinds)
: omniruntime::vec::VectorBatch(baseVecBatch->GetRowCount())
{
auto baseVectors = baseVecBatch->GetVectors();
this->vectors.insert(this->vectors.end(), baseVectors, baseVectors + baseVecBatch->GetVectorCount());
this->rowKinds = rowkinds;
this->timestamps = timestamps;
this->maxTimestamp = INT64_MIN;
}
int64_t VectorBatch::setMaxTimestamp(int colIdx)
{
omniruntime::vec::Vector<int64_t>* col = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(this->Get(colIdx));
for (int i = 0; i < this->GetRowCount(); i++) {
maxTimestamp = std::max(maxTimestamp, col->GetValue(i));
}
return maxTimestamp;
}
void VectorBatch::RearrangeColumns(std::vector<int32_t>& inputIndices)
{
LOG("=====>");
std::vector<bool> toKeep(this->vectors.size(), false);
std::vector<omniruntime::vec::BaseVector*> newVectors(inputIndices.size());
for (size_t i = 0; i < inputIndices.size(); i++) {
newVectors[i] = this->vectors[inputIndices[i]];
toKeep[inputIndices[i]] = true;
}
for (size_t i = 0; i < toKeep.size(); i++) {
if (!toKeep[i]) {
delete vectors[i];
}
}
this->vectors = newVectors;
}
RowData* VectorBatch::extractRowData(int rowIndex)
{
if (rowIndex >= this->GetRowCount()) {
return nullptr;
}
int numColumns = this->GetVectorCount();
BinaryRowData* outRow = BinaryRowData::createBinaryRowDataWithMem(numColumns);
for (int colIndex = 0; colIndex < numColumns; ++colIndex) {
auto col = this->Get(colIndex);
if (col->IsNull(rowIndex)) {
outRow->setNullAt(colIndex);
continue;
}
auto typeId = col->GetTypeId();
if (typeId < OMNI_INVALID && rowSerializerCenter[typeId] != nullptr) {
rowSerializerCenter[typeId](col, rowIndex, outRow, colIndex);
} else if (
typeId == DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE ||
typeId == omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE) {
rowSerializerCenter[OMNI_LONG](col, rowIndex, outRow, colIndex);
} else {
THROW_RUNTIME_ERROR("extractRowData Data type not supported: " << typeId);
}
}
outRow->setRowKind(this->getRowKind(rowIndex));
return outRow;
}
std::string removeTrailingZeros(std::string num)
{
size_t dot_pos = num.find('.');
if (dot_pos == std::string::npos) {
return num;
}
std::string integer_part = num.substr(0, dot_pos);
std::string decimal_part = num.substr(dot_pos + 1);
size_t last_non_zero = decimal_part.find_last_not_of('0');
if (last_non_zero != std::string::npos) {
decimal_part = decimal_part.substr(0, last_non_zero + 1);
} else {
decimal_part.clear();
}
if (integer_part.empty()) {
integer_part = "0";
}
if (decimal_part.empty()) {
return integer_part;
} else {
return integer_part + "." + decimal_part;
}
}
std::string VectorBatch::TransformTimeWithTimeZone(
int vectorID, int rowID, const std::string& tzStr, int precision) const
{
auto millis = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(vectors[vectorID])->GetValue(rowID);
int64_t adjusted_seconds = (millis >= 0) ? (millis / 1000) : ((millis - 999) / 1000);
int milliseconds = millis % 1000;
if (milliseconds < 0) {
const int addTime = 1000;
milliseconds += addTime;
}
setenv("TZ", omniruntime::codegen::function::TimeZoneUtil::GetTZ(tzStr.c_str()), 1);
tzset();
struct tm timeinfo;
localtime_r(&adjusted_seconds, &timeinfo);
char buffer[80];
strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", &timeinfo);
std::ostringstream oss;
oss << buffer << ".";
if (precision <= 3) {
oss << std::setw(3) << std::setfill('0') << milliseconds;
} else if (precision <= 9) {
oss << std::setw(3) << std::setfill('0') << milliseconds << std::string(precision - 3, '0');
} else {
oss << std::setw(3) << std::setfill('0') << milliseconds << std::string(6, '0');
}
std::string result = oss.str();
return result;
}
std::string VectorBatch::TransformTime(int vectorID, int rowID, int precision) const
{
auto millis = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(vectors[vectorID])->GetValue(rowID);
int64_t adjusted_seconds = (millis >= 0) ? (millis / 1000) : ((millis - 999) / 1000);
int milliseconds = millis % 1000;
if (milliseconds < 0) {
const int addTime = 1000;
milliseconds += addTime;
}
struct tm timeinfo;
gmtime_r(&adjusted_seconds, &timeinfo);
char buffer[80];
strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", &timeinfo);
std::ostringstream oss;
oss << buffer << ".";
if (precision <= 3) {
oss << std::setw(3) << std::setfill('0')
<< milliseconds;
} else if (precision <= 9) {
oss << std::setw(3) << std::setfill('0')
<< milliseconds << std::string(precision - 3, '0');
} else {
oss << std::setw(3) << std::setfill('0')
<< milliseconds << std::string(6, '0');
}
std::string result = oss.str();
return result;
}
std::string VectorBatch::transformDecimal128(
int vectorID, int rowID, std::vector<std::pair<int32_t, int32_t>>& decimalInfo) const
{
std::string valueStr =
(reinterpret_cast<omniruntime::vec::Vector<Decimal128>*>(vectors[vectorID])->GetValue(rowID)).ToString();
if (static_cast<int>(decimalInfo.size()) > vectorID && decimalInfo[vectorID].second > 0) {
int32_t scale = decimalInfo[vectorID].second;
int len = static_cast<int>(valueStr.length());
if (scale >= len) {
valueStr = "0." + std::string(scale - len, '0') + valueStr;
} else {
valueStr = valueStr.substr(0, len - scale) + "." + valueStr.substr(len - scale);
}
}
return valueStr;
}
std::string VectorBatch::transformDecimal64(
int vectorID, int rowID, std::vector<std::pair<int32_t, int32_t>>& decimalInfo) const
{
std::string valueStr =
std::to_string(reinterpret_cast<omniruntime::vec::Vector<long>*>(vectors[vectorID])->GetValue(rowID));
if (static_cast<int>(decimalInfo.size()) > vectorID && decimalInfo[vectorID].second > 0) {
int32_t scale = decimalInfo[vectorID].second;
int len = static_cast<int>(valueStr.length());
bool negtiveFlag = false;
if (len > 0 && valueStr[0] == '-') {
valueStr = valueStr.substr(1, len);
negtiveFlag = true;
len -= 1;
}
if (scale >= len) {
valueStr = "0." + std::string(scale - len, '0') + valueStr;
} else {
valueStr = valueStr.substr(0, len - scale) + "." + valueStr.substr(len - scale);
}
if (negtiveFlag) {
valueStr = "-" + valueStr;
}
}
return valueStr;
}
void VectorBatch::WriteString(std::ofstream& file, int vectorID, int rowID) const
{
if (vectors[vectorID]->GetEncoding() == omniruntime::vec::OMNI_FLAT) {
auto casted =
reinterpret_cast<omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>*>(
vectors[vectorID]);
file << casted->GetValue(rowID);
} else {
auto casted = reinterpret_cast<omniruntime::vec::Vector<
omniruntime::vec::DictionaryContainer<std::string_view, omniruntime::vec::LargeStringContainer>>*>(
vectors[vectorID]);
file << casted->GetValue(rowID);
}
}
void VectorBatch::WriteToFileInternal(
int vectorID,
int rowID,
std::ofstream& file,
std::vector<std::pair<int32_t, int32_t>> decimalInfo,
std::vector<std::string> inputTypes,
const std::string& tzStr) const
{
int dataId = vectors[vectorID]->GetTypeId();
switch (dataId) {
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP:
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case omniruntime::type::DataTypeId::OMNI_LONG:
LOG("vb writefile inputType is " << inputTypes[vectorID]);
if (inputTypes[vectorID].substr(0, 30) == "TIMESTAMP_WITH_LOCAL_TIME_ZONE") {
auto result = TransformTimeWithTimeZone(vectorID, rowID, tzStr);
file << result;
} else if (inputTypes[vectorID].substr(0, 9) == "TIMESTAMP") {
int precision = 3;
size_t parenPos = inputTypes[vectorID].find('(');
if (parenPos != std::string::npos) {
size_t endParen = inputTypes[vectorID].find(')', parenPos);
if (endParen != std::string::npos) {
std::string precisionStr = inputTypes[vectorID].substr(parenPos + 1, endParen - parenPos - 1);
precision = std::stoi(precisionStr);
}
}
auto result = TransformTime(vectorID, rowID, precision);
file << result;
} else {
file << reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(vectors[vectorID])->GetValue(rowID);
}
break;
case omniruntime::type::DataTypeId::OMNI_VARCHAR:
case omniruntime::type::DataTypeId::OMNI_CHAR: WriteString(file, vectorID, rowID); break;
case omniruntime::type::DataTypeId::OMNI_DOUBLE:
file << reinterpret_cast<omniruntime::vec::Vector<double>*>(vectors[vectorID])->GetValue(rowID);
break;
case omniruntime::type::DataTypeId::OMNI_INT:
file << reinterpret_cast<omniruntime::vec::Vector<int32_t>*>(vectors[vectorID])->GetValue(rowID);
break;
case omniruntime::type::DataTypeId::OMNI_BOOLEAN:
file << reinterpret_cast<omniruntime::vec::Vector<bool>*>(vectors[vectorID])->GetValue(rowID);
break;
case omniruntime::type::DataTypeId::OMNI_DECIMAL64: {
auto valueStr = transformDecimal64(vectorID, rowID, decimalInfo);
file << valueStr;
break;
}
case omniruntime::type::DataTypeId::OMNI_DECIMAL128: {
auto valueStr = transformDecimal128(vectorID, rowID, decimalInfo);
file << valueStr;
break;
}
default: std::runtime_error("WriteToFileInternal data type not supported");
}
}
void VectorBatch::writeToFile(
std::string& filename,
std::ios_base::openmode mode,
std::vector<std::pair<int32_t, int32_t>> decimalInfo,
std::vector<std::string> inputTypes,
const std::string& tzStr) const
{
std::ofstream file;
if (!normalizeAndValidatePath(filename)) {
std::cerr << "Error validating file\n";
return;
}
file.open(filename, mode);
if (!file.is_open()) {
std::cerr << "Error opening file\n";
return;
}
std::vector<std::string> rowKindStr = {"+I", "-U", "+U", "-D"};
for (size_t i = 0; i < rowCnt; ++i) {
file << rowKindStr[(int)rowKinds[i]];
for (size_t j = 0; j < vectors.size(); ++j) {
file << ",";
if (vectors[j]->IsNull(i)) {
file << "NULL";
} else {
WriteToFileInternal(j, i, file, decimalInfo, inputTypes, tzStr);
}
}
file << "\n";
}
file.close();
LOG("write file finish");
}
void VectorBatch::convertToJson(
nlohmann::ordered_json& j,
int rowIndex,
std::vector<std::pair<int32_t, int32_t>> decimalInfo,
std::vector<std::string> inputTypes,
std::vector<std::string> inputFields) const
{
for (size_t colIndex = 0; colIndex < vectors.size(); ++colIndex) {
int dataId = vectors[colIndex]->GetTypeId();
switch (dataId) {
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP:
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case omniruntime::type::DataTypeId::OMNI_LONG: {
if (inputTypes[colIndex].substr(0, 9) == "TIMESTAMP") {
auto result = TransformTime(colIndex, rowIndex);
j[inputFields[colIndex]] = result;
} else {
auto result =
reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(vectors[colIndex])->GetValue(rowIndex);
j[inputFields[colIndex]] = result;
}
break;
}
case omniruntime::type::DataTypeId::OMNI_VARCHAR:
case omniruntime::type::DataTypeId::OMNI_CHAR: {
if (vectors[colIndex]->GetEncoding() == omniruntime::vec::OMNI_FLAT) {
auto casted = reinterpret_cast<
omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>*>(
vectors[colIndex]);
j[inputFields[colIndex]] = casted->GetValue(rowIndex);
} else {
auto casted = reinterpret_cast<omniruntime::vec::Vector<omniruntime::vec::DictionaryContainer<
std::string_view,
omniruntime::vec::LargeStringContainer>>*>(vectors[colIndex]);
j[inputFields[colIndex]] = casted->GetValue(rowIndex);
}
break;
}
case omniruntime::type::DataTypeId::OMNI_DOUBLE: {
auto result =
reinterpret_cast<omniruntime::vec::Vector<double>*>(vectors[colIndex])->GetValue(rowIndex);
j[inputFields[colIndex]] = result;
break;
}
case omniruntime::type::DataTypeId::OMNI_INT: {
auto result =
reinterpret_cast<omniruntime::vec::Vector<int32_t>*>(vectors[colIndex])->GetValue(rowIndex);
j[inputFields[colIndex]] = result;
break;
}
case omniruntime::type::DataTypeId::OMNI_BOOLEAN: {
auto result = reinterpret_cast<omniruntime::vec::Vector<bool>*>(vectors[colIndex])->GetValue(rowIndex);
j[inputFields[colIndex]] = result;
break;
}
case omniruntime::type::DataTypeId::OMNI_DECIMAL64: {
auto valueStr = transformDecimal64(colIndex, rowIndex, decimalInfo);
j[inputFields[colIndex]] = valueStr;
break;
}
case omniruntime::type::DataTypeId::OMNI_DECIMAL128: {
auto valueStr = transformDecimal128(colIndex, rowIndex, decimalInfo);
j[inputFields[colIndex]] = valueStr;
break;
}
default: std::runtime_error("convertToJson data type not supported");
}
}
LOG("convertToJson finish");
}
std::vector<XXH128_hash_t> VectorBatch::getXXH128s()
{
std::vector<XXH128_hash_t> hashes(rowCnt);
for (size_t i = 0; i < rowCnt; ++i) {
XXH3_state_t* state = XXH3_createState();
XXH3_128bits_reset(state);
for (auto vec : vectors) {
auto dataTypeId = vec->GetTypeId();
switch (dataTypeId) {
case OMNI_LONG:
case OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case OMNI_TIMESTAMP: {
auto casted = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(vec);
auto val = casted->GetValue(i);
XXH3_128bits_update(state, &val, sizeof(int64_t));
break;
}
case OMNI_VARCHAR:
case OMNI_CHAR: {
if (vec->GetEncoding() == omniruntime::vec::OMNI_FLAT) {
auto casted = reinterpret_cast<
omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>*>(vec);
auto val = casted->GetValue(i);
XXH3_128bits_update(state, val.data(), val.size());
} else {
auto casted = reinterpret_cast<omniruntime::vec::Vector<omniruntime::vec::DictionaryContainer<
std::string_view,
omniruntime::vec::LargeStringContainer>>*>(vec);
auto val = casted->GetValue(i);
XXH3_128bits_update(state, val.data(), val.size());
}
break;
}
default: XXH3_freeState(state); throw std::runtime_error("Type not supported yet");
}
}
hashes[i] = XXH3_128bits_digest(state);
XXH3_freeState(state);
}
return hashes;
}
omniruntime::vec::BaseVector* VectorBatch::CopyPositionsAndFlatten(
omniruntime::vec::BaseVector* input, const int* positions, int offset, int length)
{
if (input->GetTypeId() != omniruntime::type::OMNI_VARCHAR && input->GetTypeId() != omniruntime::type::OMNI_CHAR) {
throw std::runtime_error("Type is not Varchar or Char");
}
if (input->GetEncoding() != omniruntime::vec::OMNI_DICTIONARY) {
throw std::runtime_error("not dictionary");
}
using DictVarcharVecType = omniruntime::vec::Vector<
omniruntime::vec::DictionaryContainer<std::string_view, omniruntime::vec::LargeStringContainer>>;
auto casted = reinterpret_cast<DictVarcharVecType*>(input);
auto vector = new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(length);
auto startPositions = positions + offset;
for (int32_t i = 0; i < length; i++) {
auto position = startPositions[i];
if (input->IsNull(position)) {
vector->SetNull(i);
} else {
auto value = casted->GetValue(position);
vector->SetValue(i, value);
}
}
return vector;
}
omnistream::VectorBatch* VectorBatch::CreateVectorBatch(int rowCount, const std::vector<DataTypeId>& dataTypes)
{
auto* vectorBatch = new omnistream::VectorBatch(rowCount);
for (size_t i = 0; i < dataTypes.size(); i++) {
switch (dataTypes[i]) {
case (omniruntime::type::DataTypeId::OMNI_INT): {
auto vec = new omniruntime::vec::Vector<int32_t>(rowCount);
vectorBatch->Append(vec);
break;
}
case (omniruntime::type::DataTypeId::OMNI_LONG):
case (omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE):
case (omniruntime::type::DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE):
case (omniruntime::type::DataTypeId::OMNI_TIMESTAMP): {
auto vec = new omniruntime::vec::Vector<int64_t>(rowCount);
vectorBatch->Append(vec);
break;
}
case (omniruntime::type::DataTypeId::OMNI_CHAR):
case (omniruntime::type::DataTypeId::OMNI_VARCHAR): {
auto vec =
new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(rowCount);
vectorBatch->Append(vec);
break;
}
default: throw std::runtime_error("Unsupported type: " + dataTypes[i]);
}
}
return vectorBatch;
}
}