* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "SharedDistinctCountContainerFunction.h"
#include <algorithm>
#include <optional>
#include <stdexcept>
#include "core/typeutils/LongSerializer.h"
#include "runtime/dataview/PerKeyStateDataViewStore.h"
using namespace omniruntime::type;
namespace {
TypeSerializer* createOwnedSharedDistinctSerializer(DataTypeId typeId)
{
switch (typeId) {
case DataTypeId::OMNI_INT:
case DataTypeId::OMNI_LONG: return new LongSerializer();
default: throw std::runtime_error("Shared DISTINCT only supports INT/LONG key types.");
}
}
}
SharedDistinctCountContainerFunction::SharedDistinctCountContainerFunction(std::string stateName)
: stateName_(std::move(stateName))
{
}
void SharedDistinctCountContainerFunction::addDistinctEntry(
int aggFuncIndex, const std::string& aggType, int aggIdx, int filterIndex, const std::string& inputType)
{
auto* entry = new DistinctEntry();
entry->aggFuncIndex = aggFuncIndex;
entry->aggType = aggType;
entry->aggIdx = aggIdx;
entry->filterIndex = filterIndex;
entry->inputType = inputType;
entry->typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
entries_.push_back(entry);
finalized_ = false;
}
void SharedDistinctCountContainerFunction::getOrCreateGroup(DistinctEntry* entry)
{
int aggIdx = entry->aggIdx;
auto found = distinctGroupMap.find(aggIdx);
if (found != distinctGroupMap.end()) {
groups_[found->second].groupEntries.push_back(entry);
} else {
DistinctGroup group;
group.aggIdx = aggIdx;
group.typeId = entry->typeId;
group.stateName = groups_.empty() ? stateName_ : stateName_ + "_" + std::to_string(groups_.size());
groups_.push_back(std::move(group));
groups_.back().groupEntries.push_back(entry);
distinctGroupMap.emplace(aggIdx, groups_.size() - 1);
}
}
void SharedDistinctCountContainerFunction::finalizeEntries()
{
if (finalized_) {
return;
}
if (entries_.empty()) {
throw std::runtime_error("SharedDistinctCountContainerFunction has no entries.");
}
entryIndexToAggFuncIndex.clear();
entryIndexToAggFuncIndex.reserve(entries_.size());
groups_.clear();
distinctGroupMap.clear();
for (int i = 0; i < entries_.size(); ++i) {
auto& entry = entries_[i];
if (entryIndexToAggFuncIndex.find(entry->aggFuncIndex) != entryIndexToAggFuncIndex.end()) {
throw std::runtime_error("Duplicated aggFuncIndex in shared DISTINCT container.");
}
if (entry->typeId != DataTypeId::OMNI_INT && entry->typeId != DataTypeId::OMNI_LONG) {
throw std::runtime_error("Shared DISTINCT only supports INT/LONG key types.");
}
getOrCreateGroup(entry);
entryIndexToAggFuncIndex.emplace(entry->aggFuncIndex, i);
}
for (auto& group : groups_) {
if (group.groupEntries.size() > 64) {
throw std::runtime_error(
"SharedDistinctCountContainerFunction supports at most 64 DISTINCT entries per key column.");
}
group.pendingDistinctUpdates.reserve(500);
for (int bitOffset = 0; bitOffset < group.groupEntries.size(); ++bitOffset) {
group.groupEntries[bitOffset]->bitOffset = static_cast<std::uint8_t>(bitOffset);
}
}
finalized_ = true;
}
void SharedDistinctCountContainerFunction::bindEntryAccValueIndex(int aggFuncIndex, int accIndex, int valueIndex)
{
if (!finalized_) {
throw std::runtime_error("bindEntryIndices called before finalizeEntries.");
}
auto it = entryIndexToAggFuncIndex.find(aggFuncIndex);
if (it == entryIndexToAggFuncIndex.end()) {
throw std::runtime_error("Unknown aggFuncIndex for shared DISTINCT entry binding.");
}
auto& entry = entries_[it->second];
entry->accIndex = accIndex;
entry->valueIndex = valueIndex;
}
void SharedDistinctCountContainerFunction::bindAccValueIndex(int accStartIndex, int valueStartIndex)
{
finalizeEntries();
for (auto& entry : entries_) {
entry->accIndex = accStartIndex++;
entry->valueIndex = valueStartIndex++;
}
}
bool SharedDistinctCountContainerFunction::equaliser(BinaryRowData* r1, BinaryRowData* r2)
{
for (const auto& entry : entries_) {
if (entry->valueIndex < 0) {
continue;
}
if (r1->isNullAt(entry->valueIndex) || r2->isNullAt(entry->valueIndex)) {
return false;
}
if (*r1->getLong(entry->valueIndex) != *r2->getLong(entry->valueIndex)) {
return false;
}
}
return true;
}
void SharedDistinctCountContainerFunction::open(StateDataViewStore* store)
{
finalizeEntries();
store_ = store;
auto* perKeyViewStore = reinterpret_cast<PerKeyStateDataViewStore<RowData*>*>(store_);
for (auto& group : groups_) {
group.distinctMapView = reinterpret_cast<KeyedStateMapViewWithKeysNullable<VoidNamespace, long, long>*>(
perKeyViewStore->getStateMapView<VoidNamespace, long, long>(
group.stateName, true, createOwnedSharedDistinctSerializer(group.typeId), new LongSerializer()));
}
}
bool SharedDistinctCountContainerFunction::shouldAccumulateForEntry(DistinctEntry* entry, RowData* inputRow) const
{
if (entry->filterIndex < 0) {
return true;
}
const bool isFilterNull = inputRow->isNullAt(entry->filterIndex);
return !isFilterNull && *inputRow->getBool(entry->filterIndex);
}
std::uint64_t SharedDistinctCountContainerFunction::collectCandidateMask(
RowData* inputRow, const DistinctGroup& group) const
{
std::uint64_t candidateMask = 0ULL;
for (auto entry : group.groupEntries) {
if (shouldAccumulateForEntry(entry, inputRow)) {
candidateMask |= (1ULL << entry->bitOffset);
}
}
return candidateMask;
}
void SharedDistinctCountContainerFunction::applyMaskDelta(std::size_t groupIndex, std::uint64_t deltaMask)
{
auto& group = groups_[groupIndex];
for (auto& entry : group.groupEntries) {
if ((deltaMask & (1ULL << entry->bitOffset)) == 0ULL) {
continue;
}
if (entry->valueIsNull) {
entry->aggCount = 1L;
entry->valueIsNull = false;
} else {
entry->aggCount++;
}
}
}
long SharedDistinctCountContainerFunction::getRowFieldValue(
RowData* row, int aggIdx, DataTypeId typeId, bool& isNull) const
{
isNull = row->isNullAt(aggIdx);
if (isNull) {
return 0L;
}
switch (typeId) {
case DataTypeId::OMNI_INT: return static_cast<long>(*row->getInt(aggIdx));
case DataTypeId::OMNI_LONG: return *row->getLong(aggIdx);
default: throw std::runtime_error("Unsupported shared DISTINCT key type.");
}
}
void SharedDistinctCountContainerFunction::accumulate(RowData* accInput)
{
for (std::size_t groupIndex = 0; groupIndex < groups_.size(); ++groupIndex) {
auto& group = groups_[groupIndex];
if (group.distinctMapView == nullptr) {
continue;
}
const std::uint64_t candidateMask = collectCandidateMask(accInput, group);
if (candidateMask == 0ULL) {
continue;
}
bool isNull = false;
const long fieldValue = getRowFieldValue(accInput, group.aggIdx, group.typeId, isNull);
if (isNull) {
continue;
}
const auto existingValue = group.distinctMapView->get(std::optional<long>{fieldValue});
const std::uint64_t existingMask =
existingValue.has_value() ? static_cast<std::uint64_t>(*existingValue) : 0ULL;
const std::uint64_t newMask = existingMask | candidateMask;
const std::uint64_t deltaMask = newMask ^ existingMask;
if (deltaMask == 0ULL) {
continue;
}
applyMaskDelta(groupIndex, deltaMask);
group.distinctMapView->put(std::optional<long>{fieldValue}, static_cast<long>(newMask));
}
}
void SharedDistinctCountContainerFunction::accumulate(omnistream::VectorBatch* input, const std::vector<int>& indices)
{
if (indices.empty()) {
return;
}
for (std::size_t groupIndex = 0; groupIndex < groups_.size(); ++groupIndex) {
auto& group = groups_[groupIndex];
if (group.distinctMapView == nullptr) {
continue;
}
auto* columnData = input->Get(group.aggIdx);
auto* intColumn =
(group.typeId == DataTypeId::OMNI_INT) ? dynamic_cast<omniruntime::vec::Vector<int>*>(columnData) : nullptr;
auto* longColumn = (group.typeId == DataTypeId::OMNI_LONG)
? dynamic_cast<omniruntime::vec::Vector<long>*>(columnData)
: nullptr;
if ((group.typeId == DataTypeId::OMNI_INT && intColumn == nullptr) ||
(group.typeId == DataTypeId::OMNI_LONG && longColumn == nullptr)) {
throw std::runtime_error("Input column type mismatch for shared DISTINCT.");
}
std::vector<omniruntime::vec::Vector<bool>*> filterColumns;
filterColumns.reserve(group.groupEntries.size());
for (const auto& entry : group.groupEntries) {
if (entry->filterIndex >= 0) {
filterColumns.push_back(
reinterpret_cast<omniruntime::vec::Vector<bool>*>(input->Get(entry->filterIndex)));
} else {
filterColumns.push_back(nullptr);
}
}
std::unordered_map<long, std::uint64_t> batchRequestMasks;
batchRequestMasks.reserve(indices.size());
for (int rowIndex : indices) {
if (columnData->IsNull(rowIndex)) {
continue;
}
std::uint64_t candidateMask = 0ULL;
for (std::size_t i = 0; i < group.groupEntries.size(); ++i) {
const auto& entry = group.groupEntries[i];
bool shouldAccumulate = true;
if (entry->filterIndex >= 0) {
auto* filterData = filterColumns[i];
const bool isFilterNull = filterData->IsNull(rowIndex);
shouldAccumulate = !isFilterNull && filterData->GetValue(rowIndex);
}
if (shouldAccumulate) {
candidateMask |= (1ULL << entry->bitOffset);
}
}
if (candidateMask == 0ULL) {
continue;
}
long fieldValue = 0L;
if (group.typeId == DataTypeId::OMNI_INT) {
fieldValue = static_cast<long>(intColumn->GetValue(rowIndex));
} else {
fieldValue = longColumn->GetValue(rowIndex);
}
batchRequestMasks[fieldValue] |= candidateMask;
}
for (const auto& pair : batchRequestMasks) {
const long fieldValue = pair.first;
const std::uint64_t candidateMask = pair.second;
const auto existingValue = group.distinctMapView->get(std::optional<long>{fieldValue});
const std::uint64_t existingMask =
existingValue.has_value() ? static_cast<std::uint64_t>(*existingValue) : 0ULL;
const std::uint64_t newMask = existingMask | candidateMask;
const std::uint64_t deltaMask = newMask & (~existingMask);
if (deltaMask == 0ULL) {
continue;
}
applyMaskDelta(groupIndex, deltaMask);
if (backend == 2 && currentGroupKey_ != nullptr) {
group.pendingDistinctUpdates[currentGroupKey_].emplace_back(fieldValue, static_cast<long>(newMask));
} else {
group.distinctMapView->put(std::optional<long>{fieldValue}, static_cast<long>(newMask));
}
}
}
}
void SharedDistinctCountContainerFunction::merge(RowData* otherAcc)
{
throw std::runtime_error("SharedDistinctCountContainerFunction does not support merge.");
}
void SharedDistinctCountContainerFunction::setAccumulators(RowData* acc)
{
for (auto& entry : entries_) {
if (entry->accIndex < 0) {
continue;
}
entry->valueIsNull = acc->isNullAt(entry->accIndex);
entry->aggCount = entry->valueIsNull ? -1L : *acc->getLong(entry->accIndex);
}
}
void SharedDistinctCountContainerFunction::resetAccumulators()
{
for (auto& entry : entries_) {
entry->aggCount = 0L;
entry->valueIsNull = false;
}
for (auto& group : groups_) {
group.pendingDistinctUpdates.clear();
if (group.distinctMapView == nullptr) {
continue;
}
auto* entries = group.distinctMapView->entries();
if (entries == nullptr) {
continue;
}
std::vector<long> keysToRemove;
keysToRemove.reserve(entries->size());
for (const auto& entry : *entries) {
keysToRemove.push_back(entry.first);
}
for (long distinctKey : keysToRemove) {
group.distinctMapView->remove(std::optional<long>{distinctKey});
}
}
}
void SharedDistinctCountContainerFunction::getAccumulators(BinaryRowData* accumulators)
{
for (const auto& entry : entries_) {
if (entry->accIndex < 0) {
continue;
}
if (entry->valueIsNull) {
accumulators->setNullAt(entry->accIndex);
} else {
accumulators->setLong(entry->accIndex, entry->aggCount);
}
}
}
void SharedDistinctCountContainerFunction::createAccumulators(BinaryRowData* accumulators)
{
for (const auto& entry : entries_) {
if (entry->accIndex >= 0) {
accumulators->setLong(entry->accIndex, 0L);
}
}
}
void SharedDistinctCountContainerFunction::getValue(BinaryRowData* aggValue)
{
for (const auto& entry : entries_) {
if (entry->valueIndex < 0) {
continue;
}
if (entry->valueIsNull) {
aggValue->setNullAt(entry->valueIndex);
} else {
aggValue->setLong(entry->valueIndex, entry->aggCount);
}
}
}
void SharedDistinctCountContainerFunction::cleanup()
{
for (auto& group : groups_) {
group.pendingDistinctUpdates.clear();
}
}
void SharedDistinctCountContainerFunction::close()
{
for (auto& group : groups_) {
group.pendingDistinctUpdates.clear();
}
}
void SharedDistinctCountContainerFunction::setCurrentGroupKey(RowData* key)
{
currentGroupKey_ = key;
}
void SharedDistinctCountContainerFunction::updateInnerState()
{
for (auto& group : groups_) {
if (group.distinctMapView != nullptr && !group.pendingDistinctUpdates.empty()) {
group.distinctMapView->putByBatch(group.pendingDistinctUpdates);
}
group.pendingDistinctUpdates.clear();
if (group.distinctMapView != nullptr) {
group.distinctMapView->cleanup();
}
}
}
long SharedDistinctCountContainerFunction::getRowFieldValueFromVB(
omnistream::VectorBatch* input, int columnIdx, int rowIdx, DataTypeId typeId, bool& isNull) const
{
auto* columnData = input->Get(columnIdx);
switch (typeId) {
case DataTypeId::OMNI_INT: {
auto* intColumn = dynamic_cast<omniruntime::vec::Vector<int>*>(columnData);
if (intColumn == nullptr) {
throw std::runtime_error("Input column type mismatch for shared DISTINCT.");
}
isNull = intColumn->IsNull(rowIdx);
return isNull ? 0L : static_cast<long>(intColumn->GetValue(rowIdx));
}
case DataTypeId::OMNI_LONG: {
auto* longColumn = dynamic_cast<omniruntime::vec::Vector<long>*>(columnData);
if (longColumn == nullptr) {
throw std::runtime_error("Input column type mismatch for shared DISTINCT.");
}
isNull = longColumn->IsNull(rowIdx);
return isNull ? 0L : longColumn->GetValue(rowIdx);
}
default: throw std::runtime_error("Unsupported shared DISTINCT key type.");
}
}