* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*/
#ifndef FLINK_TNEL_TOP1COMPARATOR_H
#define FLINK_TNEL_TOP1COMPARATOR_H
#include <iostream>
#include <vector>
#include <queue>
#include <unordered_map>
#include <cassert>
#include "table/data/vectorbatch/VectorBatch.h"
#include "table/runtime/keyselector/KeySelector.h"
using namespace omnistream;
class CompositeKeyRef;
class CompositeKeyRef {
public:
CompositeKeyRef(const std::vector<int>& columnIds, const std::vector<bool>& ascending, int rowId, omnistream::VectorBatch* vectorBatch)
: columnIds_(columnIds), ascending_(ascending), rowId_(rowId), vectorBatch_(vectorBatch) {}
__attribute__((noinline))
bool operator<(const CompositeKeyRef& other) const
{
ASSERT(columnIds_.size() == ascending_.size());
ASSERT(columnIds_.size() == other.columnIds_.size());
ASSERT(!columnIds_.empty());
ASSERT(!ascending_.empty());
ASSERT(vectorBatch_ != nullptr);
ASSERT(other.vectorBatch_ != nullptr);
for (size_t i = 0; i < columnIds_.size(); ++i) {
int colId = columnIds_[i];
DataTypeId dataTypeId = vectorBatch_->Get(colId)->GetTypeId();
int64_t a = 0;
int64_t b = 0;
switch (dataTypeId) {
case DataTypeId::OMNI_INT: {
auto* intVectorA = static_cast<omniruntime::vec::Vector<int32_t>*>(vectorBatch_->Get(colId));
auto* intVectorB = static_cast<omniruntime::vec::Vector<int32_t>*>(other.vectorBatch_->Get(colId));
a = static_cast<int64_t>(intVectorA->GetValue(rowId_));
b = static_cast<int64_t>(intVectorB->GetValue(other.rowId_));
break;
}
case DataTypeId::OMNI_LONG:
case DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE:
case DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case DataTypeId::OMNI_TIMESTAMP: {
auto* longVectorA = static_cast<omniruntime::vec::Vector<int64_t>*>(vectorBatch_->Get(colId));
auto* longVectorB = static_cast<omniruntime::vec::Vector<int64_t>*>(other.vectorBatch_->Get(colId));
a = longVectorA->GetValue(rowId_);
b = longVectorB->GetValue(other.rowId_);
break;
}
case DataTypeId::OMNI_DOUBLE: {
auto* doubleVectorA = static_cast<omniruntime::vec::Vector<double>*>(vectorBatch_->Get(colId));
auto* doubleVectorB = static_cast<omniruntime::vec::Vector<double>*>
(other.vectorBatch_->Get(colId));
a = static_cast<int64_t>(doubleVectorA->GetValue(rowId_));
b = static_cast<int64_t>(doubleVectorB->GetValue(other.rowId_));
break;
}
default:
throw std::runtime_error("Unsupported DataTypeId for column comparison.");
}
if (a != b) {
return ascending_[i] ? a < b : a > b;
}
}
return false;
}
int getRowId() const
{
return rowId_;
};
private:
std::vector<int> columnIds_;
std::vector<bool> ascending_;
int rowId_;
omnistream::VectorBatch* vectorBatch_;
};
template<typename K>
class Top1Comparator {
public:
Top1Comparator(const std::vector<int>& partitionKeyColTypes, const std::vector<int>& partitionKeyColId,
const std::vector<int>& sortColumnIds, const std::vector<bool>& ascending)
: partitionKeyColType_(partitionKeyColTypes), partitionKeyColId_(partitionKeyColId),
sortColumnIds_(sortColumnIds), ascending_(ascending), keySelector_(partitionKeyColType_, partitionKeyColId_) {}
struct CompositeKeyComparator {
bool operator()(const CompositeKeyRef& a, const CompositeKeyRef& b) const
{
return a < b;
}
};
std::unordered_map<K, int> findTop1RowIdsByPartition(omnistream::VectorBatch* vectorBatch)
{
LOG("findTop1RowIdsByPartition is called.");
ASSERT(vectorBatch != nullptr);
std::unordered_map<K, int> top1RowIds;
if (vectorBatch->GetVectorCount() == 0) {
return top1RowIds;
}
std::unordered_map<K, std::priority_queue<CompositeKeyRef, std::vector<CompositeKeyRef>,
CompositeKeyComparator>> partitionHeaps;
size_t rowCount = vectorBatch->GetRowCount();
for (size_t rowId = 0; rowId < rowCount; ++rowId) {
K key = keySelector_.getKey(vectorBatch, static_cast<int>(rowId));
CompositeKeyRef keyRef(sortColumnIds_, ascending_, static_cast<int>(rowId), vectorBatch);
auto it = partitionHeaps.find(key);
if (it != partitionHeaps.end()) {
it->second.push(keyRef);
if constexpr (std::is_same<K, RowData *>::value) {
delete key;
}
} else {
partitionHeaps[key].push(keyRef);
}
}
for (auto& [k, heap] : partitionHeaps) {
top1RowIds[k] = heap.top().getRowId();
}
return top1RowIds;
}
std::unordered_map<K, int> findTop1RowIdsByPartitionV2(omnistream::VectorBatch* vectorBatch)
{
LOG("findTop1RowIdsByPartition is called.");
ASSERT(vectorBatch != nullptr);
std::unordered_map<K, int> top1RowIds;
if (vectorBatch->GetVectorCount() == 0) {
return top1RowIds;
}
std::unordered_map<K, CompositeKeyRef> partitionHeaps;
size_t rowCount = vectorBatch->GetRowCount();
for (size_t rowId = 0; rowId < rowCount; ++rowId) {
K key = keySelector_.getKey(vectorBatch, static_cast<int>(rowId));
CompositeKeyRef keyRef(sortColumnIds_, ascending_, static_cast<int>(rowId), vectorBatch);
auto it = partitionHeaps.find(key);
if (it != partitionHeaps.end()) {
if (keyRef < it->second)
{
it->second = keyRef;
}
if constexpr (std::is_same<K, RowData *>::value) {
delete key;
}
} else {
partitionHeaps.emplace(key, keyRef);
}
}
for (auto& [k, heap] : partitionHeaps) {
top1RowIds[k] = heap.getRowId();
}
return top1RowIds;
}
private:
std::vector<int> partitionKeyColType_;
std::vector<int> partitionKeyColId_;
std::vector<int> sortColumnIds_;
std::vector<bool> ascending_;
KeySelector<K> keySelector_;
};
#endif