* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "GroupAggFunction.h"
#include "table/typeutils/BinaryRowDataSerializer.h"
#include "runtime/generated/function/AverageFunction.h"
#include "runtime/generated/function/CountDistinctFunction.h"
#include "runtime/dataview/PerKeyStateDataViewStore.h"
#include "runtime/generated/function/CountFunction.h"
#include "runtime/generated/function/MinMaxFunction.h"
#include "runtime/generated/function/SumFunction.h"
#include "runtime/generated/function/udf/LastStringValueFunction.h"
#include <iostream>
#include <regex>
GroupAggFunction::GroupAggFunction(long stateRetentionTime, const nlohmann::json& config)
: stateRetentionTime(stateRetentionTime),
description(config)
{
indexOfCountStar = config["aggInfoList"]["indexOfCountStar"];
recordCounter = std::move(RecordCounter::of(indexOfCountStar));
accTypes = config["aggInfoList"]["accTypes"].get<std::vector<std::string>>();
accTypes.erase(
std::remove_if(
accTypes.begin(),
accTypes.end(),
[](const std::string& type) { return type.find("RAW") != std::string::npos; }),
accTypes.end());
aggValueTypes = config["aggInfoList"]["aggValueTypes"].get<std::vector<std::string>>();
accumulatorArity = accTypes.size();
generateUpdateBefore = config.value("generateUpdateBefore", false);
inputTypes = config["inputTypes"].get<std::vector<std::string>>();
outputTypes = config["outputTypes"].get<std::vector<std::string>>();
keyedIndex = config["grouping"].get<std::vector<int32_t>>();
keyedTypes = getKeyedTypes(keyedIndex, config["inputTypes"]);
groupByKeySelector = new KeySelector<RowData*>(keyedTypes, keyedIndex);
distinctInfos = config["distinctInfos"].get<std::vector<DistinctInfo>>();
}
GroupAggFunction::~GroupAggFunction()
{
for (auto func : functions) {
delete func;
}
}
bool IntEqualiser(RowData* r1, RowData* r2, int colIdx)
{
return *r1->getInt(colIdx) == *r2->getInt(colIdx);
}
bool LongEqualiser(RowData* r1, RowData* r2, int colIdx)
{
return *r1->getLong(colIdx) == *r2->getLong(colIdx);
}
bool TimestampEqualiser(RowData* r1, RowData* r2, int colIdx)
{
return *r1->getLong(colIdx) == *r2->getLong(colIdx);
}
std::string extractAggFunction(const std::string& input)
{
std::regex aggRegex(R"((?:MAX|COUNT|SUM|MIN|AVG|last_string_value_without_retract))", std::regex_constants::icase);
std::smatch match;
if (std::regex_search(input, match, aggRegex)) {
return match.str();
} else {
return "NONE";
}
}
std::vector<std::string> GroupAggFunction::handleInputTypes()
{
std::vector<std::string> types;
for (const std::string& inputType : description["inputTypes"]) {
types.push_back(inputType);
auto typeId = LogicalType::flinkTypeToOmniTypeId(inputType);
if (typeId == DataTypeId::OMNI_INT) {
equalisers.push_back(IntEqualiser);
} else if (typeId == DataTypeId::OMNI_LONG) {
equalisers.push_back(LongEqualiser);
} else if (
typeId == DataTypeId::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE ||
typeId == DataTypeId::OMNI_TIMESTAMP_WITH_LOCAL_TIME_ZONE) {
equalisers.push_back(TimestampEqualiser);
} else {
equalisers.push_back(nullptr);
LOG("Warning equaliser for type " + inputType + " is not supported yet");
}
}
return types;
}
std::map<int, int> GroupAggFunction::handleDistinctInfo()
{
std::map<int, int> distinctInfoMap;
for (DistinctInfo info : distinctInfos) {
if (info.filterArgs.size() != info.aggIndexes.size()) {
std::cerr << "Error: filterArgs and aggIndexes size mismatch!" << std::endl;
continue;
}
for (size_t i = 0; i < info.filterArgs.size(); i++) {
distinctInfoMap[info.aggIndexes[i]] = info.filterArgs[i];
}
}
return distinctInfoMap;
}
void GroupAggFunction::open(const Configuration& parameters)
{
LOG("GroupAggFunction open() running");
omnistream::RowType accRowType(true, this->accTypes);
auto accRowTypeInfo = InternalTypeInfo::ofRowType(&accRowType);
std::string accStateName = "accState";
ValueStateDescriptor<RowData*>* accDesc = new ValueStateDescriptor<RowData*>(accStateName, accRowTypeInfo);
accDesc->SetStateSerializer(accRowTypeInfo->getTypeSerializer());
accState = static_cast<StreamingRuntimeContext<RowData*>*>(getRuntimeContext())->getState<RowData*>(accDesc);
if (dynamic_cast<RocksdbValueState<RowData*, VoidNamespace, RowData*>*>(accState)) {
this->backend = 2;
}
int accStartingIndex = 0;
int aggValueIndex = 0;
InitAggFunctions(accStartingIndex, aggValueIndex);
if (indexOfCountStar != -1) {
auto* function = new CountFunction(-1, "BIGINT", accStartingIndex, -1, -1);
function->setCountStart(true);
functions.push_back(function);
accStartingIndex++;
}
LOG("group agg accStartingIndex: " << accStartingIndex);
LOG("group agg accumulatorArity: " << accumulatorArity);
if (accStartingIndex != accumulatorArity) {
throw std::runtime_error("GroupAggFunction open: accStartingIndex does not match accumulatorArity");
}
if (aggValueIndex != static_cast<int>(aggValueTypes.size())) {
throw std::runtime_error("GroupAggFunction open: aggValueIndex does not match aggValueTypes size");
}
aggregateCallsCount = description["aggInfoList"]["aggregateCalls"].size();
resultRow = new JoinedRowData();
reUsePrevAggValue = BinaryRowData::createBinaryRowDataWithMem(functions.size());
LOG("init reUsePrevAggValue getArity : " << reUsePrevAggValue->getArity());
reUseNewAggValue = BinaryRowData::createBinaryRowDataWithMem(functions.size());
sharedAccmulators = BinaryRowData::createBinaryRowDataWithMem(accTypes.size());
}
void GroupAggFunction::InitAggFunctions(int& accStartingIndex, int& aggValueIndex)
{
vector<string> types = handleInputTypes();
map<int, int> distinctInfoMap = handleDistinctInfo();
int aggFuncIndex = 0;
for (const auto& aggCall : description["aggInfoList"]["aggregateCalls"]) {
string aggTypeStr = aggCall["name"];
string aggregationFunction = aggCall["aggregationFunction"];
string aggType = extractAggFunction(aggTypeStr);
int filterIndex = aggCall["filterArg"];
int aggIndex =
aggCall["argIndexes"].get<vector<int>>().empty() ? -1 : aggCall["argIndexes"].get<vector<int>>()[0];
string aggDataType = aggIndex == -1 ? "BIGINT" : types[aggIndex];
AggsHandleFunction* function = nullptr;
bool shouldDoRetract = false;
if (aggType == "AVG") {
function = new AverageFunction(
aggIndex, aggDataType, accStartingIndex, accStartingIndex + 1, aggValueIndex, filterIndex);
} else if (aggType == "COUNT") {
if (distinctInfoMap.find(aggFuncIndex) != distinctInfoMap.end()) {
filterIndex = distinctInfoMap[aggFuncIndex];
auto* distinctFunction = new CountDistinctFunction(
aggIndex, aggDataType, accStartingIndex, aggValueIndex, aggFuncIndex, filterIndex);
distinctFunction->open(new PerKeyStateDataViewStore(
dynamic_cast<StreamingRuntimeContext<RowData*>*>(getRuntimeContext())));
function = distinctFunction;
} else {
function = new CountFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, filterIndex);
}
} else if (aggType == "MAX") {
function =
new MinMaxFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, MAX_FUNC, filterIndex);
} else if (aggType == "MIN") {
function =
new MinMaxFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, MIN_FUNC, filterIndex);
} else if (aggType == "SUM") {
shouldDoRetract = (aggregationFunction.find("WithRetract") != std::string::npos) ? true : shouldDoRetract;
int count0Index = shouldDoRetract ? accStartingIndex + 1 : -1;
SumFunction* sumFunction =
new SumFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, filterIndex);
sumFunction->setRetraction(count0Index);
function = sumFunction;
} else if (aggType == "last_string_value_without_retract") {
function = new LastStringValueFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex);
} else {
throw runtime_error("Unsupported aggregate type: " + aggTypeStr);
}
functions.push_back(function);
accStartingIndex += ((aggType == "AVG") || (aggType == "SUM" && shouldDoRetract)) ? 2 : 1;
aggValueIndex++;
aggFuncIndex++;
}
}
JoinedRowData* GroupAggFunction::getResultRow()
{
return resultRow;
}
void GroupAggFunction::processElement(RowData* input, Context* ctx, TimestampedCollector* out)
{
bool firstRow;
bool isEqual = true;
RowData* currentKey = ctx->getCurrentKey();
RowData* accumulators = accState->value();
if (accumulators == nullptr) {
if (!currentKey) {
LOG("current key is nullptr");
throw std::runtime_error("current key is nullptr");
}
RowData* updatedKey = currentKey->copy();
ctx->setCurrentKey(updatedKey);
currentKey = updatedKey;
if (RowDataUtil::isRetractMsg(input->getRowKind())) {
return;
}
firstRow = true;
accumulators = BinaryRowData::createBinaryRowDataWithMem(accumulatorArity);
auto binRowAcc = static_cast<BinaryRowData*>(accumulators);
for (int i = 0; i < accumulatorArity; i++) {
binRowAcc->setNullAt(i);
}
accState->update(accumulators);
} else {
firstRow = false;
}
for (auto& func : functions) {
func->setAccumulators(accumulators);
}
for (int i = 0; i < aggregateCallsCount; ++i) {
functions[i]->getValue(reUsePrevAggValue);
if (RowDataUtil::isAccumulateMsg(input->getRowKind())) {
functions[i]->accumulate(input);
} else {
functions[i]->retract(input);
}
functions[i]->getValue(reUseNewAggValue);
functions[i]->getAccumulators(reinterpret_cast<BinaryRowData*>(accumulators));
}
if (!recordCounter->recordCountIsZero(accumulators)) {
if (!firstRow) {
for (int i = 0; i < aggregateCallsCount; i++) {
if (!functions[i]->equaliser(reUsePrevAggValue, reUseNewAggValue)) {
isEqual = false;
break;
}
}
if (stateRetentionTime <= 0 && isEqual) {
return;
}
if (generateUpdateBefore) {
resultRow->replace(currentKey, reUsePrevAggValue)->setRowKind(RowKind::UPDATE_BEFORE);
out->collect(resultRow);
}
resultRow->replace(currentKey, reUseNewAggValue)->setRowKind(RowKind::UPDATE_AFTER);
} else {
resultRow->replace(currentKey, reUseNewAggValue)->setRowKind(RowKind::INSERT);
}
out->collect(resultRow);
} else {
if (!firstRow) {
resultRow->replace(currentKey, reUsePrevAggValue)->setRowKind(RowKind::DELETE);
out->collect(resultRow);
}
accState->clear();
}
}
void GroupAggFunction::processBatchColumnar(
omnistream::VectorBatch* input, const std::vector<RowInfo>& groupInfo, RowData* accumulators)
{
std::vector<int> accumulateIndices;
std::vector<int> retractIndices;
for (const RowInfo& info : groupInfo) {
if (RowDataUtil::isAccumulateMsg(info.rowKind)) {
accumulateIndices.push_back(info.rowIndex);
} else {
retractIndices.push_back(info.rowIndex);
}
}
for (auto& function : functions) {
function->getValue(reUsePrevAggValue);
if (!accumulateIndices.empty()) {
function->accumulate(input, accumulateIndices);
}
if (!retractIndices.empty()) {
function->retract(input, retractIndices);
}
function->getValue(reUseNewAggValue);
function->getAccumulators(reinterpret_cast<BinaryRowData*>(accumulators));
}
}
void GroupAggFunction::processBatch(
omnistream::VectorBatch* input,
KeyedProcessFunction<RowData*, RowData*, RowData*>::Context& ctx,
TimestampedCollector& out)
{
auto rowCount = input->GetRowCount();
if (rowCount < 0) {
return;
}
std::unordered_map<RowData*, std::vector<RowInfo>> keyToRowIndices;
LOG("getEntireRow rowCount :" << rowCount);
FillRowIndices(input, keyToRowIndices, rowCount);
std::vector<RowData*> resultKeys;
std::vector<RowData*> resultValues;
std::vector<RowKind> resultRowKinds;
for (auto& pair : keyToRowIndices) {
bool isEqual = true;
RowData* currentKey = pair.first;
ctx.setCurrentKey(currentKey);
std::vector<RowInfo>& groupInfo = pair.second;
RowData* accumulators = accState->value();
bool firstRow = accumulators == nullptr;
if (firstRow) {
if (!FirstRowAccumulate(groupInfo, accumulators)) {
continue;
}
} else {
firstRow = false;
}
for (auto& func : functions) {
func->setAccumulators(accumulators);
func->setCurrentGroupKey(currentKey);
func->setBackend(backend);
}
processBatchColumnar(input, groupInfo, accumulators);
LOG("functions loop aggregateCallsCount end");
AssembleResultForBatch(accumulators, isEqual, firstRow, currentKey, resultKeys, resultValues, resultRowKinds);
}
if (backend == 2) {
UpdateAccumulatorsInRocksDB(pendingUpdates);
for (auto& pair : pendingUpdates) {
delete pair.second;
}
pendingUpdates.clear();
for (auto& func : functions) {
func->updateInnerState();
}
}
ClearEnv(input, resultKeys, resultValues, resultRowKinds, out, keyToRowIndices);
LOG("GroupAggFunction processBatch end");
}
void GroupAggFunction::deleteRowData(vector<RowData*>& rowVector)
{
for (auto row : rowVector) {
delete row;
}
rowVector.clear();
}
void GroupAggFunction::setInt(
omniruntime::vec::VectorBatch* outputBatch, int numRows, int colIndex, std::vector<RowData*> vec)
{
auto* vector = new omniruntime::vec::Vector<int64_t>(numRows);
for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
if (vec[rowIndex]->isNullAt(colIndex)) {
vector->SetNull(rowIndex);
} else {
vector->SetValue(rowIndex, *vec[rowIndex]->getInt(colIndex));
}
}
outputBatch->Append(vector);
}
void GroupAggFunction::setLong(
omniruntime::vec::VectorBatch* outputBatch, int numRows, int colIndex, std::vector<RowData*> vec)
{
auto* vector = new omniruntime::vec::Vector<int64_t>(numRows);
for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
if (vec[rowIndex]->isNullAt(colIndex)) {
vector->SetNull(rowIndex);
} else {
vector->SetValue(rowIndex, *vec[rowIndex]->getLong(colIndex));
}
}
outputBatch->Append(vector);
}
void GroupAggFunction::setString(
omniruntime::vec::VectorBatch* outputBatch, int numRows, int colIndex, std::vector<RowData*> vec)
{
auto* vector = new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(numRows);
for (int rowIndex = 0; rowIndex < numRows; rowIndex++) {
if (vec[rowIndex]->isNullAt(colIndex)) {
vector->SetNull(rowIndex);
} else {
std::string_view strView = vec[rowIndex]->getStringView(colIndex);
vector->SetValue(rowIndex, strView);
}
}
outputBatch->Append(vector);
}
omnistream::VectorBatch* GroupAggFunction::createOutputBatch(
std::vector<RowData*> collectedKeys, std::vector<RowData*> collectedValues, std::vector<RowKind> rowKinds)
{
int numColumns = outputTypes.size();
auto* outputRowType = new std::vector<omniruntime::type::DataTypeId>;
for (const auto& typeStr : outputTypes) {
outputRowType->push_back(LogicalType::flinkTypeToOmniTypeId(typeStr));
}
int numRows = collectedKeys.size();
int keySize = collectedKeys[0]->getArity();
auto* outputBatch = new omnistream::VectorBatch(numRows);
for (int colIndex = 0; colIndex < numColumns; colIndex++) {
std::vector<RowData*> vec = colIndex < keySize ? collectedKeys : collectedValues;
int offset = colIndex < keySize ? 0 : keySize;
switch (outputRowType->at(colIndex)) {
case DataTypeId::OMNI_LONG: {
setLong(outputBatch, numRows, colIndex - offset, vec);
break;
}
case DataTypeId::OMNI_TIMESTAMP: {
setLong(outputBatch, numRows, colIndex - offset, vec);
break;
}
case DataTypeId::OMNI_INT: {
setInt(outputBatch, numRows, colIndex - offset, vec);
break;
}
case DataTypeId::OMNI_DOUBLE: {
setLong(outputBatch, numRows, colIndex - offset, vec);
break;
}
case DataTypeId::OMNI_BOOLEAN: {
setInt(outputBatch, numRows, colIndex - offset, vec);
break;
}
case DataTypeId::OMNI_VARCHAR: {
setString(outputBatch, numRows, colIndex - offset, vec);
break;
}
default: {
delete outputRowType;
delete outputBatch;
LOG("Unsupported column type in inputRow (createOutputBatch). colIndex : " << colIndex);
throw std::runtime_error("Unsupported column type in inputRow");
}
}
}
for (int rowIndex = 0; rowIndex < numRows; ++rowIndex) {
outputBatch->setRowKind(rowIndex, rowKinds[rowIndex]);
}
delete outputRowType;
return outputBatch;
}
std::vector<int32_t> GroupAggFunction::getKeyedTypes(
const std::vector<int32_t> keyedIndex, const std::vector<std::string> inputTypes)
{
std::vector<int32_t> keyedTypes;
for (int32_t index : keyedIndex) {
if (index >= 0 && index < static_cast<int32_t>(inputTypes.size())) {
keyedTypes.push_back(LogicalType::flinkTypeToOmniTypeId(inputTypes[index]));
}
}
return keyedTypes;
}
void GroupAggFunction::collectOutputBatch(TimestampedCollector out, omnistream::VectorBatch* outputBatch)
{
out.collect(outputBatch);
}
void GroupAggFunction::close()
{
}
ValueState<RowData*>* GroupAggFunction::getValueState()
{
return accState;
}
bool GroupAggFunction::FirstRowAccumulate(std::vector<RowInfo>& groupInfo, RowData*& accumulators)
{
auto entireIter = groupInfo.begin();
while (entireIter != groupInfo.end()) {
if (RowDataUtil::isRetractMsg(entireIter->rowKind)) {
entireIter = groupInfo.erase(entireIter);
} else {
break;
}
}
if (groupInfo.empty()) {
return false;
}
accumulators = BinaryRowData::createBinaryRowDataWithMem(accumulatorArity);
for (auto& func : functions) {
func->createAccumulators(dynamic_cast<BinaryRowData*>(accumulators));
}
return true;
}
void GroupAggFunction::ClearEnv(
omnistream::VectorBatch* input,
std::vector<RowData*> resultKeys,
std::vector<RowData*> resultValues,
std::vector<RowKind> resultRowKinds,
TimestampedCollector& out,
std::unordered_map<RowData*, std::vector<RowInfo>> keyToRowIndices)
{
delete input;
if (!resultKeys.empty()) {
resultBatch = createOutputBatch(resultKeys, resultValues, resultRowKinds);
collectOutputBatch(out, resultBatch);
for (auto& pair : keyToRowIndices) {
delete pair.first;
}
keyToRowIndices.clear();
resultKeys.clear();
deleteRowData(resultValues);
resultRowKinds.clear();
}
}
void GroupAggFunction::AssembleResultForBatch(
RowData* accumulators,
bool isEqual,
bool firstRow,
RowData* currentKey,
std::vector<RowData*>& resultKeys,
std::vector<RowData*>& resultValues,
std::vector<RowKind>& resultRowKinds)
{
if (!recordCounter->recordCountIsZero(accumulators)) {
if (backend == 2) {
pendingUpdates.emplace(currentKey, accumulators);
} else {
accState->update(accumulators);
}
if (!firstRow) {
if (EndAssemble(isEqual)) {
return;
}
if (generateUpdateBefore) {
resultKeys.push_back(currentKey);
resultValues.push_back(reUsePrevAggValue->copy());
resultRowKinds.push_back(RowKind::UPDATE_BEFORE);
}
resultKeys.push_back(currentKey);
resultValues.push_back(reUseNewAggValue->copy());
resultRowKinds.push_back(RowKind::UPDATE_AFTER);
} else {
resultKeys.push_back(currentKey);
resultValues.push_back(reUseNewAggValue->copy());
resultRowKinds.push_back(RowKind::INSERT);
}
} else {
if (!firstRow) {
resultKeys.push_back(currentKey);
resultValues.push_back(reUsePrevAggValue->copy());
resultRowKinds.push_back(RowKind::DELETE);
}
accState->clear();
}
}
void GroupAggFunction::FillRowIndices(
omnistream::VectorBatch* input, std::unordered_map<RowData*, std::vector<RowInfo>>& keyToRowIndices, int rowCount)
{
for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
auto key = groupByKeySelector->getKey(input, rowIndex);
RowKind rowKind = input->getRowKind(rowIndex);
auto it = keyToRowIndices.find(key);
if (it != keyToRowIndices.end()) {
it->second.push_back(RowInfo{rowIndex, rowKind});
delete key;
} else {
keyToRowIndices[key] = {RowInfo{rowIndex, rowKind}};
}
}
}
void GroupAggFunction::AssembleResultForElement(
RowData* accumulators, bool isEqual, bool firstRow, RowData* currentKey, TimestampedCollector& out)
{
if (!recordCounter->recordCountIsZero(accumulators)) {
if (!firstRow) {
if (EndAssemble(isEqual)) {
return;
}
if (generateUpdateBefore) {
resultRow->replace(currentKey, reUsePrevAggValue)->setRowKind(RowKind::UPDATE_BEFORE);
out.collect(resultRow);
}
resultRow->replace(currentKey, reUseNewAggValue)->setRowKind(RowKind::UPDATE_AFTER);
} else {
resultRow->replace(currentKey, reUseNewAggValue)->setRowKind(RowKind::INSERT);
}
out.collect(resultRow);
} else {
if (!firstRow) {
resultRow->replace(currentKey, reUsePrevAggValue)->setRowKind(RowKind::DELETE);
out.collect(resultRow);
}
accState->clear();
}
}
bool GroupAggFunction::EndAssemble(bool isEqual)
{
for (int i = 0; i < aggregateCallsCount; i++) {
if (!functions[i]->equaliser(reUsePrevAggValue, reUseNewAggValue)) {
isEqual = false;
break;
}
}
if (stateRetentionTime <= 0 && isEqual) {
return true;
}
return false;
}
void GroupAggFunction::UpdateAccumulatorsInRocksDB(std::unordered_map<RowData*, RowData*>& pendingUpdates)
{
accState->updateByBatch(pendingUpdates);
}