/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef THREAD_MANAGE_H
#define THREAD_MANAGE_H

#include <condition_variable>
#include "threadManage.h"
#include "dispatcher.h"
#include "workflow_pub.h"
#include "externalinput_pub.h"
#include "common.h"
#include "template_v1_utils.h"

namespace hccl {
enum class ExecutorType {
    REDUCE_SCATTER_RING,
    ALLGATHER_RING,
    REDUCE_SCATTER_RING_DIRECT,
    REDUCE_SCATTER_RING_DIRECT_RDMA,
    ALLGATHER_RING_DIRECT,
    ALLGATHER_RING_DIRECT_RDMA,
    TYPE_RESERVED
};

class ThreadManage {
public:
    explicit ThreadManage(s32 deviceLogicId, u32 userRank, const HcclDispatcher dispatcher);

    ~ThreadManage();

    HcclResult Init();
    HcclResult Prepare(DeviceMem &inputMem, DeviceMem &outputMem, DeviceMem &scratchMem, const u64 count,
                       const HcclDataType dataType, const Stream &stream, const HcclReduceOp reductionOp,
                       const u32 root, const std::vector<Slice> &slices, const u64 baseOffset,
                       std::vector<u32> nicRankList, const std::string &tag,
                       s32 profStage, const SubCommInfo &ringSubCommInfo, std::shared_ptr<LocalNotify> &signalAux,
                       std::shared_ptr<LocalNotify> &signalMain, u32 ringIndex,
                       ExecutorType type, u64 reduceAttr = 0, const HcomCollOpInfo *opInfo = nullptr,
                       std::vector<Stream> subStreamsInOneRing = {},
                       std::vector<std::shared_ptr<LocalNotify>> mainSignalsInOneRing = {},
                       std::vector<std::shared_ptr<LocalNotify>> subSignalsInOneRing = {},
                       std::vector<u32> ringsOrder = {},
                       std::vector<Slice> userMemInputSlices = {});

    HcclResult Finalize();
    void NotifyStart();
    void WaitStart();
    void NotifyDone();
    void WaitDone();
    uint32_t GetTid();
protected:

private:
    HcclResult ThreadExecuteFn();
    HcclResult ExecuteService();
    std::shared_ptr<std::thread>ringThread_;
    uint32_t threadId_ = 0;
    s32 deviceLogicId_;
    u32 userRank_;
    const HcclDispatcher dispatcher_;

    std::mutex startMtx_;
    std::mutex doneMtx_;
    std::condition_variable startCv_;
    std::condition_variable doneCv_;
    bool startReady = false;
    bool doneReady = false;
    bool threadExit = false;

    DeviceMem inputMem_;
    DeviceMem outputMem_;
    DeviceMem scratchMem_;
    Stream stream_;

    u64 count_ = 0;
    HcclDataType dataType_ = HCCL_DATA_TYPE_RESERVED;
    HcclReduceOp reductionOp_ = HCCL_REDUCE_RESERVED;
    u32 root_ = 0;
    std::vector<Slice> slices_;
    u64 baseOffset_ = 0;
    std::vector<u32> nicRankList_;
    std::string tag_;
    s32 profStage_ = 0;
    SubCommInfo ringSubCommInfo_;
    std::shared_ptr<LocalNotify> signalAux_ = nullptr;
    std::shared_ptr<LocalNotify> signalMain_ = nullptr;
    u32 ringIndex_  = 0;
    u64 reduceAttr_ = 0;
    const HcomCollOpInfo *opInfo_;
    std::vector<Stream> subStreamsInOneRing_;
    std::vector<std::shared_ptr<LocalNotify>> mainSignalsInOneRing_;
    std::vector<std::shared_ptr<LocalNotify>> subSignalsInOneRing_;
    std::vector<u32> ringsOrder_;
    std::vector<Slice> userMemInputSlices_;
    ExecutorType executorType_ = ExecutorType::TYPE_RESERVED;
    HcclRtContext context_;
    HcclWorkflowMode workflowMode_{HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE};
};
}  // namespace hccl

#endif /* * THREAD_MANAGE_H */