/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "gtest/gtest.h"
#include <mockcpp/mokc.h>
#include <mockcpp/mockcpp.hpp>
#define private public
#define protected public
#include "base_config.h"
#include "cfg_field.h"
#include "env_config.h"
#include "env_func.h"
#include "coll_service_default_impl.h"
#include "communicator_impl.h"
#include "virtual_topo.h"
#include "json_parser.h"
#include "internal_exception.h"
#include "dev_ub_connection.h"
#include "virtual_topo_stub.h"
#include "coll_alg_component_builder.h"
#include "local_rma_buffer.h"
#include "local_ub_rma_buffer.h"
#include "local_ipc_rma_buffer.h"
#include "not_support_exception.h"
#include "rdma_handle_manager.h"
#undef private
#undef protected
#include <hccl/hccl_types.h>
#include "coll_service_base.h"
#include "ccu_ctx_mgr.h"
#include "ccu_transport_manager.h"
#include "hccl.h"

using namespace Hccl;
using namespace std;

HcclResult GetProfilingInfoStub(
    s32 deviceLogicId, CcuTaskArg &ccuTaskArg, const uint64_t executorId, std::vector<std::vector<CcuProfilingInfo>> &ccuProfilingInfo);

class CollServiceDefaultImplTest : public testing::Test {
protected:
    static void SetUpTestCase()
    {
        std::cout << "CollServiceDefaultImpl SetUP" << std::endl;
    }

    static void TearDownTestCase()
    {
        std::cout << "CollServiceDefaultImpl TearDown" << std::endl;
    }

    virtual void SetUp()
    {
        std::cout << "A Test case in CollServiceDefaultImpl SetUP" << std::endl;
    }

    virtual void TearDown()
    {
        std::cout << "A Test case in CollServiceDefaultImpl TearDown" << std::endl;
        GlobalMockObject::verify();
    }
};
const char filePath[] = "ranktable.json";
const char ranktable4pPath[] = "ranktable.json";
const char topoPath[] = "topo.json";

const std::string RankTable1Ser8Dev = R"(
    {
    "server_count":"1",
    "server_list":
    [
        {
            "device":[
                        {
                        "device_id":"0",
                        "rank_id":"0"
                        },
                        {
                        "device_id":"1",
                        "rank_id":"1"
                        },
                        {
                        "device_id":"2",
                        "rank_id":"2"
                        },
                        {
                        "device_id":"3",
                        "rank_id":"3"
                        },
                        {
                        "device_id":"4",
                        "rank_id":"4"
                        },
                        {
                        "device_id":"5",
                        "rank_id":"5"
                        },
                        {
                        "device_id":"6",
                        "rank_id":"6"
                        },
                        {
                        "device_id":"7",
                        "rank_id":"7"
                        }
                    ],
            "server_id":"1"
        }
    ],
    "status":"completed",
    "version":"1.0"
    }
    )";

const std::string RankTable4p = R"(
    {
        "version": "2.0",
        "rank_count" : "4",
        "rank_list": [
            {
                "rank_id": 0,
                "local_id": 0,
                "level_list":  [
                    {
                        "level": 0,
                        "id" : "az0-rack0",
                        "fabric_type": "INNER",
                        "rank_addr_type": "",
                        "rank_addrs": []
                    }
                ]
            },
            {
                "rank_id": 1,
                "local_id": 1,
                "level_list":  [
                    {
                        "level": 0,
                        "id" : "az0-rack0",
                        "fabric_type": "INNER",
                        "rank_addr_type": "",
                        "rank_addrs": []
                    }
                ]
            },
            {
                "rank_id": 2,
                "local_id": 2,
                "level_list":  [
                    {
                        "level": 0,
                        "id" : "az0-rack0",
                        "fabric_type": "INNER",
                        "rank_addr_type": "",
                        "rank_addrs": []
                    }
                ]
            },
            {
                "rank_id": 3,
                "local_id": 3,
                "level_list":  [
                    {
                        "level": 0,
                        "id" : "az0-rack0", 
                        "fabric_type": "INNER",
                        "rank_addr_type": "",
                        "rank_addrs": []
                    }
                ]
            }
        ]
    }
    )";

const std::string Topo1Ser8Dev = R"(
    {
        "version": "2.0",
        "hardware_type" : "950-2D-Fullmsh_64_plus_1",
        "peer_count" : 4,
        "peer_list" :[
            { "id" : 0},
            { "id" : 1},
            { "id" : 2},
            { "id" : 3}
        ],
        "edge_count" : 6,
        "edge_list": [
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.7.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.7.101", "position" : "device"}
            },
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.6.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.6.101", "position" : "device"}
            },
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 0, "addr": "192.168.5.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.5.101", "position" : "device"}
            },
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.8.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.8.101", "position" : "device"}
            },
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 1, "addr": "192.168.9.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.9.101", "position" : "device"}
            },
            {
                "level": 0,
                "protocol": "UB-CTP",
                "u_endpoint": {"type": "PEER", "id": 2, "addr": "192.168.10.100", "position" : "device"},
                "v_endpoint": {"type": "PEER", "id": 3, "addr": "192.168.10.101", "position" : "device"}
            }
        ]
    }
    )";

static void GenRankTableFile1Ser8Dev()
{
    try {
        nlohmann::json rankTableJson = nlohmann::json::parse(RankTable1Ser8Dev);
        std::ofstream out(filePath, std::ofstream::out);
        out << rankTableJson;
    } catch (...) {
        std::cout << filePath << " generate failed!" << std::endl;
        return;
    }
    std::cout << filePath << " generated." << std::endl;
}

static void DelRankTableFile()
{
    int res = unlink(filePath);
    if (res == -1) {
        std::cout << filePath << " delete failed!" << std::endl;
        return;
    }
    std::cout << filePath << " deleted." << std::endl;
}

static void GenRankTableFile4p()
{
    try {
        nlohmann::json rankTableJson = nlohmann::json::parse(RankTable4p);
        std::ofstream out(ranktable4pPath, std::ofstream::out);
        out << rankTableJson;
    } catch (...) {
        std::cout << ranktable4pPath << " generate failed!" << std::endl;
        return;
    }
    std::cout << ranktable4pPath << " generated." << std::endl;
}

static void DelRankTableFile4p()
{
    int res = unlink(ranktable4pPath);
    if (res == -1) {
        std::cout << ranktable4pPath << " delete failed!" << std::endl;
        return;
    }
    std::cout << ranktable4pPath << " deleted." << std::endl;
}

static void GenTopoFile()
{
    try {
        nlohmann::json topoJson = nlohmann::json::parse(Topo1Ser8Dev);
        std::ofstream out(topoPath, std::ofstream::out);
        out << topoJson;
    } catch (...) {
        std::cout << topoPath << " generate failed!" << std::endl;
        return;
    }
    std::cout << topoPath << " generated." << std::endl;
}

static void DelTopoFile()
{
    int res = unlink(topoPath);
    if (res == -1) {
        std::cout << topoPath << " delete failed!" << std::endl;
        return;
    }
    std::cout << topoPath << " deleted." << std::endl;
}

TEST(CollServiceDefaultImplTest, test_base_register_offload_buf)
{
    MOCKER(HrtGetDevice).stubs().will(returnValue(0));
    MOCKER(HrtGetDevicePhyIdByIndex).stubs().will(returnValue(static_cast<DevId>(1)));
    DevType devType = DevType::DEV_TYPE_950;
    MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
    MOCKER(HrtIpcSetMemoryName).stubs();
    MOCKER(HrtDevMemAlignWithPage).stubs();
    MOCKER(HrtIpcDestroyMemoryName).stubs();
    void *devPtr = nullptr;
    MOCKER(HrtMalloc).stubs().with(mockcpp::any(),mockcpp::any()).will(returnValue(devPtr));
    MOCKER(HrtMemset).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
    MOCKER_CPP(&CommunicatorImpl::InitNotifyFixedValue).stubs().will(ignoreReturnValue());
    GenRankTableFile4p();
    GenTopoFile();

    MOCKER_CPP(&CommunicatorImpl::InitCollService).stubs().will(returnValue(HcclResult::HCCL_SUCCESS));
    MOCKER(HrtSetDevice).stubs().with(mockcpp::any()).will(ignoreReturnValue());

    std::shared_ptr<DevBuffer> devBuf = DevBuffer::Create(0x100, 0x100);
    LocalIpcRmaBuffer localRmaBuf(devBuf);
    MOCKER_CPP(&HcclOneSidedService::AddOpCounterMems).stubs().will(ignoreReturnValue());
    MOCKER_CPP(&LocalRmaBufManager::Reg,
        LocalRmaBuffer * (LocalRmaBufManager::*)(const string &, BufferType, std::shared_ptr<Buffer>, const PortData &, LinkProtocol))
        .stubs()
        .will(returnValue(dynamic_cast<LocalRmaBuffer *>(&localRmaBuf)));

    CommunicatorImpl comm;
    CommParams commParams;
    commParams.commId = "commId";
    HcclCommConfig config;
    HcclCommConfigInit(&config);
    commParams.myRank = 1;
    commParams.rankSize = 4;
    commParams.devType = DevType::DEV_TYPE_950;
    comm.rankGraph = make_unique<RankGraph>(0);
    comm.rankGraph->peers_[0] = make_shared<NetInstance::Peer>(0, 0, 0, 0);
    comm.Init(commParams, RankTable4p, config);
    comm.InitDataBufferManager();
    DelRankTableFile4p();
    DelTopoFile();

    CollServiceDefaultImpl service(&comm);
    CollOperator op;
    op.inputMem = DevBuffer::Create(0x100, 1);
    op.outputMem = DevBuffer::Create(0x100, 1);
    op.scratchMem = DevBuffer::Create(0x100, 1);
    op.opTag = "test_opTag";
    EXPECT_NO_THROW(service.RegisterOpBufToBufMgr(op));
}

TEST(CollServiceDefaultImplTest, test_base_register_offload_stream)
{
    MOCKER(HrtStreamDestroy).stubs();
    void* temp = (void *)0x1;
    MOCKER(HrtStreamCreateWithFlags).stubs().will(returnValue(temp));
    MOCKER(HrtGetStreamId).stubs().will(returnValue(0));
    CommunicatorImpl comm;
    comm.id = "commId";
    comm.streamManager = make_unique<StreamManager>(&comm);
    CollServiceDefaultImpl service(&comm);
    std::string opTag = "test";
    EXPECT_NO_THROW(service.RegisterOffloadMasterStream(opTag, move(make_unique<Stream>())));
}

TEST(CollServiceDefaultImplTest, test_calc_coll_offload_op_res_with_hccl_success_returned)
{
    setenv("PRIM_QUEUE_GEN_NAME", "CcuAllReduceMesh1D", 1);
    CommunicatorImpl comm;
    comm.id = "commId";
    comm.streamManager = make_unique<StreamManager>(&comm);
    CollServiceDefaultImpl service(&comm);

    OpType opType = OpType::ALLREDUCE;
    u64 dataSize = 100;
    CollOffloadOpResReq resReq1;
    resReq1.requiredSubQueNum = 2;
    resReq1.requiredScratchMemSize = 0;

    VirtualTopoStub virtTopo(0);
    string rankTable = "test";
    virtTopo.TopoInit91095OneTimesFour(rankTable);

    RankId myRank = 0;
    u32 rankSize = 4;

    CollAlgComponentBuilder collAlgComponentBuilder;
    std::shared_ptr<CollAlgComponent> collAlgComponent = collAlgComponentBuilder.SetRankGraph(&virtTopo)
                                   .SetDevType(DevType::DEV_TYPE_950)
                                   .SetMyRank(myRank)
                                   .SetRankSize(rankSize)
                                   .Build();
    comm.collAlgComponent = collAlgComponent;
    OpExecuteConfig opConfig;  // host 展开,图模式使用
    opConfig.accState = AcceleratorState::HOSTCPU_TS;
    comm.opExecuteConfig = opConfig;
    EXPECT_NO_THROW(comm.CalcCollOffloadOpRes(opType, dataSize, HCCL_DATA_TYPE_INT8, resReq1));
    EXPECT_EQ(16, resReq1.requiredSubQueNum);
    EXPECT_EQ(256*1024*1024, resReq1.requiredScratchMemSize);
}

class FakeCollAlgComponent : public CollAlgComponent {
public:
    FakeCollAlgComponent() : CollAlgComponent(nullptr, DevType::DEV_TYPE_950, 0, 1){};
    HcclResult Orchestrate(
        const CollAlgOperator &op, const CollAlgParams &params, const string &algName, InsQuePtr queue) override
    {
        return HcclResult::HCCL_SUCCESS;
    }

    HcclResult Orchestrate(
        const CollAlgOperator &op, const CollAlgParams &params, const string &algName, PrimQuePtr queue) override
    {
        return HcclResult::HCCL_SUCCESS;
    }
};

class FakeCollAlgComponentWithError : public CollAlgComponent {
public:
    FakeCollAlgComponentWithError() : CollAlgComponent(nullptr, DevType::DEV_TYPE_950, 0, 1)
    {}
    HcclResult Orchestrate(
        const CollAlgOperator &op, const CollAlgParams &params, const string &algName, InsQuePtr queue) override
    {
        return HcclResult::HCCL_E_INTERNAL;
    }
};

TEST(CollServiceDefaultImplTest, test_orchestrate_with_ins)
{
    CommunicatorImpl comm;
    comm.id = "test";
    CollServiceDefaultImpl service(&comm);
    CollOperator op;
    op.opMode = OpMode::OPBASE;
    op.opType = OpType::ALLREDUCE;
    op.reduceOp = ReduceOp::SUM;
    op.dataType = DataType::INT8;
    op.dataCount = 4;
    op.root = 0;
    op.scratchMem = DevBuffer::Create(0x100, 1);
    std::shared_ptr<FakeCollAlgComponent> collAlgComponent = std::make_shared<FakeCollAlgComponent>();
    comm.collAlgComponent = collAlgComponent;
    MOCKER_CPP_VIRTUAL(*collAlgComponent,
    &CollAlgComponent::Orchestrate,
    HcclResult(CollAlgComponent::*)(
        const CollAlgOperator &op, const CollAlgParams &params, const string &algName, InsQuePtr queue))
    .stubs()
    .with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
    .will(returnValue(HcclResult::HCCL_SUCCESS));
    EXPECT_NO_THROW(service.OrchestrateWithIns(op));
}

TEST(CollServiceDefaultImplTest, test_orchestrate_with_ins_throw_exception)
{
    CommunicatorImpl comm;
    comm.id = "test";
    CollServiceDefaultImpl service(&comm);
    CollOperator op;
    op.opMode = OpMode::OPBASE;
    op.opType = OpType::ALLREDUCE;
    op.reduceOp = ReduceOp::SUM;
    op.dataType = DataType::INT8;
    op.dataCount = 4;
    op.root = 0;
    op.scratchMem = DevBuffer::Create(0x100, 1);
    comm.collAlgComponent = std::make_shared<FakeCollAlgComponentWithError>();
    EXPECT_THROW(service.OrchestrateWithIns(op), InternalException);
}

TEST(CollServiceDefaultImplTest, alloc_queue_notify_for_single_queue)
{
    CommunicatorImpl comm;
    auto queueNotifyManager = std::make_unique<QueueNotifyManager>(comm);
    auto queueWaitGroupCntNotifyManager = std::make_unique<QueueWaitGroupCntNotifyManager>();
    CollServiceDefaultImpl service(&comm);
    comm.queueNotifyManager = std::move(queueNotifyManager);
    comm.queueWaitGroupCntNotifyManager = std::move(queueWaitGroupCntNotifyManager);
    MOCKER_CPP(&QueueWaitGroupCntNotifyManager::ApplyFor).stubs().with();

    InsQueue insQueue;
    auto insLocalWaitGroup = std::make_unique<InsLocalWaitGroup>(0);
    insQueue.Append(std::move(insLocalWaitGroup));
    service.AllocQNotifyForSingleQ(insQueue);

    GlobalMockObject::reset();
}

TEST(CollServiceDefaultImplTest, alloc_cnt_notify_for_single_queue)
{
    CommunicatorImpl comm;
    auto queueNotifyManager = std::make_unique<QueueNotifyManager>(comm);
    auto queueBcastPostCntNotifyManager = std::make_unique<QueueBcastPostCntNotifyManager>();
    CollServiceDefaultImpl service(&comm);
    comm.queueNotifyManager = std::move(queueNotifyManager);
    comm.queueBcastPostCntNotifyManager = std::move(queueBcastPostCntNotifyManager);
    MOCKER_CPP(&QueueBcastPostCntNotifyManager::ApplyFor).stubs().with();

    InsQueue insQueue;
    auto insLocalBcastPost = std::make_unique<InsLocalBcastPost>(0);
    insQueue.Append(std::move(insLocalBcastPost));
    service.AllocQNotifyForSingleQ(insQueue);

    GlobalMockObject::reset();
}

TEST(CollServiceDefaultImplTest, test_communicator_impl_hostDeviceSyncNotifyManager_GetIdIndex)
{
    MOCKER(HrtGetDevice).stubs().will(returnValue(0));
    MOCKER(HrtNotifyCreate).stubs().will(returnValue((void *)(1)));
    MOCKER(HrtNotifyCreateWithFlag).stubs().will(returnValue((void *)(1)));
    MOCKER(HrtGetNotifyID).stubs().will(returnValue(1));
    MOCKER(HrtGetDevicePhyIdByIndex).stubs().will(returnValue(static_cast<DevId>(1)));

    CommunicatorImpl comm;
    comm.hostDeviceSyncNotifyManager = std::make_unique<HostDeviceSyncNotifyManager>();

    EXPECT_EQ(comm.hostDeviceSyncNotifyManager.get(), &(comm.GetHostDeviceSyncNotifyManager()));

    EXPECT_EQ(0, comm.GetIdIndex());
}

TEST(CollServiceDefaultImplTest, coll_service_default_impl_orchestrate_with_ins_success)
{
    CommunicatorImpl comm;
    comm.id = "test";
    comm.rmaConnectionManager = std::make_unique<RmaConnManager>(comm);
    comm.connLocalNotifyManager = std::make_unique<ConnLocalNotifyManager>(&comm);
    comm.connLocalCntNotifyManager = std::make_unique<ConnLocalCntNotifyManager>(&comm);
    comm.streamManager = make_unique<StreamManager>(&comm);
    comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
    comm.socketManager = std::make_unique<SocketManager>(comm, 1, 1, 1);
    comm.trace = std::make_unique<Trace>();
    comm.memTransportManager = make_unique<MemTransportManager>(comm);
    comm.cclBuffer = DevBuffer::Create(0x100, 200);
    u32 remoteRank = 1;
    CollOpParams collOpParams;
    collOpParams.opType = OpType::SEND;
    collOpParams.dataType = DataType::INT8;  // sizeof(int8) = 1
    collOpParams.reduceOp = ReduceOp::SUM;
    collOpParams.dstRank = remoteRank;
    collOpParams.sendBuf = nullptr;
    collOpParams.recvBuf = nullptr;
    collOpParams.count = 10;
    collOpParams.root = 0;
    collOpParams.staticAddr = true;
    collOpParams.staticShape = true;
    collOpParams.outputDataType = DataType::INT8;
    collOpParams.debugCase = 1;
    collOpParams.dstRank = 0;
    std::string name = "test";
    comm.CovertToCurrentCollOperator(name, collOpParams, OpMode::OPBASE);

    CollServiceDefaultImpl service(&comm);
    std::vector<Stream *> slaveVec;
    MOCKER_CPP(&Interpreter::Submit).stubs();
    std::shared_ptr<FakeCollAlgComponent> collAlgComponent = std::make_shared<FakeCollAlgComponent>();
    comm.collAlgComponent = collAlgComponent;
    MOCKER_CPP_VIRTUAL(*collAlgComponent,
    &CollAlgComponent::Orchestrate,
    HcclResult(CollAlgComponent::*)(
        const CollAlgOperator &op, const CollAlgParams &params, const string &algName, InsQuePtr queue))
    .stubs()
    .with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
    .will(returnValue(HcclResult::HCCL_SUCCESS));
    service.connectionsBuilders[comm.id] = std::make_unique<ConnectionsBuilder>(comm);

    DevType devType = DevType::DEV_TYPE_910A;
    MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
    MOCKER_CPP(&CollServiceDefaultImpl::RegisterOpBufToBufMgr).stubs();
    MOCKER_CPP(&CollServiceDefaultImpl::RegisterOpbasedStream).stubs();
    MOCKER_CPP(&SocketManager::BatchCreateSockets).stubs();
    MOCKER_CPP(&CollServiceBase::SaveMirrorDfxOpInfo).stubs();
    
    vector<LinkData> links;
    MOCKER_CPP(&InsQueue::GetUniqueLinks).stubs().will(returnValue(links));

    shared_ptr<InsQueue> insQueue = make_shared<InsQueue>();
    MOCKER_CPP(&PrimTranslator::Translate).stubs().will(returnValue(insQueue));

    MOCKER_CPP(&ConnectionsBuilder::BatchBuild).stubs();

    CollOperator op;
    op.opMode = OpMode::OPBASE;
    op.inputMem = DevBuffer::Create(0x100, 1);
    op.outputMem = DevBuffer::Create(0x100, 1);
    op.scratchMem = DevBuffer::Create(0x100, 1);
    s32 fakeStreamId = 123;
    MOCKER(HrtGetStreamId).stubs().with(mockcpp::any()).will(returnValue(0));
    auto stream = std::make_unique<Stream>(nullptr);
    comm.streamManager->opbase->master = make_unique<Stream>(&comm);
    comm.currentCollOperator->opMode = OpMode::OPBASE;
    comm.rankGraph = std::make_unique<RankGraph>(comm.GetMyRank());
    comm.rankGraph->peers_[comm.GetMyRank()] = std::make_shared<NetInstance::Peer>(comm.GetMyRank(), 0, 0, 0);
    
    EnvAlgoConfig fakeAlgoCfg;
    EnvAlgoConfig &algoCfg = fakeAlgoCfg;
    algoCfg.bufferSize =
        CfgField<u64>("HCCL_BUFFSIZE", 200 * 1024 * 1024, Str2T<u64>, CHK_RANGE_CLOSED<u64>(1, ULLONG_MAX), [](u64 &i) {
            i *= 1024 * 1024;
        });
    algoCfg.bufferSize.isParsed = true;
    MOCKER_CPP(&EnvConfig::GetAlgoConfig).stubs().will(returnValue(algoCfg));
    MOCKER_CPP(&CollServiceBase::SaveMirrorDfxOpInfo).stubs().with(mockcpp::any()).will(ignoreReturnValue());
    EXPECT_NO_THROW(service.LoadWithOpBasedMode(op, std::move(stream)));
}

TEST(CollServiceDefaultImplTest, coll_service_default_RmaConnManager_Get_allDTO_return_empty)
{
    CommunicatorImpl comm;

    CollOpParams collOpParams;
    collOpParams.opType = OpType::SEND;
    collOpParams.dataType = DataType::INT8;  // sizeof(int8) = 1
    collOpParams.reduceOp = ReduceOp::SUM;
    collOpParams.dstRank = 1;
    u32 buffer = 10;
    collOpParams.sendBuf = static_cast<void *>(&buffer);
    collOpParams.recvBuf = static_cast<void *>(&buffer);
    collOpParams.count = 10;
    collOpParams.root = 0;
    collOpParams.staticAddr = true;
    collOpParams.staticShape = true;

    uint64_t a = 10;
    uintptr_t devAddr = reinterpret_cast<uintptr_t>(&a);
    std::size_t devSize = 2;
    comm.cclBuffer = make_shared<DevBuffer>(10);
    string tag = "optag";
    comm.CovertToCurrentCollOperator(tag, collOpParams, OpMode::OPBASE);
}

TEST(CollServiceDefaultImplTest, col_service_default_impl_update_ub_ci_if_need_success)
{
    DevType devType = DevType::DEV_TYPE_910A;
    MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
    u32 remoteRank = 1;
    CommunicatorImpl comm;
    CollOpParams collOpParams;
    collOpParams.opType      = OpType::SEND;
    collOpParams.dataType    = DataType::INT8; // sizeof(int8) = 1
    collOpParams.reduceOp    = ReduceOp::SUM;
    collOpParams.dstRank     = remoteRank;
    collOpParams.sendBuf     = nullptr;
    collOpParams.recvBuf     = nullptr;
    collOpParams.count       = 10;
    collOpParams.root        = 0;
    collOpParams.staticAddr  = true;
    collOpParams.staticShape = true;
    collOpParams.outputDataType = DataType::INT8;
    collOpParams.debugCase = 1;
    collOpParams.dstRank = 0;
    std::string name = "test";
    comm.cclBuffer = DevBuffer::Create(0x100, 200);
    comm.CovertToCurrentCollOperator(name, collOpParams, OpMode::OPBASE);
    comm.id = "test";
    comm.rmaConnectionManager = std::make_unique<RmaConnManager>(comm);
    comm.connLocalNotifyManager = std::make_unique<ConnLocalNotifyManager>(&comm);
    comm.connLocalCntNotifyManager = std::make_unique<ConnLocalCntNotifyManager>(&comm);
    comm.streamManager = make_unique<StreamManager>(&comm);
    comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
    void* ptr;
    unique_ptr<Stream> master = make_unique<Stream>((ptr));
    comm.streamManager->opbase->master = std::move(master);
    comm.socketManager = std::make_unique<SocketManager>(comm, 1, 1, 1);
    comm.memTransportManager = make_unique<MemTransportManager>(comm);

    CollServiceDefaultImpl service(&comm);
    service.updatingUbCiEvent = nullptr;
    MOCKER(IfNeedUpdatingUbCi).stubs().will(returnValue(true));
    MOCKER_CPP(&UbCiUpdaterManager::SaveConnsCi).stubs().with(mockcpp::any()).will(ignoreReturnValue());
    CollOperator op;
    op.opTag = "test";
    service.UpdateUbCiIfNeed(op.opTag);
    service.updatingUbCiEvent = make_unique<MaskEvent>();
    RtEvent_t fakePtr = nullptr;
    aclrtEventWaitStatus status = ACL_EVENT_WAIT_STATUS_COMPLETE;
    MOCKER(aclrtQueryEventWaitStatus).stubs().with(mockcpp::any(), outBoundP(&status, sizeof(status))).will(returnValue(ACL_SUCCESS));
    MOCKER_CPP(&UbCiUpdaterManager::UpdateConnsCi).stubs().with(mockcpp::any());
    service.UpdateUbCiIfNeed(op.opTag);
}

TEST(CollServiceDefaultImplTest, AddCountTask)
{
    GlobalMockObject::verify();
    MOCKER(HrtStreamDestroy).stubs().with(mockcpp::any());
    void *ptr = nullptr;
    MOCKER(HrtStreamCreateWithFlags).stubs().with(mockcpp::any(), mockcpp::any()).will(returnValue(ptr));
    MOCKER(HrtGetStreamId).stubs().with(mockcpp::any()).will(returnValue(0));
    DevType devType = DevType::DEV_TYPE_910A;
    MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
    MOCKER(HrtReduceAsync).stubs().with(mockcpp::any());
    MOCKER(HrtMemcpy).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
    void *devPtr = nullptr;
    MOCKER(HrtMalloc).stubs().with(mockcpp::any(),mockcpp::any()).will(returnValue(devPtr));
    MOCKER(HrtMemset).stubs().with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any());
 
    CommunicatorImpl comm;
    comm.streamManager = make_unique<StreamManager>(&comm);
    comm.streamManager->opbase = make_unique<OpbaseStreamManager>(&comm);
    comm.streamManager->opbase->master = make_unique<Stream>(&comm);
    comm.currentCollOperator = make_unique<CollOperator>();
    comm.currentCollOperator->opMode = OpMode::OPBASE;
    CollServiceDefaultImpl collServiceDefaultImpl(&comm);
    EXPECT_NO_THROW(collServiceDefaultImpl.AddCountTask(true));

}


TEST(CollServiceDefaultImplTest, Test_RecoverTransport)
{
    std::unique_ptr<CommunicatorImpl> comm = std::make_unique<CommunicatorImpl>();
    CollServiceDefaultImpl collServiceDefaultImpl(comm.get());
 
    MOCKER_CPP(&RdmaHandleManager::GetDieAndFuncId).stubs().will(returnValue(make_pair<uint32_t,uint32_t>(0,0)));
    DevType devType = DevType::DEV_TYPE_950;
    MOCKER(HrtGetDeviceType).stubs().will(returnValue(devType));
    
    vector<LinkData> links;
    vector<std::pair<LinkGroup, u32>> linkGroupPair;
    LinkData linkData(PortDeploymentType::P2P,LinkProtocol::UB_CTP, 0, 1, IpAddress{"10.0.0.1"}, IpAddress{"10.0.0.2"});
    links.push_back(linkData);
    LinkGroup linkGroup{};
    linkGroup.AddLink({linkData});
    LinkData otherLinkData(PortDeploymentType::P2P,LinkProtocol::UB_CTP, 1, 1, IpAddress{"10.0.0.3"}, IpAddress{"10.0.0.4"});;
    linkGroup.AddLink({otherLinkData});
    linkGroupPair.push_back(make_pair(linkGroup, 0));

    EXPECT_THROW(collServiceDefaultImpl.RecoverTransport(links, linkGroupPair), NotSupportException);
}