* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef FLINK_TNEL_LOCALSLICINGWINDOWAGGOPERATOR_H
#define FLINK_TNEL_LOCALSLICINGWINDOWAGGOPERATOR_H
#include <memory>
#include <regex>
#include "streaming/api/operators/AbstractStreamOperator.h"
#include "table/data/JoinedRowData.h"
#include "table/typeutils/BinaryRowDataSerializer.h"
#include "streaming/api/operators/TimestampedCollector.h"
#include "test/core/operators/OutputTest.h"
#include "table/runtime/operators/window/WindowKey.h"
#include "table/runtime/operators/window/slicing/SliceAssigners.h"
#include "table/runtime/generated/AggsHandleFunction.h"
#include "streaming/api/operators/OneInputStreamOperator.h"
#include "streaming/api/watermark/Watermark.h"
#include "core/include/common.h"
#include "table/runtime/keyselector/KeySelector.h"
class LocalSlicingWindowAggOperator : public AbstractStreamOperator<long>, public OneInputStreamOperator {
public:
LocalSlicingWindowAggOperator(const nlohmann::json& config, Output* output)
: AbstractStreamOperator(output),
description(config)
{
this->collector = new TimestampedCollector(this->output);
inputTypes = config["inputTypes"].get<std::vector<std::string>>();
outputTypeStr = config["outputTypes"].get<std::vector<std::string>>();
clock = new ClockService();
for (const auto& i : outputTypeStr) {
outputTypes.push_back(LogicalType::flinkTypeToOmniTypeId(i));
}
keyedIndex = config["grouping"].get<std::vector<int32_t>>();
for (int32_t index : keyedIndex) {
if (index >= 0 && index < static_cast<int32_t>(inputTypes.size())) {
keyedTypes.push_back(LogicalType::flinkTypeToOmniTypeId(inputTypes[index]));
}
}
keySelector = new KeySelector<std::shared_ptr<RowData>>(keyedTypes, keyedIndex);
windowRow = new GenericRowData(1);
accWindowRow = new JoinedRowData();
resultRow = new JoinedRowData();
sliceAssigner = AssignerAtt::createSliceAssigner(description);
if (description.contains("timeAttributeIndex")) {
nlohmann::json rowtimeIndex = description["timeAttributeIndex"];
rowtimeIndexVal = rowtimeIndex.get<long>();
} else {
rowtimeIndexVal = -1;
}
windowInterval = sliceAssigner->getSliceEndInterval();
}
void open() override;
const char* getName() override;
void close() override;
void processBatch(StreamRecord* input) override;
std::string getTypeName() override;
void ProcessWatermark(Watermark* mark) override;
void processElement(StreamRecord* element) override {};
Output* getOutput();
static long getNextTriggerWatermark(long watermark, long interval)
{
if (watermark == INT64_MAX) {
return watermark;
}
long remainder = (interval <= 0) ? 0 : (watermark % interval);
long start = remainder < 0L ? watermark - (remainder + interval) : watermark - remainder;
long triggerWatermark = start + interval - 1L;
return triggerWatermark > watermark ? triggerWatermark : triggerWatermark + interval;
}
void processWatermarkStatus(WatermarkStatus* watermarkStatus) override
{
output->emitWatermarkStatus(watermarkStatus);
}
void initializeState(StreamTaskStateInitializerImpl* initializer, TypeSerializer* keySerializer) override
{
INFO_RELEASE(
"LocalSlicingWindowAggOperator initializeState with initializer, operatorID: "
<< OneInputStreamOperator::GetOperatorID().toString());
AbstractStreamOperator<long>::SetOperatorID(OneInputStreamOperator::GetOperatorID().toString());
AbstractStreamOperator<long>::initializeState(initializer, keySerializer);
}
void notifyCheckpointComplete(long checkpointId) override
{
AbstractStreamOperator<long>::notifyCheckpointComplete(checkpointId);
}
void notifyCheckpointAborted(long checkpointId) override
{
AbstractStreamOperator<long>::notifyCheckpointAborted(checkpointId);
}
static std::string extractAggFunction(const std::string& input)
{
std::regex aggRegex(R"((?:MAX|COUNT|SUM|MIN|AVG))", std::regex_constants::icase);
std::smatch match;
if (std::regex_search(input, match, aggRegex)) {
return match.str();
} else {
return "NONE";
}
}
void eraseMsg(std::vector<std::unique_ptr<RowData>>& entireRows);
private:
nlohmann::json description;
std::vector<std::string> accTypes;
std::vector<std::string> aggValueTypes;
int accumulatorArity = 0;
std::vector<AggsHandleFunction*> functions;
int aggregateCallsCount = 0;
GenericRowData* windowRow;
JoinedRowData* accWindowRow;
JoinedRowData* resultRow;
BinaryRowData* reUseAggValue;
BinaryRowData* reUseAccumulator;
std::unordered_map<WindowKey, std::vector<std::unique_ptr<RowData>>> bundle;
std::vector<std::string> inputTypes;
std::vector<std::string> outputTypeStr;
std::vector<omniruntime::type::DataTypeId> outputTypes;
std::vector<int32_t> keyedTypes;
KeySelector<std::shared_ptr<RowData>>* keySelector;
std::vector<int32_t> keyedIndex;
SliceAssigner* sliceAssigner;
long currentWatermark = 0;
long nextTriggerWatermark = 0;
long windowInterval = 0;
TimestampedCollector* collector;
omnistream::VectorBatch* resultBatch = nullptr;
void AccumulateOrRetract(const std::vector<std::unique_ptr<RowData>>& entireRows);
bool SendAccResults(Watermark* mark);
void SetLong(omniruntime::vec::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
void SetInt(omniruntime::vec::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
std::vector<WindowKey> invertOrder;
int rowtimeIndexVal;
ClockService* clock;
void ExtractFunction();
void SetStringVectorBatch(omnistream::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
};
#endif