/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#ifndef FLINK_TNEL_SHARED_DISTINCT_COUNT_CONTAINER_FUNCTION_H
#define FLINK_TNEL_SHARED_DISTINCT_COUNT_CONTAINER_FUNCTION_H

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <unordered_map>
#include <vector>
#include "../AggsHandleFunction.h"
#include "../table/runtime/dataview/StateDataViewStore.h"
#include "../runtime/state/VoidNamespace.h"

class SharedDistinctCountContainerFunction : public AggsHandleFunction {
public:
    explicit SharedDistinctCountContainerFunction(std::string stateName = "distinct_acc_shared");

    void addDistinctEntry(int aggFuncIndex, const std::string& aggType, int aggIdx, int filterIndex,
                          const std::string& inputType);
    void finalizeEntries();
    void bindEntryAccValueIndex(int aggFuncIndex, int accIndex, int valueIndex);

    void setWindowSize(int windowSize) override {}
    bool equaliser(BinaryRowData* r1, BinaryRowData* r2) override;
    void open(StateDataViewStore* store);
    void accumulate(RowData* accInput) override;
    void accumulate(omnistream::VectorBatch* input, const std::vector<int>& indices) override;
    void retract(RowData* retractInput) override {}
    void retract(omnistream::VectorBatch* input, const std::vector<int>& indices) override {}
    void merge(RowData* otherAcc) override;
    void setAccumulators(RowData* acc) override;
    void resetAccumulators() override;
    void getAccumulators(BinaryRowData* accumulators) override;
    void createAccumulators(BinaryRowData* accumulators) override;
    void getValue(BinaryRowData* aggValue) override;
    void cleanup() override;
    void close() override;
    void setCurrentGroupKey(RowData* key) override;
    void updateInnerState() override;

    void bindAccValueIndex(int accStartIndex, int valueStartIndex) override;
    int accumulatorSlots() const override { return static_cast<int>(entries_.size()); }
    bool hasAggOutput() const override { return !entries_.empty(); }

private:
    struct DistinctEntry {
        int aggFuncIndex = -1;
        std::string aggType;
        int aggIdx = -1;
        int filterIndex = -1;
        std::string inputType;
        omniruntime::type::DataTypeId typeId = omniruntime::type::DataTypeId::OMNI_LONG;
        std::size_t groupIndex = 0;
        std::uint8_t bitOffset = 0;
        int accIndex = -1;
        int valueIndex = -1;
        long aggCount = -1L;
        bool valueIsNull = true;
    };

    using PendingDistinctUpdates = std::unordered_map<RowData*, std::vector<std::tuple<long, long>>>;
    struct DistinctGroup {
        int aggIdx = -1;
        omniruntime::type::DataTypeId typeId = omniruntime::type::DataTypeId::OMNI_LONG;
        std::string stateName;
        std::vector<DistinctEntry*> groupEntries;
        KeyedStateMapViewWithKeysNullable<VoidNamespace, long, long>* distinctMapView = nullptr;
        PendingDistinctUpdates pendingDistinctUpdates;
    };


    bool shouldAccumulateForEntry(DistinctEntry* entry, RowData* inputRow) const;
    std::uint64_t collectCandidateMask(RowData* inputRow, const DistinctGroup& group) const;
    void applyMaskDelta(std::size_t groupIndex, std::uint64_t deltaMask);
    long getRowFieldValue(RowData* row, int aggIdx, omniruntime::type::DataTypeId typeId, bool& isNull) const;
    void getOrCreateGroup( DistinctEntry* entry);

    std::string stateName_;
    std::vector<DistinctEntry*> entries_;
    std::vector<DistinctGroup> groups_;
    std::unordered_map<int, int> entryIndexToAggFuncIndex;
    std::unordered_map<int, int> distinctGroupMap;
    bool finalized_ = false;
    StateDataViewStore* store_ = nullptr;
    RowData* currentGroupKey_ = nullptr;
    long getRowFieldValueFromVB(omnistream::VectorBatch* input, int columnIdx,int rowIdx ,omniruntime::type::DataTypeId typeId, bool& isNull) const;

};

#endif // FLINK_TNEL_SHARED_DISTINCT_COUNT_CONTAINER_FUNCTION_H