* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#pragma once
#include <cstdint>
#include "OmniOperatorJIT/core/src/type/data_type.h"
#include "OmniOperatorJIT/core/src/vector/unsafe_vector.h"
#include "OmniOperatorJIT/core/src/vector/vector.h"
#include "OmniOperatorJIT/core/src/vector/vector_helper.h"
#include "OmniOperatorJIT/core/src/vector/nulls_buffer.h"
#include "table/data/vectorbatch/VectorBatch.h"
using namespace omniruntime::vec;
using namespace omniruntime::type;
using namespace omniruntime::vec::unsafe;
namespace omnistream {
class VectorBatchDeserializationUtils {
public:
static VectorBatch* deserializeVectorBatch(uint8_t *&buffer)
{
LOG("----deserializeVectorBatch start:: " << buffer)
int32_t batchSize;
memcpy_s(&batchSize, sizeof(batchSize), buffer, sizeof(batchSize));
buffer += sizeof(batchSize);
int32_t vectorCount;
memcpy_s(&vectorCount, sizeof(vectorCount), buffer,
sizeof(vectorCount));
buffer += sizeof(vectorCount);
int32_t rowCnt = 0;
memcpy_s(&rowCnt, sizeof(rowCnt), buffer, sizeof(rowCnt));
buffer += sizeof(rowCnt);
VectorBatch* batch = new VectorBatch(rowCnt);
memcpy_s(batch->getTimestamps(), sizeof(int64_t) * rowCnt, buffer,
sizeof(int64_t) * rowCnt);
buffer += sizeof(int64_t) * rowCnt;
memcpy_s(batch->getRowKinds(), sizeof(RowKind) * rowCnt, buffer,
sizeof(RowKind) * rowCnt);
buffer += sizeof(RowKind) * rowCnt;
for (int idx = 0; idx < vectorCount; idx++) {
int32_t vectorSize;
memcpy_s(&vectorSize, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
int8_t encodingNum;
memcpy_s(&encodingNum, sizeof(int8_t), buffer, sizeof(int8_t));
buffer += sizeof(int8_t);
int8_t dataTypeNum;
memcpy_s(&dataTypeNum, sizeof(int8_t), buffer, sizeof(int8_t));
buffer += sizeof(int8_t);
DataTypeId dataType = static_cast<DataTypeId>(dataTypeNum);
Encoding encoding = static_cast<Encoding>(encodingNum);
if (encoding == OMNI_FLAT) {
if (dataType == OMNI_CHAR || dataType == OMNI_VARCHAR) {
batch->Append(deserializeCharVector(rowCnt, buffer));
} else {
batch->Append(
deserializePrimitiveVector(rowCnt, buffer, dataType));
}
} else if (encoding == OMNI_DICTIONARY) {
if (dataType == OMNI_CHAR || dataType == OMNI_VARCHAR) {
batch->Append(deserializeStringDictionaryContainerVector(
rowCnt, buffer));
} else {
throw std::runtime_error("Unsupported data type");
}
}
}
LOG("----deserializeVectorBatch END:: " << reinterpret_cast<long>(buffer))
return batch;
}
static BaseVector *deserializePrimitiveVector(int32_t vectorSize, uint8_t *&buffer, DataTypeId dataType)
{
BaseVector *baseVector = nullptr;
switch (dataType) {
case OMNI_LONG:
case OMNI_DATE64:
case OMNI_TIME64:
case OMNI_TIMESTAMP:
case OMNI_DECIMAL64: {
Vector<int64_t> *vector64 = new Vector<int64_t>(vectorSize);
baseVector = vector64;
deserializeInt64(vector64, buffer);
break;
}
case OMNI_INT:
case OMNI_DATE32: {
Vector<int32_t> *vector32 = new Vector<int32_t>(vectorSize);
baseVector = vector32;
deserializeInt32(vector32, buffer);
break;
}
case OMNI_SHORT: {
Vector<int16_t> *vector16 = new Vector<int16_t>(vectorSize);
baseVector = vector16;
deserializeInt16(vector16, buffer);
break;
}
case OMNI_DOUBLE: {
Vector<double> *vectorDouble = new Vector<double>(vectorSize);
baseVector = vectorDouble;
deserializeDouble(vectorDouble, buffer);
break;
}
case OMNI_BOOLEAN: {
Vector<bool> *vectorBool = new Vector<bool>(vectorSize);
baseVector = vectorBool;
deserializeBool(vectorBool, buffer);
break;
}
case OMNI_DECIMAL128: {
Vector<Decimal128> *vectorDecimal128 =
new Vector<Decimal128>(vectorSize);
baseVector = vectorDecimal128;
deserializeDecimal128(vectorDecimal128, buffer);
break;
}
default:
throw std::runtime_error("Unsupported data type");
}
return baseVector;
}
static Vector<LargeStringContainer<std::string_view>> *
deserializeCharVector(int32_t size, uint8_t *&buffer)
{
int32_t stringBodySize;
memcpy_s(&stringBodySize, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
Vector<LargeStringContainer<std::string_view>> *charVector =
new Vector<LargeStringContainer<std::string_view>>(size,
stringBodySize);
deserializeNulls(charVector, buffer, size);
int32_t *offsetArr = UnsafeStringVector::GetOffsets(charVector);
memcpy_s(offsetArr, sizeof(int32_t) * (size + 1), buffer,
sizeof(int32_t) * (size + 1));
buffer += sizeof(int32_t) * (size + 1);
size_t copySize = offsetArr[size] * sizeof(char);
memcpy_s(UnsafeStringVector::GetValues(charVector), copySize, buffer,
copySize);
buffer += offsetArr[size] * sizeof(char);
return charVector;
}
static void deserializeNulls(BaseVector *baseVector, uint8_t *&buffer,
int32_t size)
{
auto nullData = UnsafeBaseVector::GetNulls(baseVector);
auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(size);
memcpy_s(nullData, nullByteSize, buffer, nullByteSize);
buffer += nullByteSize;
}
static void deserializeInt64(Vector<int64_t> *vector64, uint8_t *&buffer)
{
int32_t size = vector64->GetSize();
deserializeNulls(vector64, buffer, size);
int64_t *data = UnsafeVector::GetRawValues(vector64);
memcpy_s(data, sizeof(int64_t) * size, buffer, sizeof(int64_t) * size);
buffer += sizeof(int64_t) * size;
}
static void deserializeInt32(Vector<int32_t> *vector32, uint8_t *&buffer)
{
int32_t size = vector32->GetSize();
deserializeNulls(vector32, buffer, size);
int32_t *data = UnsafeVector::GetRawValues(vector32);
memcpy_s(data, sizeof(int32_t) * size, buffer, sizeof(int32_t) * size);
buffer += sizeof(int32_t) * size;
}
static void deserializeInt16(Vector<int16_t> *vector16, uint8_t *&buffer)
{
int32_t size = vector16->GetSize();
deserializeNulls(vector16, buffer, size);
int16_t *data = UnsafeVector::GetRawValues(vector16);
memcpy_s(data, sizeof(int16_t) * size, buffer, sizeof(int16_t) * size);
buffer += sizeof(int16_t) * size;
}
static void deserializeDouble(Vector<double> *vectorDouble,
uint8_t *&buffer)
{
int32_t size = vectorDouble->GetSize();
deserializeNulls(vectorDouble, buffer, size);
double *data = UnsafeVector::GetRawValues(vectorDouble);
memcpy_s(data, sizeof(double) * size, buffer, sizeof(double) * size);
buffer += sizeof(double) * size;
}
static void deserializeBool(Vector<bool> *vectorBool, uint8_t *&buffer)
{
int32_t size = vectorBool->GetSize();
deserializeNulls(vectorBool, buffer, size);
bool *data = UnsafeVector::GetRawValues(vectorBool);
memcpy_s(data, sizeof(bool) * size, buffer, sizeof(bool) * size);
buffer += sizeof(bool) * size;
}
static void deserializeDecimal128(Vector<Decimal128> *vectorDecimal128,
uint8_t *&buffer)
{
int32_t size = vectorDecimal128->GetSize();
deserializeNulls(vectorDecimal128, buffer, size);
Decimal128 *data = UnsafeVector::GetRawValues(vectorDecimal128);
memcpy_s(data, sizeof(Decimal128) * size, buffer,
sizeof(Decimal128) * size);
buffer += sizeof(Decimal128) * size;
}
static Vector<DictionaryContainer<std::string_view, LargeStringContainer>> *
deserializeStringDictionaryContainerVector(int32_t rowCnt,
uint8_t *&buffer)
{
int32_t *values = new int32_t[rowCnt];
memcpy_s(values, sizeof(int32_t) * rowCnt, buffer,
sizeof(int32_t) * rowCnt);
buffer += sizeof(int32_t) * rowCnt;
int32_t dictOffset = 0;
memcpy_s(&dictOffset, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
int32_t dictSize = 0;
memcpy_s(&dictSize, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
int32_t stringBodySize = 0;
memcpy_s(&stringBodySize, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
auto nullByteSize = omniruntime::vec::NullsBuffer::CalculateNbytes(rowCnt);
std::shared_ptr<AlignedBuffer<uint8_t>> nullsBuffer =
std::make_shared<AlignedBuffer<uint8_t>>(nullByteSize);
memcpy(nullsBuffer->GetBuffer(), buffer,
nullByteSize);
buffer += nullByteSize;
std::shared_ptr<LargeStringContainer<std::string_view>>
stringContainer =
std::make_shared<LargeStringContainer<std::string_view>>(
dictSize, stringBodySize);
int32_t *offset =
UnsafeStringContainer::GetOffsets(stringContainer.get());
auto ret = memcpy_s(offset, sizeof(int32_t) * (dictSize + 1), buffer,
sizeof(int32_t) * (dictSize + 1));
if (ret != EOK) {
LOG("memcpy_s failed for string offset, ret = " << ret)
}
buffer += sizeof(int32_t) * (dictSize + 1);
ret = memcpy_s(UnsafeStringContainer::GetValues(stringContainer.get()),
offset[dictSize] * sizeof(char), buffer,
offset[dictSize] * sizeof(char));
if (ret != EOK) {
LOG("memcpy_s failed for string offset, ret = " << ret)
}
buffer += offset[dictSize] * sizeof(char);
auto dictionary =
std::make_shared<DictionaryContainer<std::string_view>>(
values, rowCnt, stringContainer, dictSize, dictOffset);
auto newNullsBuffer = new NullsBuffer(rowCnt, nullsBuffer);
auto stringDictionaryVector =
new Vector<DictionaryContainer<std::string_view>>(
rowCnt, dictionary, newNullsBuffer, false, OMNI_CHAR);
return stringDictionaryVector;
}
static long derializeWatermark(uint8_t*& buffer)
{
long timestamp;
memcpy_s(×tamp, sizeof(long), buffer, sizeof(long));
buffer += sizeof(long);
return timestamp;
}
};
}