* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <iostream>
#include <vector>
#include <string_view>
#include <memory>
#include "RowTimeDeduplicateFunction.h"
#include "data/util/VectorBatchUtil.h"
void RowTimeDeduplicateFunction::initOutputVector(omnistream::VectorBatch *out, omnistream::VectorBatch *inputVB,
int rowCount)
{
if (rowCount <= 0) {
return;
}
for (int col = 0; col < inputVB->GetVectorCount(); col++) {
auto dataType = inputVB->Get(col)->GetTypeId();
if (dataType == omniruntime::type::DataTypeId::OMNI_INT) {
auto newVec = new omniruntime::vec::Vector<int32_t>(rowCount);
out->Append(newVec);
} else if (dataType == omniruntime::type::DataTypeId::OMNI_LONG ||
dataType == omniruntime::type::DataTypeId::OMNI_TIMESTAMP) {
auto newVec = new omniruntime::vec::Vector<int64_t>(rowCount);
out->Append(newVec);
} else if (dataType == omniruntime::type::DataTypeId::OMNI_CHAR) {
auto newVec =
std::make_unique<omniruntime::vec::Vector<
omniruntime::vec::LargeStringContainer<std::string_view>>>(rowCount);
out->Append(newVec.release());
}
}
}
void RowTimeDeduplicateFunction::processBatch(omnistream::VectorBatch *inputVB, Context &ctx, TimestampedCollector &out)
{
recordStateVB->addVectorBatch(inputVB);
delVb.insert(inputVB);
vectorBatchCacheMap.emplace(getCurrentBatchId() - 1, inputVB);
auto outputVectorBatch = ProcessUpdateRecord(inputVB, ctx);
freeDelBatch();
out.collect(outputVectorBatch);
}
bool RowTimeDeduplicateFunction::CompareRecord(int preRowId, int currentRowId,
omnistream::VectorBatch* previousVB,
omnistream::VectorBatch* currentVB)
{
long previousTime = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(previousVB->Get(rowtimeIndex_))
->GetValue(preRowId);
long currentTime = reinterpret_cast<omniruntime::vec::Vector<int64_t>*>(currentVB->Get(rowtimeIndex_))
->GetValue(currentRowId);
if (keepLastRow_) {
return previousTime <= currentTime;
}
else {
return currentTime < previousTime;
}
}
void RowTimeDeduplicateFunction::open(const Configuration &config)
{
LOG("RowTimeDeduplicateFunction open")
std::string deduplicateStateName = "recordState";
TypeSerializer *serializer = new LongSerializer();
auto *recordStateDesc = new ValueStateDescriptor<int64_t>(deduplicateStateName, serializer);
this->recordStateVB =
static_cast<StreamingRuntimeContext<RowData *> *>(getRuntimeContext())->getState<int64_t>(recordStateDesc);
if (dynamic_cast<RocksdbValueState<RowData *, VoidNamespace, int64_t> *>(recordStateVB)) {
static_cast<RocksdbValueState<RowData *, VoidNamespace, int64_t> *>(this->recordStateVB)->setDefaultValue(-1);
INFO_RELEASE("RowTimeDeduplicateFunction backend is rocksdb")
backendType_ = omnistream::StateType::ROCKSDB;
} else {
static_cast<HeapValueState<RowData *, VoidNamespace, int64_t> *>(this->recordStateVB)->setDefaultValue(-1);
INFO_RELEASE("RowTimeDeduplicateFunction backend is mem")
backendType_ = omnistream::StateType::HEAP;
}
LOG("RowTimeDeduplicateFunction open finish")
}
std::vector<int32_t> RowTimeDeduplicateFunction::getKeyedTypes(const std::vector<int32_t> keyedIndex,
const std::vector<std::string> inputTypes)
{
std::vector<int32_t> keyedTypes;
for (int32_t index : keyedIndex) {
if (index >= 0 && index < static_cast<int32_t>(inputTypes.size())) {
keyedTypes.push_back(LogicalType::flinkTypeToOmniTypeId(inputTypes[index]));
}
}
return keyedTypes;
}
void RowTimeDeduplicateFunction::freeDelBatch()
{
vectorBatchCacheMap.clear();
if (backendType_ == omnistream::StateType::HEAP) {
delVb.clear();
return;
}
for (auto vb: delVb) {
delete vb;
}
delVb.clear();
}
unordered_map<RowData*, int32_t> RowTimeDeduplicateFunction::GetNeededUpdateRecord(
omnistream::VectorBatch* inputVB)
{
int rowCount = inputVB->GetRowCount();
unordered_map<RowData *, int32_t> updatedRecord;
for (int i = 0; i < rowCount; i++) {
RowData *key = groupByKeySelector->getKey(inputVB, i);
auto itTmp = updatedRecord.find(key);
if (itTmp == updatedRecord.end()) {
updatedRecord[key] = i;
} else {
int32_t previousRowId = itTmp->second;
bool shouldUpdated = CompareRecord(previousRowId,i,inputVB,inputVB);
if (shouldUpdated) {
updatedRecord[key] = i;
}
delete key ;
}
}
return updatedRecord;
}
omnistream::VectorBatch* RowTimeDeduplicateFunction::ProcessUpdateRecord(omnistream::VectorBatch* inputVB, Context& ctx)
{
std::vector<std::tuple<long, long, RowData*>> updatedRecords;
updatedRecords.reserve(inputVB->GetRowCount() / 4);
int outPutRowCount = 0;
auto updateState = GetNeededUpdateRecord(inputVB);
INFO_RELEASE("state size : " << updateState.size())
int currentBatchId = getCurrentBatchId() - 1;
omnistream::VectorBatch* outputVB = nullptr;
for (auto it : updateState) {
RowData* k = it.first;
int currentRowId = it.second;
ctx.setCurrentKey(k);
int64_t preCombineId = recordStateVB->value();
if (preCombineId == -1) {
long combineId = VectorBatchUtil::getComboId(currentBatchId, currentRowId);
updatedRecords.emplace_back(-99, combineId, k);
outPutRowCount++;
}
else {
int preBatchId = VectorBatchUtil::getBatchId(preCombineId);
int preRowId = VectorBatchUtil::getRowId(preCombineId);
omnistream::VectorBatch* previousVectorBatch = GetVectorBatchById(preBatchId);
if (previousVectorBatch == nullptr) {
LOG("previousVectorBatch is nullptr..........why......")
}
else {
bool isUpdated = CompareRecord(preRowId, currentRowId, previousVectorBatch, inputVB);
if (isUpdated) {
long combineId = VectorBatchUtil::getComboId(currentBatchId, currentRowId);
updatedRecords.emplace_back(preCombineId, combineId, k);
if (generateUpdateBefore_) {
outPutRowCount += 2;
}
else {
outPutRowCount++;
}
}
}
}
}
outputVB = new omnistream::VectorBatch(outPutRowCount);
INFO_RELEASE("outPutRowCount is " << outPutRowCount)
initOutputVector(outputVB, inputVB, outPutRowCount);
int rowIndex = 0;
for (int i = 0; i < updatedRecords.size(); i++) {
auto record = updatedRecords[i];
long preCombineId = std::get<0>(record);
long curCombineId = std::get<1>(record);
if (generateUpdateBefore_ || generateInsert_) {
if (preCombineId == -99) {
outputVB->setRowKind(rowIndex, RowKind::INSERT);
CopyTargetVectorBatchToOut(outputVB, curCombineId, rowIndex);
rowIndex++;
}
else {
if (generateUpdateBefore_) {
outputVB->setRowKind(rowIndex, RowKind::UPDATE_BEFORE);
CopyTargetVectorBatchToOut(outputVB, preCombineId, rowIndex);
rowIndex++;
}
outputVB->setRowKind(rowIndex, RowKind::UPDATE_AFTER);
CopyTargetVectorBatchToOut(outputVB, curCombineId, rowIndex);
rowIndex++;
}
}
else {
outputVB->setRowKind(rowIndex, RowKind::UPDATE_AFTER);
CopyTargetVectorBatchToOut(outputVB, curCombineId, rowIndex);
rowIndex++;
}
}
if (updatedRecords.size() > 0) {
UpdateStateBackend(updatedRecords, ctx);
}
return outputVB;
}
omnistream::VectorBatch* RowTimeDeduplicateFunction::GetVectorBatchById(int32_t batchId)
{
LOG("batchis is " << batchId)
auto it = vectorBatchCacheMap.find(batchId);
if (it != vectorBatchCacheMap.end()) {
return it->second;
}else {
auto vb = recordStateVB->getVectorBatch(batchId);
vectorBatchCacheMap.emplace(batchId, vb);
delVb.insert(vb);
return vb;
}
}
void RowTimeDeduplicateFunction::UpdateStateBackend(std::vector<std::tuple<long,long,RowData*>> &updateRecords,Context& ctx)
{
if (backendType_ == omnistream::StateType::HEAP) {
for (auto record : updateRecords) {
long curCombineId = std::get<1>(record);
auto rowKey = std::get<2>(record);
ctx.setCurrentKey(rowKey);
recordStateVB->update(curCombineId);
}
}else {
std::unordered_map<RowData*, int64_t> pendingUpdates ;
for (auto record : updateRecords) {
long curCombineId = std::get<1>(record);
auto rowKey = std::get<2>(record);
pendingUpdates.emplace(rowKey, curCombineId);
}
auto rocksdbStateBackend = dynamic_cast<RocksdbValueState<RowData *, VoidNamespace, int64_t> *>(recordStateVB) ;
if (rocksdbStateBackend) {
rocksdbStateBackend->updateByBatch(pendingUpdates);
}
}
for (auto record : updateRecords) {
auto rowKey = std::get<2>(record);
delete rowKey;
}
}
void RowTimeDeduplicateFunction::CopyTargetVectorBatchToOut(omnistream::VectorBatch *outputVB, long comboID, int rowIndex)
{
LOG("comboID is, " << comboID)
int batchId = VectorBatchUtil::getBatchId(comboID);
int rowId = VectorBatchUtil::getRowId(comboID);
auto batch = GetVectorBatchById(batchId);
outputVB->setTimestamp(rowIndex, batch->getTimestamp(rowId));
int colCount = batch->GetVectorCount();
for (int col = 0; col < colCount; col++) {
auto dataType = batch->Get(col)->GetTypeId();
if (dataType == omniruntime::type::DataTypeId::OMNI_INT) {
auto val =
reinterpret_cast<omniruntime::vec::Vector<int32_t> *>(batch->Get(col))->GetValue(
rowId);
reinterpret_cast<omniruntime::vec::Vector<int32_t> *>(outputVB->Get(col))
->SetValue(rowIndex, val);
} else if (dataType == omniruntime::type::DataTypeId::OMNI_LONG) {
auto val =
reinterpret_cast<omniruntime::vec::Vector<int64_t> *>(batch->Get(col))->GetValue(
rowId);
reinterpret_cast<omniruntime::vec::Vector<int64_t> *>(outputVB->Get(col))
->SetValue(rowIndex, val);
} else if (dataType == omniruntime::type::DataTypeId::OMNI_CHAR) {
if (batch->Get(col)->GetEncoding() == omniruntime::vec::OMNI_FLAT) {
auto casted = reinterpret_cast<omniruntime::vec::Vector
<omniruntime::vec::LargeStringContainer<std::string_view>> *>(batch->Get(col))->GetValue(rowId);
reinterpret_cast<omniruntime::vec::Vector
<omniruntime::vec::LargeStringContainer<std::string_view>> *>(
outputVB->Get(col))
->SetValue(rowIndex, casted);
} else {
auto casted =
reinterpret_cast<omniruntime::vec::Vector<omniruntime::vec::DictionaryContainer<
std::string_view, omniruntime::vec::LargeStringContainer>> *>(batch->Get(col))->GetValue(rowId);
reinterpret_cast<omniruntime::vec::Vector
<omniruntime::vec::LargeStringContainer<std::string_view>> *>(
outputVB->Get(col))
->SetValue(rowIndex, casted);
}
}
}
}