/*
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#ifndef OMNISTREAM_DEFAULTOPERATORSTATEBACKEND_H
#define OMNISTREAM_DEFAULTOPERATORSTATEBACKEND_H

#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <stdexcept>
#include <future>
#include <typeinfo>

#include "core/typeutils/TypeSerializer.h"
#include "core/api/common/state/ListState.h"
#include "core/api/common/state/ListStateDescriptor.h"
#include "core/api/common/state/MapStateDescriptor.h"
#include "runtime/state/SnapshotExecutionType.h"
#include "runtime/state/bridge/TaskStateManagerBridge.h"
#include "runtime/state/bridge/OmniTaskBridge.h"

#include "HeapBroadcastState.h"
#include "PartitionableListState.h"
#include "CheckpointStreamFactory.h"
#include "DefaultOperatorStateBackendSnapshotStrategy.h"
#include "OperatorStateBackend.h"
#include "OperatorStateHandle.h"
#include "RegisteredBroadcastStateBackendMetaInfo.h"
#include "RegisteredOperatorStateBackendMetaInfo.h"
#include "SnapshotStrategyRunner.h"
#include "SnapshotResult.h"

class DefaultOperatorStateBackend : public OperatorStateBackend {
public:
    DefaultOperatorStateBackend(bool asynchronousSnapshots,
                                std::shared_ptr<TaskStateManagerBridge> bridge,
                                std::shared_ptr<OmniTaskBridge> omniTaskBridge,
                                DefaultOperatorStateBackendSnapshotStrategy* snapshotStrategy,
                                std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredOperatorStates,
                                std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredBroadcastStates,
                                std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> accessedStatesByName,
                                std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> accessedBroadcastStatesByName)
        : asynchronousSnapshots_(asynchronousSnapshots),
          bridge_(bridge),
          omniTaskBridge_(omniTaskBridge),
          snapshotStrategy_(snapshotStrategy),
          registeredOperatorStates_(registeredOperatorStates),
          registeredBroadcastStates_(registeredBroadcastStates),
          accessedStatesByName_(accessedStatesByName),
          accessedBroadcastStatesByName_(accessedBroadcastStatesByName) {
    }

    std::unordered_set<std::string> getRegisteredStateNames() override {
        std::unordered_set<std::string> nameSet;
        for (const auto& pair : *registeredOperatorStates_) {
            nameSet.insert(pair.first);
        }
        return nameSet;
    }

    std::unordered_set<std::string> getRegisteredBroadcastStateNames() override {
        std::unordered_set<std::string> nameSet;
        for (const auto& pair : *registeredBroadcastStates_) {
            nameSet.insert(pair.first);
        }
        return nameSet;
    }

    void close() {}

    void dispose() override {
        LOG(" DefaultOperatorStateBackend dispose");
        if (!registeredOperatorStates_->empty()) {
            registeredOperatorStates_->clear();
        }
        if (!registeredBroadcastStates_->empty()) {
            registeredBroadcastStates_->clear();
        }
        if (!accessedStatesByName_->empty()) {
            accessedStatesByName_->clear();
        }
        if (!accessedBroadcastStatesByName_->empty()) {
            accessedBroadcastStatesByName_->clear();
        }
    }

    template<typename K, typename V>
    std::shared_ptr<HeapBroadcastState<K, V>> getBroadcastState(MapStateDescriptor<K, V>* stateDescriptor) {
        std::string name = stateDescriptor->getName();
        auto accessedIterator = accessedBroadcastStatesByName_->find(name);
        if (accessedIterator != accessedBroadcastStatesByName_->end()) {
            std::shared_ptr<HeapBroadcastState<K, V>> state = std::dynamic_pointer_cast<HeapBroadcastState<K, V>>(accessedIterator->second);
            return std::dynamic_pointer_cast<HeapBroadcastState<K, V>>(accessedIterator->second);
        }

        std::shared_ptr<HeapBroadcastState<K, V>> resultState = nullptr;
        auto* broadcastStateKeySerializer = stateDescriptor->GetUserKeySerializer();
        auto* broadcastStateValueSerializer = stateDescriptor->GetValueSerializer();

        auto registeredIterator = registeredBroadcastStates_->find(name);

        if (registeredIterator == registeredBroadcastStates_->end()) {
            auto stateMetaInfo = std::make_shared<RegisteredBroadcastStateBackendMetaInfo>(
                    name,
                    OperatorStateHandle::Mode::BROADCAST,
                    broadcastStateKeySerializer,
                    broadcastStateValueSerializer);
            auto internalMap = std::shared_ptr<std::map<K, V>>();
            resultState = std::make_shared<HeapBroadcastState<K, V>>(stateMetaInfo, internalMap);
            registeredBroadcastStates_->emplace(name, resultState);
        } else {
            resultState = std::dynamic_pointer_cast<HeapBroadcastState<K, V>>(registeredIterator->second);
            auto stateMetaInfo = resultState->getStateMetaInfo();
            stateMetaInfo->updateKeySerializer(broadcastStateKeySerializer);
            stateMetaInfo->updateValueSerializer(broadcastStateValueSerializer);
            resultState->setStateMetaInfo(stateMetaInfo);
        }

        accessedBroadcastStatesByName_->emplace(name, (*registeredBroadcastStates_)[name]);

        return resultState;
    }

    template<typename S>
    std::shared_ptr<ListState<S>> getListState(ListStateDescriptor<S>* stateDescriptor)  {
        return getListState(stateDescriptor, OperatorStateHandle::Mode::SPLIT_DISTRIBUTE);
    }

    template<typename S>
    std::shared_ptr<ListState<S>> getUnionListState(ListStateDescriptor<S>* stateDescriptor){
        return getListState(stateDescriptor, OperatorStateHandle::Mode::UNION);
    }

    std::shared_ptr<std::packaged_task<std::shared_ptr<SnapshotResult<OperatorStateHandle>>()>> snapshot(
        long checkpointId,
        long timestamp,
        CheckpointStreamFactory* streamFactory,
        CheckpointOptions* checkpointOptions) override {
        if (registeredOperatorStates_->empty() && registeredBroadcastStates_->empty()) {
            return nullptr;
        }

        auto snapshotStrategyRunner = std::make_unique<SnapshotStrategyRunner<OperatorStateHandle, SnapshotResources>>(
                "DefaultOperatorStateBackend snapshot",
                snapshotStrategy_,
                asynchronousSnapshots_ ? ASYNCHRONOUS : SYNCHRONOUS);

        return snapshotStrategyRunner->snapshot(checkpointId,
                                                timestamp,
                                                streamFactory,
                                                checkpointOptions,
                                                omniTaskBridge_,
                                                "");
    }

private:
    bool asynchronousSnapshots_;
    std::shared_ptr<TaskStateManagerBridge> bridge_;
    std::shared_ptr<OmniTaskBridge> omniTaskBridge_;
    DefaultOperatorStateBackendSnapshotStrategy* snapshotStrategy_;
    std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredOperatorStates_;
    std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredBroadcastStates_;
    std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> accessedStatesByName_;
    std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> accessedBroadcastStatesByName_;

    template<typename S>
    std::shared_ptr<ListState<S>> getListState(ListStateDescriptor<S>* stateDescriptor, OperatorStateHandle::Mode mode) {

        std::string name = stateDescriptor->getName();

        auto accessedIterator = accessedStatesByName_->find(name);
        if (accessedIterator != accessedStatesByName_->end()) {
            auto state = std::dynamic_pointer_cast<PartitionableListState<S>>(accessedIterator->second);
            return std::dynamic_pointer_cast<PartitionableListState<S>>(accessedIterator->second);
        }

        std::shared_ptr<PartitionableListState<S>> resultState = nullptr;
        auto* operatorStateSerializer = stateDescriptor->getStateSerializer();
        auto registeredIterator = registeredOperatorStates_->find(name);

        // 打印当前 registeredOperatorStates_ 中所有的 key,以及本次查询的 name
        std::string registeredKeys;
        for (auto& entry : *registeredOperatorStates_) {
            if (!registeredKeys.empty()) registeredKeys += ", ";
            registeredKeys += entry.first;
        }

        if (registeredIterator == registeredOperatorStates_->end()) {
            auto stateMetaInfo = std::make_shared<RegisteredOperatorStateBackendMetaInfo>(name, mode, operatorStateSerializer);
            auto internalList = std::make_shared<std::vector<S>>();
            resultState = std::make_shared<PartitionableListState<S>>(stateMetaInfo, internalList);
            registeredOperatorStates_->emplace(name, resultState);
        } else {
            resultState = std::dynamic_pointer_cast<PartitionableListState<S>>(registeredIterator->second);
            auto stateMetaInfo = resultState->getStateMetaInfo();
            stateMetaInfo->updateStateSerializer(operatorStateSerializer);
            resultState->setStateMetaInfo(stateMetaInfo);
        }

        accessedStatesByName_->emplace(name, (*registeredOperatorStates_)[name]);
        return resultState;
    }
};

#endif //OMNISTREAM_DEFAULTOPERATORSTATEBACKEND_H