* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "LocalSlicingWindowAggOperator.h"
#include "runtime/generated/function/AverageFunction.h"
#include "runtime/generated/function/CountDistinctFunction.h"
#include "runtime/dataview/PerKeyStateDataViewStore.h"
#include "runtime/generated/function/CountFunction.h"
#include "runtime/generated/function/MinMaxFunction.h"
#include "runtime/generated/function/SumFunction.h"
#include "table/runtime/generated/function/EmptyNamespaceFunction.h"
#include <iostream>
#include "table/data/util/RowDataUtil.h"
#include "streaming/api/operators/TimestampedCollector.h"
void LocalSlicingWindowAggOperator::open()
{
aggregateCallsCount = description["aggInfoList"]["aggregateCalls"].size();
if (aggregateCallsCount == 0) {
AggsHandleFunction *function = new EmptyNamespaceFunction();
functions.push_back(function);
aggregateCallsCount = 1;
reUseAccumulator = BinaryRowData::createBinaryRowDataWithMem(0);
reUseAggValue = BinaryRowData::createBinaryRowDataWithMem(0);
return;
}
accTypes = description["aggInfoList"]["accTypes"].get<std::vector<std::string>>();
aggValueTypes = description["aggInfoList"]["aggValueTypes"].get<std::vector<std::string>>();
reUseAccumulator = BinaryRowData::createBinaryRowDataWithMem(accTypes.size());
reUseAggValue = BinaryRowData::createBinaryRowDataWithMem(aggValueTypes.size());
accumulatorArity = accTypes.size();
ExtractFunction();
}
void LocalSlicingWindowAggOperator::ExtractFunction()
{
int accStartingIndex = 0;
int aggValueIndex = 0;
int aggFuncIndex = 0;
std::vector<std::string> types;
for (const std::string& inputType : description["inputTypes"]) {
types.push_back(inputType);
}
for (const auto& aggCall : description["aggInfoList"]["aggregateCalls"]) {
LOG("aggFuncIndex: " << aggFuncIndex)
std::string aggTypeStr = aggCall["name"];
std::string aggType = extractAggFunction(aggTypeStr);
int filterIndex = aggCall["filterArg"];
int aggIndex = aggCall["argIndexes"].get<std::vector<int>>().empty() ? -1
: aggCall["argIndexes"].get<std::vector<int>>()[0];
std::string aggDataType = aggIndex == -1 ? "NULL" : types[aggIndex];
AggsHandleFunction *function;
if (aggType == "AVG") {
function = new AverageFunction(aggIndex, aggDataType, accStartingIndex, accStartingIndex + 1,
aggValueIndex, filterIndex);
} else if (aggType == "COUNT") {
function = new CountFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, filterIndex);
} else if (aggType == "MAX") {
function = new MinMaxFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, MAX_FUNC,
filterIndex);
} else if (aggType == "MIN") {
function = new MinMaxFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, MIN_FUNC,
filterIndex);
} else if (aggType == "SUM") {
function = new SumFunction(aggIndex, aggDataType, accStartingIndex, aggValueIndex, filterIndex);
} else {
throw std::runtime_error("Unsupported aggregate type: " + aggTypeStr);
}
std::cout << "filterIndex: " << filterIndex << std::endl;
functions.push_back(function);
accStartingIndex += aggType == "AVG" ? 2 : 1;
aggValueIndex++;
aggFuncIndex++;
}
}
void LocalSlicingWindowAggOperator::processBatch(StreamRecord *input) {
auto record = std::unique_ptr<StreamRecord>(input);
auto batch = std::unique_ptr<omnistream::VectorBatch>(
reinterpret_cast<omnistream::VectorBatch*>(record->getValue()));
if (!batch) {
return;
}
auto rowCount = batch->GetRowCount();
if (!batch || rowCount < 0) {
return;
}
std::vector<int64_t> sliceEndArr(rowCount);
for (int64_t i = 0; i < batch->GetRowCount(); i++) {
sliceEndArr[i] = sliceAssigner->assignSliceEnd(batch.get(), i, clock);
}
for (int64_t row = 0; row < rowCount; ++row) {
long rowTime = sliceEndArr[row];
auto keyRow = keySelector->getKey(batch.get(), row);
auto currentRow = std::unique_ptr<RowData>(batch->extractRowData(row));
WindowKey windowKey(rowTime, keyRow);
auto it = bundle.find(windowKey);
if (it != bundle.end()) {
it->second.push_back(std::move(currentRow));
} else {
invertOrder.push_back(windowKey);
std::vector<std::unique_ptr<RowData>> vec;
vec.push_back(std::move(currentRow));
bundle.emplace(windowKey, std::move(vec));
}
}
}
void LocalSlicingWindowAggOperator::ProcessWatermark(Watermark *mark)
{
LOG("LocalSlicingWindowAggOperator::processWatermark start: " << mark->getTimestamp())
if (mark->getTimestamp() > currentWatermark) {
currentWatermark = mark->getTimestamp();
if (currentWatermark >= nextTriggerWatermark && bundle.size() > 0) {
if (!SendAccResults(mark)) {
output->emitWatermark(mark);
return;
}
}
}
LOG("LocalSlicingWindowAggOperator::processWatermark end: " << mark->getTimestamp())
if (timeServiceManager != nullptr) {
timeServiceManager->advanceWatermark(mark);
}
output->emitWatermark(mark);
}
bool LocalSlicingWindowAggOperator::SendAccResults(Watermark *mark)
{
int numRows = invertOrder.size();
int numColumns = outputTypes.size();
auto outputBatch = omnistream::VectorBatch::CreateVectorBatch(numRows,outputTypes);
int currentRowNum = 0;
for (const WindowKey& currentKey : invertOrder) {
auto& entireRows = bundle[currentKey];
eraseMsg(entireRows);
if (entireRows.empty()) {
continue;
}
RowData* accumulators = BinaryRowData::createBinaryRowDataWithMem(accumulatorArity);
for (auto& func : functions) {
func->createAccumulators(dynamic_cast<BinaryRowData *>(accumulators));
}
for (auto& func : functions) {
func->setAccumulators(accumulators);
}
AccumulateOrRetract(entireRows);
windowRow->setField(0, currentKey.getWindow());
accWindowRow->replace(reUseAccumulator, windowRow);
resultRow->replace(currentKey.getKey().get(), accWindowRow);
for (int colIndex = 0; colIndex < numColumns; ++colIndex){
switch (outputTypes[colIndex]) {
case DataTypeId::OMNI_LONG: {
SetLong(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
case DataTypeId::OMNI_TIMESTAMP: {
SetLong(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
case DataTypeId::OMNI_INT: {
SetInt(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
case DataTypeId::OMNI_DOUBLE: {
SetLong(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
case DataTypeId::OMNI_BOOLEAN: {
SetInt(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
case DataTypeId::OMNI_VARCHAR: {
SetStringVectorBatch(outputBatch, currentRowNum, colIndex, resultRow);
break;
}
default: {
throw std::runtime_error("Unsupported column type in inputRow");
}
}
}
outputBatch->setRowKind(currentRowNum, resultRow->getRowKind());
currentRowNum++;
delete accumulators;
}
collector->collect(outputBatch);
bundle.clear();
invertOrder.clear();
nextTriggerWatermark = getNextTriggerWatermark(currentWatermark, windowInterval);
return true;
}
void LocalSlicingWindowAggOperator::eraseMsg(std::vector<std::unique_ptr<RowData>>& entireRows) {
auto entireIter = entireRows.begin();
while (entireIter != entireRows.end()) {
if (RowDataUtil::isRetractMsg((*entireIter)->getRowKind())) {
entireIter = entireRows.erase(entireIter);
} else {
break;
}
}
}
Output* LocalSlicingWindowAggOperator::getOutput()
{
return this->output;
}
void LocalSlicingWindowAggOperator::close()
{
for (auto func : functions) {
delete func;
}
delete clock;
}
void LocalSlicingWindowAggOperator::SetStringVectorBatch(omnistream::VectorBatch* outputBatch, int rowIndex,
int colIndex, RowData* collectedRow)
{
auto vector = static_cast<omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>> *> (outputBatch->Get(colIndex));
std::string_view strView = collectedRow->getStringView(colIndex);
vector->SetValue(rowIndex, strView);
}
void LocalSlicingWindowAggOperator::SetLong(omniruntime::vec::VectorBatch* outputBatch,
int rowIndex, int colIndex, RowData* collectedRow)
{
auto vector = static_cast<omniruntime::vec::Vector<int64_t>*>(outputBatch->Get(colIndex));
vector->SetValue(rowIndex, *collectedRow->getLong(colIndex));
}
void LocalSlicingWindowAggOperator::SetInt(omniruntime::vec::VectorBatch* outputBatch,
int rowIndex, int colIndex, RowData* collectedRow)
{
auto vector = static_cast<omniruntime::vec::Vector<int64_t>*>(outputBatch->Get(colIndex));
vector->SetValue(rowIndex, *collectedRow->getInt(colIndex));
}
const char *LocalSlicingWindowAggOperator::getName()
{
return "LocalWindowAggOperator";
}
std::string LocalSlicingWindowAggOperator::getTypeName()
{
std::string typeName = "LocalWindowAggOperator";
typeName.append(__PRETTY_FUNCTION__);
return typeName;
}
void LocalSlicingWindowAggOperator::AccumulateOrRetract(const std::vector<std::unique_ptr<RowData>>& entireRows)
{
for (int i = 0; i < aggregateCallsCount; ++i) {
for (auto& entireRow : entireRows) {
if (RowDataUtil::isAccumulateMsg(entireRow->getRowKind())) {
functions[i]->accumulate(entireRow.get());
} else {
functions[i]->retract(entireRow.get());
}
}
functions[i]->getAccumulators(reUseAccumulator);
}
}