/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#ifndef FLINK_TNEL_LOCALSLICINGWINDOWAGGOPERATOR_H
#define FLINK_TNEL_LOCALSLICINGWINDOWAGGOPERATOR_H

#include <memory>
#include <regex>
#include "streaming/api/operators/AbstractStreamOperator.h"
#include "table/data/JoinedRowData.h"
#include "table/typeutils/BinaryRowDataSerializer.h"
#include "streaming/api/operators/TimestampedCollector.h"
#include "test/core/operators/OutputTest.h"
#include "table/runtime/operators/window/WindowKey.h"
#include "table/runtime/operators/window/slicing/SliceAssigners.h"
#include "table/runtime/generated/AggsHandleFunction.h"
#include "streaming/api/operators/OneInputStreamOperator.h"
#include "streaming/api/watermark/Watermark.h"
#include "core/include/common.h"
#include "table/runtime/keyselector/KeySelector.h"

class LocalSlicingWindowAggOperator : public AbstractStreamOperator<long>, public OneInputStreamOperator {
public:
    LocalSlicingWindowAggOperator(const nlohmann::json& config, Output* output)
        : AbstractStreamOperator(output),
          description(config)
    {
        this->collector = new TimestampedCollector(this->output);
        inputTypes = config["inputTypes"].get<std::vector<std::string>>();
        outputTypeStr = config["outputTypes"].get<std::vector<std::string>>();
        clock = new ClockService();
        for (const auto& i : outputTypeStr) {
            outputTypes.push_back(LogicalType::flinkTypeToOmniTypeId(i));
        }

        keyedIndex = config["grouping"].get<std::vector<int32_t>>();
        for (int32_t index : keyedIndex) {
            if (index >= 0 && index < static_cast<int32_t>(inputTypes.size())) {
                keyedTypes.push_back(LogicalType::flinkTypeToOmniTypeId(inputTypes[index]));
            }
        }
        keySelector = new KeySelector<std::shared_ptr<RowData>>(keyedTypes, keyedIndex);
        windowRow = new GenericRowData(1);
        accWindowRow = new JoinedRowData();
        resultRow = new JoinedRowData();
        // todo: This only works for fixed length valueType
        sliceAssigner = AssignerAtt::createSliceAssigner(description);
        if (description.contains("timeAttributeIndex")) {
            nlohmann::json rowtimeIndex = description["timeAttributeIndex"];
            rowtimeIndexVal = rowtimeIndex.get<long>();
        } else {
            rowtimeIndexVal = -1;
        }
        windowInterval = sliceAssigner->getSliceEndInterval();
    }
    void open() override;

    const char* getName() override;
    void close() override;
    void processBatch(StreamRecord* input) override;
    std::string getTypeName() override;

    void ProcessWatermark(Watermark* mark) override;
    void processElement(StreamRecord* element) override {};
    Output* getOutput();

    static long getNextTriggerWatermark(long watermark, long interval)
    {
        if (watermark == INT64_MAX) {
            return watermark;
        }
        long remainder = (interval <= 0) ? 0 : (watermark % interval);
        long start = remainder < 0L ? watermark - (remainder + interval) : watermark - remainder;
        long triggerWatermark = start + interval - 1L;

        return triggerWatermark > watermark ? triggerWatermark : triggerWatermark + interval;
    }

    void processWatermarkStatus(WatermarkStatus* watermarkStatus) override
    {
        output->emitWatermarkStatus(watermarkStatus);
    }

    void initializeState(StreamTaskStateInitializerImpl* initializer, TypeSerializer* keySerializer) override
    {
        // First do the shared initialization step
        INFO_RELEASE(
            "LocalSlicingWindowAggOperator initializeState with initializer, operatorID: "
            << OneInputStreamOperator::GetOperatorID().toString());
        AbstractStreamOperator<long>::SetOperatorID(OneInputStreamOperator::GetOperatorID().toString());
        AbstractStreamOperator<long>::initializeState(initializer, keySerializer);
    }

    void notifyCheckpointComplete(long checkpointId) override
    {
        AbstractStreamOperator<long>::notifyCheckpointComplete(checkpointId);
    }

    void notifyCheckpointAborted(long checkpointId) override
    {
        AbstractStreamOperator<long>::notifyCheckpointAborted(checkpointId);
    }

    static std::string extractAggFunction(const std::string& input)
    {
        std::regex aggRegex(R"((?:MAX|COUNT|SUM|MIN|AVG))", std::regex_constants::icase);
        std::smatch match;
        if (std::regex_search(input, match, aggRegex)) {
            return match.str();
        } else {
            return "NONE";
        }
    }
    void eraseMsg(std::vector<std::unique_ptr<RowData>>& entireRows);

private:
    nlohmann::json description;
    std::vector<std::string> accTypes;
    std::vector<std::string> aggValueTypes;
    int accumulatorArity = 0;
    std::vector<AggsHandleFunction*> functions;
    int aggregateCallsCount = 0;
    GenericRowData* windowRow;
    JoinedRowData* accWindowRow;
    JoinedRowData* resultRow;
    BinaryRowData* reUseAggValue;
    BinaryRowData* reUseAccumulator;
    std::unordered_map<WindowKey, std::vector<std::unique_ptr<RowData>>> bundle;
    std::vector<std::string> inputTypes;
    std::vector<std::string> outputTypeStr; // todo: remove from variables
    std::vector<omniruntime::type::DataTypeId> outputTypes;

    std::vector<int32_t> keyedTypes;
    KeySelector<std::shared_ptr<RowData>>* keySelector;
    std::vector<int32_t> keyedIndex;
    SliceAssigner* sliceAssigner;
    long currentWatermark = 0;
    long nextTriggerWatermark = 0;
    long windowInterval = 0;

    TimestampedCollector* collector;
    omnistream::VectorBatch* resultBatch = nullptr;

    void AccumulateOrRetract(const std::vector<std::unique_ptr<RowData>>& entireRows);
    bool SendAccResults(Watermark* mark);

    void SetLong(omniruntime::vec::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
    void SetInt(omniruntime::vec::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
    std::vector<WindowKey> invertOrder;
    int rowtimeIndexVal;
    ClockService* clock;
    void ExtractFunction();
    void SetStringVectorBatch(omnistream::VectorBatch* outputBatch, int rowIndex, int colIndex, RowData* collectedRow);
};

#endif