/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include <stdio.h>
#define private public
#define protected public
#include "hcom_ops_kernel_info_store.h"
#include "hcom_ops_kernel_builder.h"
#include "hcom_graph_optimizer.h"
#undef protected
#undef private
#include "hccl_stub.h"
#include "external/ge/ge_api_types.h" // ge对内options
#include "common/ge_common/ge_types.h"
#include "hcom_executor.h"
#include "v80_rank_table.h"
#include <iostream>
#include <fstream>
#include "graph/utils/node_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "llt_hccl_stub_ge.h"
#include "offline_build_config_parse.h"
#include "hcom_op_utils.h"

using namespace std;
using namespace hccl;

class HcomKernelBuilderTest : public testing::Test
{
protected:
    static void SetUpTestCase()
    {
        nlohmann::json rank_table =
        {
            {"status", "completed"},
            {"deploy_mode", "lab"},
            {"device_num", "4"},
            {"server_num", "2"},
            {"boardType", "0"},
            {"para_plane_location", "device"},
            {"para_plane_nic_num", "2"},
            {"para_plane_nic_name", {"eth0", "eth1"}},
            {"instance_count", "4"},
            {"device_count", "4"},
            {
                "instance_list",
                {
                    {   {"pod_name", ""}, {"rank_id", "0"}, {"server_id", "10.0.0.10"},
                        {
                            "devices", {{{"device_id", "1"}, {"device_ip", "192.168.0.12"}, {"ref_ip", "192.168.10.13"}}}
                        }
                    },
                    {   {"pod_name", ""}, {"rank_id", "1"}, {"server_id", "10.0.0.10"},
                        {
                            "devices", {{{"device_id", "0"}, {"device_ip", "192.168.1.12"}, {"ref_ip", "192.168.11.13"}}}
                        }
                    },
                    {   {"pod_name", ""}, {"rank_id", "2"}, {"server_id", "10.0.0.11"},
                        {
                            "devices", {{{"device_id", "0"}, {"device_ip", "192.168.0.14"}, {"ref_ip", "192.168.10.15"}}}
                        }
                    },
                    {   {"pod_name", ""}, {"rank_id", "3"}, {"server_id", "10.0.0.11"},
                        {
                            "devices", {{{"device_id", "1"}, {"device_ip", "192.168.1.14"}, {"ref_ip", "192.168.11.15"}}}
                        }
                    }
                }
            },
            {
                "server_list",
                {
                    {
                        {"server_id", "192.168.10.2"},
                        {
                            "para_plane_info",
                            {{
                                    {"eth1", "192.168.210.2"},
                                    {"ref_ip", "192.168.210.1"}
                                },
                                {
                                    {"eth0", "192.168.200.2"},
                                    {"ref_ip", "192.168.200.1"}
                                }
                            }
                        }

                    },
                    {
                        {"server_id", "192.168.10.3"},
                        {
                            "para_plane_info",
                            {{
                                    {"eth0", "192.168.200.3"},
                                    {"ref_ip", "192.168.200.1"}
                                },
                                {
                                    {"eth1", "192.168.210.3"},
                                    {"ref_ip", "192.168.210.1"}
                                }
                            }
                        }

                    },

                }
            }
        };
        char file_name[] = "./ut_HcomKernelBuilderTest.json";

        std::ofstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary);

        if (outfile.is_open())
        {
            HCCL_INFO("open %s success", file_name);
        }
        else
        {
            HCCL_INFO("open %s failed", file_name);
        }

        outfile << std::setw(4) << rank_table << std::endl;
        outfile.close();

        std::cout << "\033[36m--HcomKernelInfoTest SetUP--\033[0m" << std::endl;
    }
    static void TearDownTestCase()
    {
        char file_name[] = "./ut_HcomKernelBuilderTest.json";
        remove(file_name);
        std::cout << "\033[36m--HcomKernelInfoTest TearDown--\033[0m" << std::endl;
    }
    // Some expensive resource shared by all tests.
    virtual void SetUp()
    {
        std::cout << "A Test SetUP" << std::endl;
    }
    virtual void TearDown()
    {
        std::cout << "A Test TearDown" << std::endl;
    }
};

class NodeTest : public ge::Node {
public:
    NodeTest(){;};
    ~NodeTest(){;};
};

ge::graphStatus OfflineRankMappingOption(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        },
        {
            {"group_name", "off_group_rank_list"},
            {"group_rank_list", {0, 1, 2, 3, 4, 5, 6, 7}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "2","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "3","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "4","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "5","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "6","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "7","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "8","item_id": "-1","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == "ge.socVersion") {
        dumpDebugValue = "Ascend910";
    }
    HCCL_INFO("dumpDebugValue:[%s]", dumpDebugValue.c_str());
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpRunningParam_common)
{
    struct model_feature feature;
    u32 segment_num = 10;
    std::vector<u32> segment_index;
    HcclResult ret;

    nlohmann::json rank_table =
    {
        {"status", "completed"},
        {"deploy_mode", "lab"},
        {"group_count", "1"},
        {"chip_info", "910"},
        {"board_id", "0x0000"},
        {"para_plane_nic_location", "device"},
        {"para_plane_nic_num", "1"},
        {"para_plane_nic_name", {"eth0"}},
        {
            "group_list",
            {
                {
                    {"group_name", ""},
                    {"device_num", "1"},
                    {"server_num", "1"},
                    {"instance_count", "1"},
                        {
                            "instance_list",
                            {
                                {   {"rank_id", "0"}, {"server_id", "172.17.1.120"},
                                    {
                                        "devices", {{{"device_id", "0"}, {"device_ip", "192.168.1.120"}}}
                                    }
                                }
                            }
                        },
                }
            }
        }
    };

    char file_name_t[] = "./ut_CalcOpRunningParam_common.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomOpsKernelBuilder::GetOriginalGraphShapeTypeFromDesc)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    nodeptr->GetOpDesc()->SetType("");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::INTERNAL_ERROR);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "_super_kernel_scope", "super_kernel_scope");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    std::vector<int64_t> workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "reduction", "sum"); 
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_RANK_SIZE");
    int64_t RANK_SIZE = 1;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();    

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption).stubs().will(invoke(OfflineRankMappingOption));
    uint32_t graphId = 1;
    MOCKER_CPP(&HcomOpsKernelBuilder::GetRootGraphID)
        .stubs()
        .with(mockcpp::any(), outBound(graphId))
        .will(returnValue(HCCL_SUCCESS));
    std::string curGroup = "off_group_rank_list";
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SHAPE");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(OfflineRankMappingOption));
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SHAPE");

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpRunningParam_common_51)
{
    struct model_feature feature;
    u32 segment_num = 10;
    std::vector<u32> segment_index;
    HcclResult ret;

    nlohmann::json rank_table =
    {
        {"status", "completed"},
        {"deploy_mode", "lab"},
        {"group_count", "1"},
        {"chip_info", "910"},
        {"board_id", "0x0000"},
        {"para_plane_nic_location", "device"},
        {"para_plane_nic_num", "1"},
        {"para_plane_nic_name", {"eth0"}},
        {
            "group_list",
            {
                {
                    {"group_name", ""},
                    {"device_num", "1"},
                    {"server_num", "1"},
                    {"instance_count", "1"},
                        {
                            "instance_list",
                            {
                                {   {"rank_id", "0"}, {"server_id", "172.17.1.120"},
                                    {
                                        "devices", {{{"device_id", "0"}, {"device_ip", "192.168.1.120"}}}
                                    }
                                }
                            }
                        },
                }
            }
        }
    };

    char file_name_t[] = "./ut_CalcOpRunningParam_common_51.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }
    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;
    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomOpsKernelBuilder::GetOriginalGraphShapeTypeFromDesc)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    DevType type610 = DevType::DEV_TYPE_310P1;
    ge::NodePtr nodeptr(new NodeTest);
    nodeptr->GetOpDesc()->SetType("");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::INTERNAL_ERROR);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", HCCL_WORLD_GROUP);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "_super_kernel_scope", "super_kernel_scope");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    std::vector<int64_t> workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_RANK_SIZE");
    int64_t RANK_SIZE = 1;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();    

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(OfflineRankMappingOption));
    std::string curGroup = "off_group_rank_list";
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SHAPE");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(OfflineRankMappingOption));
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SHAPE");

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);
    GlobalMockObject::verify();
}


TEST_F(HcomKernelBuilderTest, ut_generateTask)
{
    ge::NodePtr nodeptr(new NodeTest);
    ge::Buffer tempBuffer;
    ge::RunContext runContext_dummy;
    HcomOpsKernelBuilder hcomKernelInfo;

    std::vector<domi::TaskDef> taskDefList;

    rtError_t rt_ret = RT_ERROR_NONE;
    rtStream_t stream;
    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

	std::string name = "HcomTag";
	nodeptr->GetOpDesc()->SetName(name);

    // -------------------HcomBroadcast test----------------
    std::string type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_FISSION");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPSIZE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPTYPE");
    s64 tempInt = 5;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "dest_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "src_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "_fission_factor", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_size", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_type", 0);
    HCCL_INFO("node[%p] run context[%p]", nodeptr.get(), &runContext_dummy);
    HCCL_INFO("----------%s", nodeptr->GetOpDesc()->GetType().c_str());
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_NEEDMAPRANK");
    ge::AttrUtils::SetBool(nodeptr->GetOpDesc(), "_need_map_rank_id", true);


    s32 ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);

    u32 result_type = taskDefList[0].type();
    u32 result_stream_id = taskDefList[0].stream_id();
    std::string result_hccl_hccl_type = taskDefList[0].mutable_kernel_hccl()->hccl_type();
    std::string result_private_def = taskDefList[0].private_def();
    char private_def_buf[sizeof(HCCL_KERNEL_INFO_PRIVATE_DEF)];
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    HCCL_KERNEL_INFO_PRIVATE_DEF *privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    std::string result_group = reinterpret_cast<const char*>(privateDefBuf->group);
    u32 result_srcRank = (privateDefBuf->srcRank);
    u32 result_destRank = (privateDefBuf->destRank);
    u32 result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_group, tempStr);
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, 0);


    // -------------------HcomSend test----------------
    MOCKER(HcomGetRankId)
    .expects(atMost(8))
    .will(returnValue(HCCL_SUCCESS));
    type = HCCL_KERNEL_OP_TYPE_SEND;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[1].type();
    result_stream_id = taskDefList[1].stream_id();
    result_hccl_hccl_type = taskDefList[1].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[1].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_group = reinterpret_cast<const char*>(privateDefBuf->group);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_group, tempStr);
    std::string tmpTag = result_group+"5"+"0"+"5";
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, tempInt);
    EXPECT_EQ(result_srTag, tempInt);

    // Send: 未设定 destRank 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DESTRANK");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");

    // Send: 未设定 srTag 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomReceive test----------------
    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[2].type();
    result_stream_id = taskDefList[2].stream_id();
    result_hccl_hccl_type = taskDefList[2].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[2].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_group = reinterpret_cast<const char*>(privateDefBuf->group);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_group, tempStr);
    tmpTag = result_group+"5"+"5"+"0";
    EXPECT_EQ(result_srcRank, tempInt);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, tempInt);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");

    // Receive: srcRank未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRCRANK");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");

    // Receive: srTag未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomAllReduce test----------------
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");

    // -------------------incalid type test----------------
    type = " ";
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
}

TEST_F(HcomKernelBuilderTest, ut_GenerateTask_unknown)
{
    ge::Status ret;
    //HcomOpsKernelInfoStore  hcomOpsKernelInfoStore_;
    map<string, string> options;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    ret = hcomOpsKernelInfoStore_.Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);
    bool is_unknown = true;
    ge::NodePtr nodeptr(new NodeTest);
    ge::RunContext runContext;
    std::vector<domi::TaskDef> taskDefList;

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(is_unknown))
    .will(returnValue(ge::GRAPH_SUCCESS));
    ret = hcomOpsKernelInfoStore_.GenerateTask(*nodeptr, runContext, taskDefList);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
    ret = hcomOpsKernelInfoStore_.Finalize();
    EXPECT_EQ(ret, ge::SUCCESS);
}

HcclResult MockGetOffDeviceTypeWithoutDev(DevType &devType)
{
    devType = DevType::DEV_TYPE_310P3;
    HCCL_DEBUG("[offline] Get devtype[%u]....", devType);
    return HCCL_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpRunningParam_unknown)
{
    ge::Status ret;
    //HcomOpsKernelInfoStore  hcomOpsKernelInfoStore_;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    bool is_unknown = true;
    ge::NodePtr nodeptr(new NodeTest);
    ge::RunContext runContext;
    std::vector<domi::TaskDef> taskDefList;

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(is_unknown))
    .will(returnValue(ge::GRAPH_SUCCESS));

    ret = hcomOpsKernelInfoStore_.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(IsOfflineCompilation)
    .stubs()
    .will(returnValue(true));

    MOCKER(GetOffDeviceTypeWithoutDev)
    .stubs()
    .will(invoke(MockGetOffDeviceTypeWithoutDev));
    hcomOpsKernelInfoStore_.CalcOpRunningParam(*nodeptr);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpRunningParam_V51)
{
    struct model_feature feature;
    u32 segment_num = 10;
    std::vector<u32> segment_index;
    HcclResult ret;

    nlohmann::json rank_table =
    {
        {"status", "completed"},
        {"deploy_mode", "lab"},
        {"group_count", "1"},
        {"chip_info", "310P3"},
        {"board_id", "0x2000"},
        {"para_plane_nic_location", "device"},
        {"para_plane_nic_num", "1"},
        {"para_plane_nic_name", {"eth0"}},
        {
            "group_list",
            {
                {
                    {"group_name", ""},
                    {"device_num", "1"},
                    {"server_num", "1"},
                    {"instance_count", "1"},
                        {
                            "instance_list",
                            {
                                {   {"rank_id", "0"}, {"server_id", "172.17.1.120"},
                                    {
                                        "devices", {{{"device_id", "0"}, {"device_ip", "192.168.1.120"}}}
                                    }
                                }
                            }
                        },
                }
            }
        }
    };

    char file_name_t[] = "./ut_CalcOpRunningParam_V51.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    set_board_id(0x2000);

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::INTERNAL_ERROR);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    std::vector<int64_t> workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_RANK_SIZE");
    int64_t RANK_SIZE = 1;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    workSpaceBytes.clear();
    workSpaceBytes = nodeptr->GetOpDesc()->GetWorkspaceBytes();    

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SHAPE");
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DTYPE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SHAPE");

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);

    set_board_id(0);
    remove(file_name_t);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_GenerateTask_AlltoAllV)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpsKernelBuilder hcomKernelInfo;

    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

    std::string name = "HcomTag";
    nodeptr->GetOpDesc()->SetName(name);

    nodeptr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLTOALLV);
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", HCCL_WORLD_GROUP);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", 5);
    std::vector<int64_t> sendCounts;
    ge::AttrUtils::SetListInt(nodeptr->GetOpDesc(), "send_counts", sendCounts);

    ge::RunContext runContext;
    std::vector<domi::TaskDef> taskDefList;

    s32 ret = hcomKernelInfo.GenerateTask(*nodeptr, runContext, taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(HcomKernelBuilderTest, ut_GenerateTask_AlltoAllVC)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpsKernelBuilder hcomKernelInfo;

    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

    std::string name = "HcomTag";
    nodeptr->GetOpDesc()->SetName(name);

    nodeptr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLTOALLVC);
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", HCCL_WORLD_GROUP);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", 5);
    std::vector<int64_t> sendCounts;
    ge::AttrUtils::SetListInt(nodeptr->GetOpDesc(), "send_counts", sendCounts);

    ge::RunContext runContext;
    std::vector<domi::TaskDef> taskDefList;

    s32 ret = hcomKernelInfo.GenerateTask(*nodeptr, runContext, taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(HcomKernelBuilderTest, ut_generateTask_by_comm_pytorch)
{
    ge::NodePtr nodeptr(new NodeTest);
    ge::Buffer tempBuffer;
    ge::RunContext runContext_dummy;
    HcomOpsKernelBuilder hcomKernelBuilder;

    std::vector<domi::TaskDef> taskDefList;

    rtError_t rt_ret = RT_ERROR_NONE;
    rtStream_t stream;
    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

	std::string name = "HcomTag";
	nodeptr->GetOpDesc()->SetName(name);

    // -------------------HcomBroadcast test----------------
    std::string type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "comm");
    int64_t hcomComm = 645678156;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "comm", hcomComm);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_FISSION");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPSIZE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPTYPE");
    s64 tempInt = 5;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "dest_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "src_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "_fission_factor", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_size", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_type", 0);
    HCCL_INFO("node[%p] run context[%p]", nodeptr.get(), &runContext_dummy);
    HCCL_INFO("----------%s", nodeptr->GetOpDesc()->GetType().c_str());
    s32 ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);

    u32 result_type = taskDefList[0].type();
    u32 result_stream_id = taskDefList[0].stream_id();
    std::string result_hccl_hccl_type = taskDefList[0].mutable_kernel_hccl()->hccl_type();
    std::string result_private_def = taskDefList[0].private_def();
    char private_def_buf[sizeof(HCCL_KERNEL_INFO_PRIVATE_DEF)];
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    HCCL_KERNEL_INFO_PRIVATE_DEF *privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    int64_t result_comm = (privateDefBuf->comm);
    u32 result_srcRank = (privateDefBuf->srcRank);
    u32 result_destRank = (privateDefBuf->destRank);
    u32 result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, 0);

    // -------------------HcomSend test----------------
    type = HCCL_KERNEL_OP_TYPE_SEND;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[1].type();
    result_stream_id = taskDefList[1].stream_id();
    result_hccl_hccl_type = taskDefList[1].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[1].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_comm = (privateDefBuf->comm);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, tempInt);
    EXPECT_EQ(result_srTag, tempInt);

    // Send: 未设定 destRank 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DESTRANK");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");

    // Send: 未设定 srTag 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomReceive test----------------
    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[2].type();
    result_stream_id = taskDefList[2].stream_id();
    result_hccl_hccl_type = taskDefList[2].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[2].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_comm = (privateDefBuf->comm);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, tempInt);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, tempInt);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");

    // Receive: srcRank未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRCRANK");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");

    // Receive: srTag未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomAllReduce test----------------
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");

    // -------------------incalid type test----------------
    type = " ";
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
}

TEST_F(HcomKernelBuilderTest, ut_generateTask_by_comm_pytorch2)
{
    ge::NodePtr nodeptr(new NodeTest);
    ge::Buffer tempBuffer;
    ge::RunContext runContext_dummy;
    HcomOpsKernelBuilder hcomKernelBuilder;

    std::vector<domi::TaskDef> taskDefList;

    rtError_t rt_ret = RT_ERROR_NONE;
    rtStream_t stream;
    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

	std::string name = "HcomTag";
	nodeptr->GetOpDesc()->SetName(name);

    // -------------------HcomBroadcast test----------------
    std::string type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "comm");
    int64_t hcomComm = 0;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "comm", hcomComm);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_FISSION");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPSIZE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPTYPE");
    s64 tempInt = 5;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "dest_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "src_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "_fission_factor", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_size", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_type", 0);
    HCCL_INFO("node[%p] run context[%p]", nodeptr.get(), &runContext_dummy);
    HCCL_INFO("----------%s", nodeptr->GetOpDesc()->GetType().c_str());
    s32 ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);

    u32 result_type = taskDefList[0].type();
    u32 result_stream_id = taskDefList[0].stream_id();
    std::string result_hccl_hccl_type = taskDefList[0].mutable_kernel_hccl()->hccl_type();
    std::string result_private_def = taskDefList[0].private_def();
    char private_def_buf[sizeof(HCCL_KERNEL_INFO_PRIVATE_DEF)];
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    HCCL_KERNEL_INFO_PRIVATE_DEF *privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    int64_t result_comm = (privateDefBuf->comm);
    u32 result_srcRank = (privateDefBuf->srcRank);
    u32 result_destRank = (privateDefBuf->destRank);
    u32 result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, 0);

    // -------------------HcomSend test----------------
    type = HCCL_KERNEL_OP_TYPE_SEND;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[1].type();
    result_stream_id = taskDefList[1].stream_id();
    result_hccl_hccl_type = taskDefList[1].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[1].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_comm = (privateDefBuf->comm);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, 0);
    EXPECT_EQ(result_destRank, tempInt);
    EXPECT_EQ(result_srTag, tempInt);

    // Send: 未设定 destRank 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_DESTRANK");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");

    // Send: 未设定 srTag 时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomReceive test----------------
    type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    result_type = taskDefList[2].type();
    result_stream_id = taskDefList[2].stream_id();
    result_hccl_hccl_type = taskDefList[2].mutable_kernel_hccl()->hccl_type();
    result_private_def = taskDefList[2].private_def();
    sal_memcpy(&private_def_buf[0],sizeof(private_def_buf),result_private_def.c_str(),sizeof(private_def_buf));
    privateDefBuf = (HCCL_KERNEL_INFO_PRIVATE_DEF *)&private_def_buf[0];
    result_comm = (privateDefBuf->comm);
    result_srcRank = (privateDefBuf->srcRank);
    result_destRank = (privateDefBuf->destRank);
    result_srTag = (privateDefBuf->srTag);
    EXPECT_EQ(result_type, RT_MODEL_TASK_HCCL);
    EXPECT_EQ(result_stream_id, streamId);
    EXPECT_EQ(result_hccl_hccl_type, type);
    EXPECT_EQ(result_comm, hcomComm);
    EXPECT_EQ(result_srcRank, tempInt);
    EXPECT_EQ(result_destRank, 0);
    EXPECT_EQ(result_srTag, tempInt);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");

    // Receive: srcRank未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRCRANK");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");

    // Receive: srTag未设定时,报错
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_SRTAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");

    // -------------------HcomAllReduce test----------------
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");

    // -------------------incalid type test----------------
    type = " ";
    nodeptr->GetOpDesc()->SetType(type);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ret = hcomKernelBuilder.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::INTERNAL_ERROR);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
}

HcclResult FakeGetOffDeviceTypeWithoutDev(DevType &devType)
{
    devType = DevType::DEV_TYPE_910B;
    return HCCL_SUCCESS;
}

HcclResult stub_GetVectorFromTensor(const ge::GeTensor* tensor, std::vector<int64_t>& vector)
{
    vector.resize(4*4);
    return HCCL_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_CheckAlltoAllvcRank)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpUtils hcomKernelInfo;
    u32 alltoallvcRank = 0;
    const string sGroup = "test_group";
    const int64_t hcomComm = 0;

    MOCKER(&HcomOpUtils::GetRankId)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ge::AttrUtils::SetInt((*(nodeptr.get())).GetOpDesc(), "rank", alltoallvcRank);
    HcclResult ret = hcomKernelInfo.CheckAlltoAllvcRank(*(nodeptr.get()), hcomComm, sGroup);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpRunningResources_OpenSource)
{
    // 测试开源版本的CalcOpRunningResources
    HcomOpsKernelBuilder kernelBuilder;
    ge::NodePtr nodeptr(new NodeTest);
    std::string sCollectiveType = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    std::string sGroup = "test_group";
    u32 streamNum = 0;
    u64 opMemSize = 0;
    u32 taskNum = 0;
    u32 aivCoreNum = 0;

    // 模拟IsUsingOpenSource返回true,使用开源版本
    MOCKER(IsUsingOpenSource)
    .expects(atMost(1))
    .with(outBound(true))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟SetHcomOpParam函数
    MOCKER(SetHcomOpParam)
    .expects(atMost(1))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCreateOpParamGraphMode函数
    OpParamGraphModePtr opParamPtr = reinterpret_cast<OpParamGraphModePtr>(0x12345678);
    MOCKER(HcceCreateOpParamGraphMode)
    .expects(atMost(1))
    .with(outBound(opParamPtr))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟SetHcclOpParam函数
    MOCKER(SetHcclOpParam)
    .expects(atMost(1))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCalcOpResOfflineGraphMode函数
    MOCKER(HcceCalcOpResOfflineGraphMode)
    .expects(atMost(1))
    .with(mockcpp::any(), outBound(&opMemSize), outBound(&streamNum), outBound(&taskNum), outBound(&aivCoreNum))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟IsOfflineCompilation返回true
    MOCKER(IsOfflineCompilation)
    .expects(atMost(1))
    .will(returnValue(true));

    // 测试CalcOpRunningResources方法
    HcclResult ret = kernelBuilder.CalcOpRunningResources(*nodeptr, sCollectiveType, sGroup, streamNum, opMemSize, taskNum, aivCoreNum);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_SetHcclOpParam)
{
    // 测试SetHcclOpParam函数
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    // 准备测试参数
    HcomOpParam hcomOpParam;
    std::string sCollectiveType;
    OpParamGraphModePtr opParamPtr = reinterpret_cast<OpParamGraphModePtr>(0x12345678);
    std::vector<int64_t> sendCounts;
    std::vector<int64_t> sendDispls;
    std::vector<int64_t> recvCounts;
    std::vector<int64_t> recvDispls;

    // 模拟IsUsingOpenSource返回true,使用开源版本
    MOCKER(IsUsingOpenSource)
    .expects(atMost(1))
    .with(outBound(true))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCreateOpParamGraphMode函数
    MOCKER(HcceCreateOpParamGraphMode)
    .expects(atMost(1))
    .with(outBound(opParamPtr))
    .will(returnValue(HCCL_SUCCESS));

    HcclResult ret = graphOptimizer.SetHcclOpParam(*addedNodePtr0, &hcomOpParam, opParamPtr, sCollectiveType, 
                                                  sendCounts, sendDispls, recvCounts, recvDispls);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_getAlltoAllCountsDispl_across_graph)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpUtils graphOptimizer;
    MOCKER(&ge::OpDescUtils::GetInputConstData)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue((ge::ConstGeTensorBarePtr)nullptr));

    std::vector<int64_t> sendCounts;
    std::vector<int64_t> sendDispls;
    std::vector<int64_t> recvCounts;
    std::vector<int64_t> recvDispls;

    HcclResult ret = graphOptimizer.GetAlltoAllCountsDispl(*(nodeptr.get()), sendCounts, sendDispls, recvCounts, recvDispls);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

HcclResult GetDeviceTypeA2Stub(const char *group, DevType &deviceType) {
    deviceType = DevType::DEV_TYPE_910B;
    return HCCL_SUCCESS;      
}

ge::graphStatus FakeGetOption2(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        },
        {
            {"group_name", "off_group_rank_list"},
            {"group_rank_list", {0, 1, 2, 3, 4, 5, 6, 7}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
    } else if (optionExec == ge::OPTION_EXEC_HCOM_RANK_MAPPING) {
        dumpDebugValue = R"([{"rank_id": "0","device_index": [0,0,0]},{"rank_id": "1","device_index": [0,1,1]}])";
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "2","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "3","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "4","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "5","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "6","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "7","item_id": "0","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == "ge.socVersion") {
        dumpDebugValue = "Ascend910";
    }
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_offlinebuild_calcSubStreamNum)
{
    ge::Status ret;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLV;
    nodeptr->GetOpDesc()->SetType(type);


    std::string curGroup = "aa";
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(FakeGetOption2));

    MOCKER(&ge::AttrUtils::SetInt)
    .stubs()
    .will(returnValue(false));

    ret = hcomOpsKernelInfoStore_.HcomCalcOpRunningParam(*nodeptr);


    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    std::string nodeName = "ALL_GATHER_NO_CALCULATION";
    nodeptr->GetOpDesc()->SetType(type);
    nodeptr->GetOpDesc()->SetName(nodeName);
    ret = hcomOpsKernelInfoStore_.HcomCalcOpRunningParam(*nodeptr);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_offlinebuild_calcSubStreamNumAllToAllVC)
{
    HcclResult ret;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore;
    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLVC;
    nodeptr->GetOpDesc()->SetType(type);

    std::string curGroup = "off_group_rank_list";
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(FakeGetOption2));

    u32 rankId = 0;
    MOCKER(HcomGetRankId)
    .stubs()
    .with(mockcpp::any(), outBound(&rankId))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(&HcomOpUtils::CheckAlltoAllvcRank)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomGraphOptimizer::GetOriginalGraphShapeTypeFromDesc)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_SEND_COUNT_MATRIX");
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank", 7);
    ret = hcomOpsKernelInfoStore.HcomCalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_GenerateTaskDef)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpsKernelBuilder hcomKernelInfo;

    s64 streamId = 10000;
    s64 Id = 1;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);
    std::string name = "HcomTag";
    nodeptr->GetOpDesc()->SetName(name);
    // std::string type = HCCL_KERNEL_OP_TYPE_REMOTE_LOOKUP;
    std::string type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    nodeptr->GetOpDesc()->SetId((s64)Id);
    HCCL_KERNEL_INFO_PRIVATE_DEF privateDefBuf;
    domi::TaskDef taskDef;

    s32 ret = hcomKernelInfo.GenerateTaskDef(*nodeptr, privateDefBuf, taskDef);
    EXPECT_EQ(ret, ge::SUCCESS);
}

ge::graphStatus TaskNumGetOption(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        },
        {
            {"group_name", "off_group_rank_list"},
            {"group_rank_list", {0, 1, 2, 3, 4, 5, 6, 7}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
    } else if (optionExec == ge::OPTION_EXEC_HCOM_RANK_MAPPING) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "-1","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "2","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "3","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "4","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "5","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "6","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "7","item_id": "0","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == "ge.socVersion") {
        dumpDebugValue = "Ascend910";
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE_FILE) {
        dumpDebugValue = "./ut_task_num_one_server_hcom_test.json";
    } else if (optionExec == "ge.offline_hccl_compile") {
        return ge::GRAPH_FAILED;
    }
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpTaskNum)
{
    HcclResult ret;
    nlohmann::json rank_table = rank_table_910_2server_8rank;
    char file_name_t[] = "./ut_task_num_one_server_hcom_test.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    int64_t RANK_SIZE = 4;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    std::string name = HCCL_KERNEL_OP_TYPE_ALLREDUCE + "1server";
    nodeptr->GetOpDesc()->SetName(name);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(TaskNumGetOption));

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);

    GlobalMockObject::verify();
}

ge::graphStatus OfflineRankMappingOption1(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        },
        {
            {"group_name", "off_group_rank_list"},
            {"group_rank_list", {0, 1, 2, 3, 4, 5, 6, 7}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
        return ge::GRAPH_SUCCESS;
    } else if (optionExec == "ge.exec.rankTable" || optionExec == "ge.offline_hccl_compile" ||
        optionExec == "ge.exec.hcomRankMapping") {
        return ge::GRAPH_FAILED;
    } else if (optionExec == "ge.exec.rankMap") {
        dumpDebugValue = R"({"rank_map":[{"logic_rank_id":1,"model_rank_id":0},{"logic_rank_id":2,"model_rank_id":1}]})";
        return ge::GRAPH_SUCCESS;
    } else if (optionExec == "ge.socVersion") {
	    dumpDebugValue = "Ascend910";
	    return ge::GRAPH_SUCCESS;
	} else if (optionExec == "ge.exec.rankTableFile") {
	    dumpDebugValue = "./ut_task_num_one_server_stream_test.json";
	    return ge::GRAPH_SUCCESS;
	}
    dumpDebugValue.push_back('1');
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomKernelBuilderTest, ut_CalcOpTaskNum_1server_stream)
{
    HcclResult ret;
    nlohmann::json rank_table = rank_table_1server_8rank;
    char file_name_t[] = "./ut_task_num_one_server_stream_test.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    int64_t RANK_SIZE = 4;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    std::string tempStr = "aa";
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(OfflineRankMappingOption1));
    std::string type;
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    std::string name = HCCL_KERNEL_OP_TYPE_ALLREDUCE + "1server";
    nodeptr->GetOpDesc()->SetName(name);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    s32 deviceNumPerServer = 8;
    s32 serverNum = 9;

    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);

    GlobalMockObject::verify();
}
TEST_F(HcomKernelBuilderTest, ut_CalcOpTaskNum_1server)
{
    HcclResult ret;
    nlohmann::json rank_table = rank_table_1server_8rank;
    char file_name_t[] = "./ut_task_num_one_server_hcom_test.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomOpsKernelBuilder hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    int64_t RANK_SIZE = 4;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    std::string name = HCCL_KERNEL_OP_TYPE_ALLREDUCE + "1server";
    nodeptr->GetOpDesc()->SetName(name);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(TaskNumGetOption));

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);

    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_GenerateTaskAivCoreLimit)
{
    ge::NodePtr nodeptr(new NodeTest);
    ge::Buffer tempBuffer;
    ge::RunContext runContext_dummy;
    HcomOpsKernelBuilder hcomKernelInfo;

    std::vector<domi::TaskDef> taskDefList;

    rtError_t rt_ret = RT_ERROR_NONE;
    rtStream_t stream;
    s64 streamId = 10000;
    nodeptr->GetOpDesc()->SetStreamId((s64)streamId);

	std::string name = "HcomTag";
	nodeptr->GetOpDesc()->SetName(name);

    std::string type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRCRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DESTRANK");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_SRTAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_FALSE_TAG");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_FISSION");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPSIZE");
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_DUMPTYPE");
    s64 tempInt = 5;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "dest_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "src_rank", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "sr_tag", tempInt);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "_fission_factor", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_size", 1);
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "global_workspace_type", 0);
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "DUMMY_SET_TRUE_NEEDMAPRANK");
    ge::AttrUtils::SetBool(nodeptr->GetOpDesc(), "_need_map_rank_id", true);

    // -------------------aiv core limit----------------
    // AivCoreLimit use default value 48
    HCCL_KERNEL_INFO_PRIVATE_DEF privateDef;
    MOCKER_CPP(&HcomOpsKernelBuilder::GenerateTaskPrivateDef).stubs().with(mockcpp::any(), spy(privateDef), mockcpp::any(), mockcpp::any()).will(returnValue(HCCL_SUCCESS));

    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "ATTR_OP_VECTORCORE_NUM_CLEAR");
    s32 ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    EXPECT_EQ(privateDef.aivCoreLimit, 48U);

    // AivCoreLimit from _op_vectorcore_num
    ge::AttrUtils::HasAttr(nodeptr->GetOpDesc(), "ATTR_OP_VECTORCORE_NUM");
    std::string aviCoreNum("4");
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "_op_vectorcore_num", aviCoreNum);
    ret = hcomKernelInfo.GenerateTask(*nodeptr,runContext_dummy,taskDefList);
    EXPECT_EQ(ret, ge::SUCCESS);
    EXPECT_EQ(privateDef.aivCoreLimit, 4U);

    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, ut_offlinebuild_calcSubStreamNumv2_When_Normal_Expect_ReturnlsHCCL_SUCCESS)
{
    ge::Status ret;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLV;
    nodeptr->GetOpDesc()->SetType(type);


    std::string curGroup = "aa";
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(FakeGetOption2));

    MOCKER(&ge::AttrUtils::SetInt)
    .stubs()
    .will(returnValue(false));

#ifdef MACRO_DEV_TYPE_NEW
    MOCKER(HcomGetDeviceType).stubs().with(mockcpp::any()).will(returnValue(DevType::DEV_TYPE_950));
#else
    MOCKER(HcomGetDeviceType).stubs().with(mockcpp::any()).will(returnValue(DevType::DEV_TYPE_910_95));
#endif
    ret = hcomOpsKernelInfoStore_.HcomCalcOpRunningParam(*nodeptr);

    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    std::string nodeName = "ALL_GATHER_NO_CALCULATION";
    nodeptr->GetOpDesc()->SetType(type);
    nodeptr->GetOpDesc()->SetName(nodeName);
    ret = hcomOpsKernelInfoStore_.HcomCalcOpRunningParam(*nodeptr);
    GlobalMockObject::verify();
}

TEST_F(HcomKernelBuilderTest, Ut_GetCrackParamsInfo_When_AllParamsValid_Expect_CorrectCrackParams) {
    ge::Node node;
    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    u32 tensorNum = 2;
    int64_t tensorOffset[2] = {0, 100};
    int64_t tensorSize[2] = {50, 60};
    int64_t crackOffset[2] = {0, 0};
    int64_t crackSize[2] = {0, 0};

    HcclResult result = hcomOpsKernelInfoStore_.GetCrackParamsInfo(node, tensorNum, tensorOffset, tensorSize, crackOffset, crackSize);
    EXPECT_EQ(result, HCCL_SUCCESS);

    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLV;
    nodeptr->GetOpDesc()->SetType(type);
    std::string curGroup = "aa";
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);
    MOCKER_CPP(&HcomOpUtils::GetAllTensorSize).stubs().will(returnValue(HCCL_SUCCESS));
    result = hcomOpsKernelInfoStore_.GetTensorParamsInfo(*nodeptr, tensorNum, tensorOffset, tensorSize);
    EXPECT_EQ(result, HCCL_SUCCESS);
}

TEST_F(HcomKernelBuilderTest, Ut_SetPrivateDefWithTensorInfo) {
    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLV;
    nodeptr->GetOpDesc()->SetType(type);
    std::string curGroup = "aa";
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);
    HCCL_KERNEL_INFO_PRIVATE_DEF privateDefBuf;
    privateDefBuf.tensorNum = 2;
    domi::TaskDef taskDef;

    MOCKER_CPP(&HcomOpsKernelBuilder::GetTensorParamsInfo).stubs().will(returnValue(HCCL_SUCCESS));
    MOCKER_CPP(&HcomOpsKernelBuilder::GetCrackParamsInfo).stubs().will(returnValue(HCCL_SUCCESS));

    HcomOpsKernelBuilder hcomOpsKernelInfoStore_;
    HcclResult result = hcomOpsKernelInfoStore_.SetPrivateDefWithTensorInfo(*nodeptr, privateDefBuf, taskDef);
    EXPECT_EQ(result, HCCL_SUCCESS);
}