/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#pragma once

#include "PriorityQueueSetFactory.h"
#include "rocksdb/db.h"
#include "rocksdb/options.h"

#include "core/fs/CloseableRegistry.h"

#include "state/RocksDBWriteBatchWrapper.h"
#include "core/typeutils/TypeSerializer.h"
#include "runtime/state/rocksdb/util/ResourceGuard.h"

#include "runtime/state/InternalKeyContextImpl.h"
#include "runtime/state/KeyGroupRange.h"
#include "runtime/state/LocalRecoveryConfig.h"
#include "runtime/state/CompositeKeySerializationUtils.h"
#include "runtime/state/RocksDbKvStateInfo.h"
#include "runtime/state/RocksdbKeyedStateBackend.h"
#include "runtime/state/RocksDBResourceContainer.h"
#include "runtime/state/restore/RocksDBRestoreOperation.h"
#include "runtime/state/restore/RocksDBNoneRestoreOperation.h"
#include "runtime/state/restore/RocksDBIncrementalRestoreOperation.h"
#include "runtime/state/restore/RocksDBFullRestoreOperation.h"
#include "runtime/state/restore/RocksDBHeapTimersFullRestoreOperation.h"
#include "runtime/snapshot/RocksDBSnapshotStrategyBase.h"
#include "runtime/snapshot/RocksNativeFullSnapshotStrategy.h"
#include "runtime/snapshot/RocksIncrementalSnapshotStrategy.h"
#include "runtime/state/rocksdb/RocksDBStateUploader.h"


namespace fs = std::filesystem;
template <typename K>
class RocksDBKeyedStateBackendBuilder {
public:
    enum class PriorityQueueStateType {
        HEAP,
        ROCKSDB
    };

    RocksDBKeyedStateBackendBuilder(
        std::string operatorIdentifier_,
        fs::path instanceBasePath_,
        std::shared_ptr<RocksDBResourceContainer> optionsContainer_,
        TypeSerializer *keySerializer_,
        int numberOfKeyGroups_,
        KeyGroupRange *keyGroupRange_,
        std::shared_ptr<LocalRecoveryConfig> localRecoveryConfig_,
        PriorityQueueStateType priorityQueueStateType,
        std::vector<std::shared_ptr<KeyedStateHandle>> stateHandles_,
        std::shared_ptr<TaskStateManagerBridge> bridge_,
        std::shared_ptr<OmniTaskBridge> omniTaskBridge_,
        std::shared_ptr<OperatorID> operatorId,
        int alternativeIdx)
        : keySerializer(keySerializer_),
        operatorIdentifier(operatorIdentifier_),
        instanceBasePath(instanceBasePath_),
        optionsContainer(optionsContainer_),
        numberOfKeyGroups(numberOfKeyGroups_),
        keyGroupRange(keyGroupRange_),
        localRecoveryConfig(localRecoveryConfig_),
        priorityQueueStateType_(priorityQueueStateType),
        restoreStateHandles(stateHandles_),
        bridge(bridge_),
        omniTaskBridge(omniTaskBridge_),
        operatorId_(operatorId),
        alternativeIdx_(alternativeIdx)
    {
        instanceRocksDBPath = getInstanceRocksDBPath(instanceBasePath);
    }

    RocksdbKeyedStateBackend<K> *build();

    static std::filesystem::path getInstanceRocksDBPath(const std::filesystem::path& instanceBasePath)
    {
        return instanceBasePath / DB_INSTANCE_DIR_STRING;
    }

    std::shared_ptr<TypeSerializer> keySerializer;

    RocksDBKeyedStateBackendBuilder<K>& setEnableIncrementalCheckpointing(bool enableIncrementalCheckpointing_)
    {
        this->enableIncrementalCheckpointing = enableIncrementalCheckpointing_;
        return *this;
    }

    RocksDBKeyedStateBackendBuilder<K>& setNumberOfTransferringThreads(int numberOfTransferingThreads_)
    {
        this->numberOfTransferingThreads = numberOfTransferingThreads_;
        return *this;
    }

    RocksDBKeyedStateBackendBuilder<K>& setWriteBatchSize(long writeBatchSize_)
    {
        this->writeBatchSize = writeBatchSize_;
        return *this;
    }

private:
    const std::shared_ptr<CloseableRegistry> cancelStreamRegistry;
    static constexpr const char* DB_INSTANCE_DIR_STRING = "db";
    std::string operatorIdentifier;
    std::filesystem::path instanceBasePath;
    std::shared_ptr<RocksDBResourceContainer> optionsContainer;
    int numberOfKeyGroups;
    KeyGroupRange* keyGroupRange;
    std::shared_ptr<LocalRecoveryConfig> localRecoveryConfig;
    PriorityQueueStateType priorityQueueStateType_ = PriorityQueueStateType::ROCKSDB;
    std::vector<std::shared_ptr<KeyedStateHandle>> restoreStateHandles;
    std::function<rocksdb::ColumnFamilyOptions(const std::string&)> columnFamilyOptionsFactory;
    std::filesystem::path instanceRocksDBPath;
    bool enableIncrementalCheckpointing;
    int numberOfTransferingThreads;
    long writeBatchSize = 2097152;
    std::shared_ptr<rocksdb::DB> injectedTestDB;
    std::shared_ptr<TaskStateManagerBridge> bridge;
    std::shared_ptr<OmniTaskBridge> omniTaskBridge;
    std::shared_ptr<OperatorID> operatorId_;
    int alternativeIdx_;

    static void checkAndCreateDirectory(const fs::path& directory)
    {
        if (fs::exists(directory)) {
            if (!fs::is_directory(directory)) {
                throw std::runtime_error("Not a directory: " + directory.string());
            }
        } else {
            if (!fs::create_directories(directory)) {
                throw std::runtime_error(
                    "Could not create RocksDB data directory at " + directory.string());
            }
        }
    }

    void prepareDirectories()
    {
        try {
            checkAndCreateDirectory(instanceBasePath);

            if (fs::exists(instanceRocksDBPath)) {
                fs::remove_all(instanceBasePath);
            }
        } catch (const fs::filesystem_error& e) {
            throw std::runtime_error("Failed to prepare directories: " + std::string(e.what()));
        }
    }

    std::shared_ptr<RocksDBRestoreOperation> getRocksDBRestoreOperation(
            int keyGroupPrefixBytes,
            std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>> *kvStateInformation,
            std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<HeapPriorityQueueSnapshotRestoreWrapperBase>>> registeredPQStates)
    {
        auto dbOptions = optionsContainer->getDbOptions();
        if (restoreStateHandles.empty()) {
            return std::make_shared<RocksDBNoneRestoreOperation>(
                    kvStateInformation,
                    instanceRocksDBPath,
                    dbOptions,
                    columnFamilyOptionsFactory);
        }
        std::shared_ptr<KeyedStateHandle> firstStateHandle = restoreStateHandles[0];
        if (auto incrementalHandle = std::dynamic_pointer_cast<IncrementalKeyedStateHandle>(firstStateHandle)) {
            return std::make_shared<RocksDBIncrementalRestoreOperation<K>>(
                    operatorIdentifier,
                    keyGroupRange,
                    keyGroupPrefixBytes,
                    numberOfTransferingThreads,
                    kvStateInformation,
                    keySerializer,
                    instanceBasePath,
                    instanceRocksDBPath,
                    dbOptions,
                    columnFamilyOptionsFactory,
                    restoreStateHandles,
                    writeBatchSize,
                    omniTaskBridge,
                    operatorId_,
                    alternativeIdx_);
        } else {
            if (priorityQueueStateType_ == PriorityQueueStateType::HEAP) {
                return std::make_shared<RocksDBHeapTimersFullRestoreOperation<K>>(
                        keyGroupRange,
                        numberOfKeyGroups,
                        keySerializer,
                        kvStateInformation,
                        registeredPQStates,
                        instanceRocksDBPath,
                        dbOptions,
                        columnFamilyOptionsFactory,
                        restoreStateHandles,
                        writeBatchSize,
                        omniTaskBridge);
            }

            return std::make_shared<RocksDBFullRestoreOperation<K>>(
                    keyGroupRange,
                    keySerializer,
                    kvStateInformation,
                    instanceRocksDBPath,
                    dbOptions,
                    columnFamilyOptionsFactory,
                    restoreStateHandles,
                    writeBatchSize,
                    omniTaskBridge);
        }
    }

    RocksDBSnapshotStrategyBase* initializeSavepointAndCheckpointStrategies(
            std::shared_ptr<ResourceGuard> rocksDBResourceGuard,
            std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>> *kvStateInformation,
            int keyGroupPrefixBytes,
            rocksdb::DB* db,
            UUID backendUID,
            std::map<long, std::vector<HandleAndLocalPath>> materializedSstFiles,
            long lastCompletedCheckpointId)
    {
        RocksDBSnapshotStrategyBase* checkpointSnapshotStrategy;

        auto stateUploader = std::make_shared<RocksDBStateUploader>(numberOfTransferingThreads);

        if (enableIncrementalCheckpointing) {
            checkpointSnapshotStrategy = new RocksIncrementalSnapshotStrategy(
                                                                        db,
                                                                        rocksDBResourceGuard,
                                                                        keySerializer,
                                                                        kvStateInformation,
                                                                        *keyGroupRange,
                                                                        keyGroupPrefixBytes,
                                                                        localRecoveryConfig,
                                                                        instanceBasePath.string(),
                                                                        backendUID,
                                                                        materializedSstFiles,
                                                                        stateUploader,
                                                                        lastCompletedCheckpointId);
        } else {
            checkpointSnapshotStrategy = new RocksNativeFullSnapshotStrategy(
                                                                        db,
                                                                        rocksDBResourceGuard,
                                                                        keySerializer,
                                                                        kvStateInformation,
                                                                        *keyGroupRange,
                                                                        keyGroupPrefixBytes,
                                                                        localRecoveryConfig,
                                                                        instanceBasePath.string(),
                                                                        backendUID,
                                                                        stateUploader);
        }

        return checkpointSnapshotStrategy;
    }

    std::shared_ptr<PriorityQueueSetFactory> initPriorityQueueFactory(
            int32_t keyGroupPrefixBytes,
            std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>> *kvStateInformation,
            rocksdb::DB* db,
            std::shared_ptr<RocksDBWriteBatchWrapper> writeBatchWrapper
    );
};

template <typename K>
RocksdbKeyedStateBackend<K> *RocksDBKeyedStateBackendBuilder<K>::build() {
    // todo: Memory Management?
    auto kvStateInformation = new std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>>();
    auto registeredPQStates = std::make_shared<std::unordered_map<std::string, std::shared_ptr<HeapPriorityQueueSnapshotRestoreWrapperBase>>>();
    rocksdb::DB* db;
    std::shared_ptr<RocksDBRestoreOperation> restoreOperation;

    int keyGroupPrefixBytes =
        CompositeKeySerializationUtils::computeRequiredBytesInKeyGroupPrefix(numberOfKeyGroups);

    auto capturedOptionsContainer = this->optionsContainer;
    columnFamilyOptionsFactory = [capturedOptionsContainer](const std::string& name) -> rocksdb::ColumnFamilyOptions {
        auto optPtr = capturedOptionsContainer->getColumnOptions();
        return *optPtr;
    };

    try {
        UUID backendUID = UUID::randomUUID();
        std::map<long, std::vector<IncrementalKeyedStateHandle::HandleAndLocalPath>> materializedSstFiles;
        long lastCompletedCheckpointId = -1L;
        auto rocksDBResourceGuard = std::make_shared<ResourceGuard>();

        prepareDirectories();
        restoreOperation = getRocksDBRestoreOperation(keyGroupPrefixBytes, kvStateInformation, registeredPQStates);
        auto restoreResult = restoreOperation->restore();
        db = restoreResult->getDb();

        if (std::dynamic_pointer_cast<RocksDBIncrementalRestoreOperation<K>>(restoreOperation)) {
            backendUID = restoreResult->getBackendUID();
            materializedSstFiles = restoreResult->getRestoredSstFiles();
            lastCompletedCheckpointId = restoreResult->getLastCompletedCheckpointId();
        }

        auto strategy = initializeSavepointAndCheckpointStrategies(
            rocksDBResourceGuard,
            kvStateInformation,
            keyGroupPrefixBytes,
            db,
            backendUID,
            materializedSstFiles,
            lastCompletedCheckpointId);

        auto writeBatchWrapper = std::make_shared<RocksDBWriteBatchWrapper>(
            db,
            optionsContainer->getWriteOptions(),
            writeBatchSize);

        auto priorityQueueSetFactory = initPriorityQueueFactory(
                keyGroupPrefixBytes,
                kvStateInformation,
                db,
                writeBatchWrapper);

        auto keyContext = new InternalKeyContextImpl<K>(keyGroupRange, numberOfKeyGroups);

        return new RocksdbKeyedStateBackend<K>(
            keySerializer.get(),
            keyContext,
            db,
            strategy,
            keyGroupRange,
            kvStateInformation,
            registeredPQStates,
            rocksDBResourceGuard,
            keyGroupPrefixBytes,
            writeBatchWrapper,
            priorityQueueSetFactory,
            bridge,
            omniTaskBridge);
    } catch (const std::exception& e) {
        throw std::runtime_error("build failed." + std::string(e.what()));
    }
}

template <typename K>
std::shared_ptr<PriorityQueueSetFactory> RocksDBKeyedStateBackendBuilder<K>::initPriorityQueueFactory(
        int32_t keyGroupPrefixBytes,
        std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>>* kvStateInformation,
        rocksdb::DB* db,
        std::shared_ptr<RocksDBWriteBatchWrapper> writeBatchWrapper) {
    if (priorityQueueStateType_ == PriorityQueueStateType::HEAP) {
        INFO_RELEASE("The priority queue type of rocksDB backend is HEAP.");
        return std::make_shared<HeapPriorityQueueSetFactory>(keyGroupRange, numberOfKeyGroups, 128);
    }

    if (priorityQueueStateType_ == PriorityQueueStateType::ROCKSDB) {
        INFO_RELEASE("The priority queue type of rocksDB backend is ROCKSDB.");
        return std::make_shared<RocksDBPriorityQueueSetFactory>(
            keyGroupRange,
            keyGroupPrefixBytes,
            numberOfKeyGroups,
            kvStateInformation,
            db,
            optionsContainer->getReadOptions(),
            writeBatchWrapper,
            columnFamilyOptionsFactory,
            optionsContainer->getWriteBufferManagerCapacity());
    }
    THROW_LOGIC_EXCEPTION("Unsupported priority queue state type: " + static_cast<int>(priorityQueueStateType_));
}