/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include <stdio.h>

#define private public
#define protected public
#include "hcom_all_reduce_fusion.h"
#include "hcom_broadcast_fusion.h"
#include "hcom_reduce_fusion.h"
#include "hcom_all_reduce_fusion.h"
#include "hcom_alltoallvc_fusion.h"
#include "hcom_allgather_fusion.h"
#include "hcom_reducescatter_fusion.h"
#include "plugin_manager.h"
#include "hcom_fusion_optimizer.h"
#include "offline_build_config_parse.h"
#include "hcom_graph_optimizer.h"
#include "hcom_op_utils.h"
#undef private
#undef protected
#include "hccl_stub.h"
#include "v80_rank_table.h"
#include "hcom_ops_kernel_info_store.h"
#include "external/ge/ge_api_types.h" // ge对内options
#include "common/ge_common/ge_types.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/utils/tensor_utils.h"
#include "graph/ge_tensor.h"

#include "evaluator.h"
#include "model.h"
#include "cluster.h"

#include <iostream>
#include <fstream>
#include "llt_hccl_stub_ge.h"

using namespace std;
using namespace hccl;

class HcomGraphOptimizerTest : public testing::Test
{
protected:
    static void SetUpTestCase()
    {
        std::cout << "\033[36m--HcomGraphOptimizerTest SetUP--\033[0m" << std::endl;
    }
    static void TearDownTestCase()
    {

        std::cout << "\033[36m--HcomGraphOptimizerTest TearDown--\033[0m" << std::endl;
    }
    // Some expensive resource shared by all tests.
    virtual void SetUp()
    {
        MOCKER_CPP(&HcomGraphOptimizer::SetSuperKernelScopeAttr).stubs().with(mockcpp::any()).will(returnValue(HCCL_SUCCESS));
        std::cout << "A Test SetUP" << std::endl;
    }
    virtual void TearDown()
    {
        std::cout << "A Test TearDown" << std::endl;
    }
};

TEST_F(HcomGraphOptimizerTest, ut_optimize_graphprepare_SetHcomOpParallelLabel)
{
    ge ::Status ret;
    ge ::Status ge_ret;
    bool bRet;
    std::map<std::string,std::string> options;
    ge::ComputeGraph graph("test_graph");
    std::map<string, GraphOptimizerPtr> graphOptimizers;
    ge::OpDescPtr opDescPtr = nullptr;

    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    // 实验室场景 hcom_init成功:成功
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge::OpDesc op;
    graph.AddNode(opDescPtr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    // const string typeEs = HCCL_KERNEL_OP_TYPE_COLL_REMOTE_LOOKUP_PAIRED;
    // op.SetType(typeEs);

    HcomGraphOptimizer optimizer;
    std::string groupLabel = "hcom_op";
    for (auto nodePtr : graph.GetAllNodes()) {
        HcclResult hRet = optimizer.SetHcomOpParallelLabel(*nodePtr, groupLabel);
        EXPECT_EQ(hRet, HCCL_SUCCESS);
    }

    bRet = ge::AttrUtils::SetInt(opDescPtr, "tag", 1);
    EXPECT_EQ(bRet, true);

    for (auto nodePtr : graph.GetAllNodes()) {
        HcclResult hRet = optimizer.SetHcomOpParallelLabel(*nodePtr, groupLabel);
        EXPECT_EQ(hRet, HCCL_SUCCESS);
    }

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_Initialize_to_Finalize)
{
    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));
    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    // 实验室场景 hcom_init成功:成功
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    MOCKER_CPP(&HcomAllReduceFusion::GetFusionOps)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    // ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    auto descPtr1 = std::make_shared<ge::OpDesc>("Allreduce1", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr1->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr1 = graph->AddNode(descPtr1);
    EXPECT_NE(addedNodePtr1, nullptr);

    ge::ComputeGraphPtr subgraph = std::make_shared<ge::ComputeGraph>("test_subgraph");
    auto descPtr2 = std::make_shared<ge::OpDesc>("Allreduce2", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr2->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr2 = subgraph->AddNode(descPtr2);
    EXPECT_NE(addedNodePtr2, nullptr);

    auto descPtr3 = std::make_shared<ge::OpDesc>("Allreduce3", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr3->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr3 = subgraph->AddNode(descPtr3);
    EXPECT_NE(addedNodePtr3, nullptr);

    graph->AddSubGraph(subgraph);

    HCCL_INFO("start OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ge::OpDescPtr opDescPtr = nullptr;
    ge::AttrUtils::SetStr(opDescPtr, "reduction", "sum"); 
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    
    ret = Finalize();
    EXPECT_EQ(ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_Initialize_to_Finalize_51)
{

    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));
    u32 numHccsLink = 0;
    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    u32 rankSize = 2;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBoundP(&rankSize))
    .will(returnValue(HCCL_SUCCESS));

    // 实验室场景 hcom_init成功:成功
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    MOCKER_CPP(&HcomAllReduceFusion::GetFusionOps)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    auto descPtr1 = std::make_shared<ge::OpDesc>("Allreduce1", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr1->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr1 = graph->AddNode(descPtr1);
    EXPECT_NE(addedNodePtr1, nullptr);

    ge::ComputeGraphPtr subgraph = std::make_shared<ge::ComputeGraph>("test_subgraph");
    auto descPtr2 = std::make_shared<ge::OpDesc>("Allreduce2", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr2->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr2 = subgraph->AddNode(descPtr2);
    EXPECT_NE(addedNodePtr2, nullptr);

    auto descPtr3 = std::make_shared<ge::OpDesc>("Allreduce3", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr3->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr3 = subgraph->AddNode(descPtr3);
    EXPECT_NE(addedNodePtr3, nullptr);

    graph->AddSubGraph(subgraph);

    HCCL_INFO("start OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;

    DevType type610 = DevType::DEV_TYPE_310P1;
    MOCKER(GetOffDeviceTypeWithoutDev)
    .stubs()
    .with(outBound(type610))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    
    ret = Finalize();
    EXPECT_EQ(ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

HcclResult GetOffDeviceTypeWithoutDevMockA2(DevType &devType)
{
    devType = DevType::DEV_TYPE_910B;
    HCCL_DEBUG("[offline] Get devtype[%u]....", devType);
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_OptimizeFusedGraph_GetDeterministic)
{
    MOCKER(GetOffDeviceTypeWithoutDev).stubs().will(invoke(GetOffDeviceTypeWithoutDevMockA2));
    setenv("HCCL_DETERMINISTIC", "STRICT", 1);

    u8 deterministic = 0;
    HcclResult ret = GetDeterministic(deterministic);

    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(deterministic, 2);

    unsetenv("HCCL_DETERMINISTIC");
    GlobalMockObject::verify();
}

HcclResult stub_GetAllInputsTensorMemSize(const ge::OpDescPtr &opDescPtr, uint64_t &tensorSize)
{
    tensorSize = 2048 * 1024 * 1024;
    return HCCL_SUCCESS;
}

HcclResult stub_GetCCLBufferAvailableSize(u64 &size)
{
    size = 1024 * 1024;
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(7);

    std::string tempStrReduction = "sum";
    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "reduction", tempStrReduction);

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    ge::AttrUtils::HasAttr(ops[1]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "group", tempStr);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[1], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 1);
    EXPECT_EQ(fusionOps["hccl_world_group"].size(), 0);

    tempStr = "test_group";
    ge::AttrUtils::SetStr(ops[2]->GetOpDesc(), "group", tempStr);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[2], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 3);
    EXPECT_EQ(fusionOps["test_group"].size(), 0);

    tempStr = "test_group";
    ge::AttrUtils::SetStr(ops[3]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion", HCOM_ATTR_FUSION_BY_FUSION_ID);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion_id", 0);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[3], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 5);
    EXPECT_EQ(fusionOps["test_group"].size(), 0);

    tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(ops[4]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion", 0);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion_id", HCOM_ATTR_FUSION_ID_DEFAULT);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[4], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 5);
    EXPECT_EQ(fusionOps["hccl_world_group"].size(), 0);

    tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(ops[5]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[5]->GetOpDesc(), "fusion", 3);
    ge::AttrUtils::SetInt(ops[5]->GetOpDesc(), "fusion_id", HCOM_ATTR_FUSION_ID_DEFAULT);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[5], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);
    EXPECT_EQ(fusionOps.size(), 5);
    EXPECT_EQ(fusionOps["hccl_world_group"].size(), 0);

    tempStr = "test_group";
    ge::AttrUtils::SetStr(ops[6]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[6]->GetOpDesc(), "fusion", HCOM_ATTR_FUSION_BY_FUSION_ID);
    ge::AttrUtils::SetInt(ops[6]->GetOpDesc(), "fusion_id", 0);
    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[6], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ge::AttrUtils::HasAttr(ops[3]->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");

    MOCKER(GetCCLBufferAvailableSize)
    .stubs()
    .with(mockcpp::any())
    .will(invoke(stub_GetCCLBufferAvailableSize));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));

    FusionInfos fusionOpsTemp;
    ret = fusionHcomAllReduceOp.GetFusionOpsSlices(fusionOps, fusionOpsTemp);
}

static std::vector<u32> g_mock_segments = {1, 2};
static HcclResult mock_HcomGetSplitStrategy(
    const char* name,
    const model_feature* feat,
    u32** data_out,
    u32 * size_out,
    bool* configured) {
    
    if (data_out) *data_out = g_mock_segments.data();
    if (size_out) *size_out = g_mock_segments.size();
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_FuseOps)
{
    HcclResult ret;
    bool bRet;
    ge::ComputeGraph graph("test_graph");
    ge::OpDescPtr nodeGroup = nullptr;
    ge::OpDescPtr nodeFusionId = 0;
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::map<std::string, std::vector<ge::NodePtr>> fusionOps;
    u32 segmentNum = 0;
    std::vector<u32> segmentIndex;

    std::vector<ge::NodePtr> ops(3);
    std::vector<ge::NodePtr> nodeVec_0;
    std::string group = HCCL_WORLD_GROUP;
    int64_t fusionid = HCOM_ATTR_FUSION_ID_DEFAULT;
    bRet = ge::AttrUtils::SetStr(nodeGroup, "group", group);
    EXPECT_EQ(bRet, true);
    bRet = ge::AttrUtils::SetInt(nodeFusionId, "fusion_id", fusionid);
    EXPECT_EQ(bRet, true);
    nodeVec_0.push_back(ops[0]);
    nodeVec_0.push_back(ops[1]);
    nodeVec_0.push_back(ops[2]);
    fusionOps["hccl_world_group"] = nodeVec_0;

    MOCKER(HcomGetSplitStrategy)
    .stubs()
    .with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
    .will(invoke(mock_HcomGetSplitStrategy));

    MOCKER_CPP(&HcomAllReduceFusion::RunFusionOps)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = fusionHcomAllReduceOp.FuseOps(graph, fusionOps["hccl_world_group"]);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionStrategy)
{
    HcclResult ret;
    bool bRet;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    ge::OpDescPtr nodeGroup = nullptr;
    ge::OpDescPtr nodeFusionId = 0;
    std::vector<ge::NodePtr> fusionOps(3);
    u32 segmentNum = 0;
    std::vector<u32> segmentIndex;

    u32 segment_num = 2;      // fused to 2 allreduce node
    std::vector<u32> segments;
    bool configured = false;
    segments.push_back(1);
    segments.push_back(2);
    MOCKER(HcomGetSplitStrategy)
    .stubs()
    .with(mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any(), mockcpp::any())
    .will(invoke(mock_HcomGetSplitStrategy));
    std::string group = HCCL_WORLD_GROUP;
    int64_t fusionid = HCOM_ATTR_FUSION_ID_DEFAULT;
    ge::AttrUtils::HasAttr(nodeGroup, "DUMMY_SET_TRUE_GROUP");
    bRet = ge::AttrUtils::SetStr(nodeGroup, "group", group);
    EXPECT_EQ(bRet, true);
    bRet = ge::AttrUtils::SetInt(nodeFusionId, "fusion_id", fusionid);
    EXPECT_EQ(bRet, true);
    ret = fusionHcomAllReduceOp.GetFusionStrategy(graph, fusionOps, segmentNum, segmentIndex);
    EXPECT_EQ(segmentNum, segment_num);
    EXPECT_EQ(segmentIndex[0], segments[0]);
    EXPECT_EQ(segmentIndex[1], segments[1]);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetNodeUnknownShapeInfo_known)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    u32 segmentNum = 0;
    std::vector<u32> segmentIndex;
    bool bUnknownShapeNodeStatus = false;

    ret = fusionHcomAllReduceOp.GetNodeUnknownShapeInfo(fusionOps[0], bUnknownShapeNodeStatus);
    EXPECT_EQ(bUnknownShapeNodeStatus, false);
    EXPECT_EQ(fusionHcomAllReduceOp.bHasUnknownShapeNodeGraph_, false);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ret = fusionHcomAllReduceOp.GetNodeUnknownShapeInfo(fusionOps[0], bUnknownShapeNodeStatus);
    EXPECT_EQ(bUnknownShapeNodeStatus, false);
    EXPECT_EQ(fusionHcomAllReduceOp.bHasUnknownShapeNodeGraph_, false);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetNodeUnknownShapeInfo_unknown)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    u32 segmentNum = 0;
    std::vector<u32> segmentIndex;
    bool is_unknown = true;
    bool bUnknownShapeNodeStatus = false;

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(is_unknown))
    .will(returnValue(ge::GRAPH_SUCCESS));

    ret = fusionHcomAllReduceOp.GetNodeUnknownShapeInfo(fusionOps[0], bUnknownShapeNodeStatus);
    EXPECT_EQ(bUnknownShapeNodeStatus, true);
    EXPECT_EQ(fusionHcomAllReduceOp.bHasUnknownShapeNodeGraph_, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ret = fusionHcomAllReduceOp.GetNodeUnknownShapeInfo(fusionOps[0], bUnknownShapeNodeStatus);
    EXPECT_EQ(bUnknownShapeNodeStatus, true);
    EXPECT_EQ(fusionHcomAllReduceOp.bHasUnknownShapeNodeGraph_, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

class NodeTest : public ge::Node {
public:
    NodeTest(){;};
    ~NodeTest(){;};
};

TEST_F(HcomGraphOptimizerTest, ut_RunFusionOps)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    std::vector<ge::NodePtr> InOps(3);
    std::vector<ge::NodePtr> OutOps(3);

    fusionOps[0] = std::make_shared<NodeTest>();
    fusionOps[0]->Init();
    fusionOps[1] = std::make_shared<NodeTest>();
    fusionOps[1]->Init();
    fusionOps[2] = std::make_shared<NodeTest>();
    fusionOps[2]->Init();

    InOps[0] = std::make_shared<NodeTest>();
    InOps[0]->Init();
    InOps[1] = std::make_shared<NodeTest>();
    InOps[1]->Init();
    InOps[2] = std::make_shared<NodeTest>();
    InOps[2]->Init();

    OutOps[0] = std::make_shared<NodeTest>();
    OutOps[0]->Init();
    OutOps[1] = std::make_shared<NodeTest>();
    OutOps[1]->Init();
    OutOps[2] = std::make_shared<NodeTest>();
    OutOps[2]->Init();

    InOps[0]->GetOutControlAnchor()->LinkTo(fusionOps[0]->GetInControlAnchor());
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[0]->GetInDataAnchor(0));
    InOps[1]->GetOutDataAnchor(0)->LinkTo(fusionOps[1]->GetInControlAnchor());
    InOps[2]->GetOutDataAnchor(0)->LinkTo(fusionOps[2]->GetInDataAnchor(0));

    fusionOps[0]->GetOutControlAnchor()->LinkTo(OutOps[0]->GetInControlAnchor());
    fusionOps[0]->GetOutDataAnchor(0)->LinkTo(OutOps[0]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInControlAnchor());
    fusionOps[2]->GetOutDataAnchor(0)->LinkTo(OutOps[2]->GetInDataAnchor(0));

    u32 segmentNum = 2;
    std::vector<u32> segmentIndex;
    segmentIndex.push_back(0);
    segmentIndex.push_back(1);

    ret = fusionHcomAllReduceOp.RunFusionOps(graph, fusionOps, segmentNum, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}


TEST_F(HcomGraphOptimizerTest, ut_RunFusionOps_have_duplication)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    std::vector<ge::NodePtr> InOps(3);
    std::vector<ge::NodePtr> OutOps(3);

    fusionOps[0] = std::make_shared<NodeTest>();
    fusionOps[0]->Init();
    fusionOps[1] = std::make_shared<NodeTest>();
    fusionOps[1]->Init();
    fusionOps[2] = std::make_shared<NodeTest>();
    fusionOps[2]->Init();

    InOps[0] = std::make_shared<NodeTest>();
    InOps[0]->Init();
    InOps[1] = std::make_shared<NodeTest>();
    InOps[1]->Init();
    InOps[2] = std::make_shared<NodeTest>();
    InOps[2]->Init();

    OutOps[0] = std::make_shared<NodeTest>();
    OutOps[0]->Init();
    OutOps[1] = std::make_shared<NodeTest>();
    OutOps[1]->Init();
    OutOps[2] = std::make_shared<NodeTest>();
    OutOps[2]->Init();

    InOps[0]->GetOutControlAnchor()->LinkTo(fusionOps[0]->GetInControlAnchor());
    // InOps_0 link to fusionOps_0
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[0]->GetInDataAnchor(0));
    // InOps_0 link to fusionOps_1
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[1]->GetInDataAnchor(0));
    InOps[2]->GetOutDataAnchor(0)->LinkTo(fusionOps[2]->GetInDataAnchor(0));

    fusionOps[0]->GetOutControlAnchor()->LinkTo(OutOps[0]->GetInControlAnchor());
    fusionOps[0]->GetOutDataAnchor(0)->LinkTo(OutOps[0]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInControlAnchor());
    fusionOps[2]->GetOutDataAnchor(0)->LinkTo(OutOps[2]->GetInDataAnchor(0));

    u32 segmentNum = 2;
    std::vector<u32> segmentIndex;
    segmentIndex.push_back(0);
    segmentIndex.push_back(1);

    ret = fusionHcomAllReduceOp.RunFusionOps(graph, fusionOps, segmentNum, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}



TEST_F(HcomGraphOptimizerTest, ut_RunFusionOps_bcast)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomBroadcastFusion fusionHcomBroadcastOp;
    std::vector<ge::NodePtr> fusionOps(3);
    std::vector<ge::NodePtr> InOps(3);
    std::vector<ge::NodePtr> OutOps(3);

    fusionOps[0] = std::make_shared<NodeTest>();
    fusionOps[0]->Init();
    fusionOps[1] = std::make_shared<NodeTest>();
    fusionOps[1]->Init();
    fusionOps[2] = std::make_shared<NodeTest>();
    fusionOps[2]->Init();

    InOps[0] = std::make_shared<NodeTest>();
    InOps[0]->Init();
    InOps[1] = std::make_shared<NodeTest>();
    InOps[1]->Init();
    InOps[2] = std::make_shared<NodeTest>();
    InOps[2]->Init();

    OutOps[0] = std::make_shared<NodeTest>();
    OutOps[0]->Init();
    OutOps[1] = std::make_shared<NodeTest>();
    OutOps[1]->Init();
    OutOps[2] = std::make_shared<NodeTest>();
    OutOps[2]->Init();

    InOps[0]->GetOutControlAnchor()->LinkTo(fusionOps[0]->GetInControlAnchor());
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[0]->GetInDataAnchor(0));
    InOps[1]->GetOutDataAnchor(0)->LinkTo(fusionOps[1]->GetInControlAnchor());
    InOps[2]->GetOutDataAnchor(0)->LinkTo(fusionOps[2]->GetInDataAnchor(0));

    fusionOps[0]->GetOutControlAnchor()->LinkTo(OutOps[0]->GetInControlAnchor());
    fusionOps[0]->GetOutDataAnchor(0)->LinkTo(OutOps[0]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInControlAnchor());
    fusionOps[2]->GetOutDataAnchor(0)->LinkTo(OutOps[2]->GetInDataAnchor(0));

    u32 segmentNum = 2;
    std::vector<u32> segmentIndex;
    segmentIndex.push_back(0);
    segmentIndex.push_back(1);

    ret = fusionHcomBroadcastOp.RunFusionOps(graph, fusionOps, segmentNum, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, ut_HcomReduceScatter_OptimizeFusedGraph)
{
    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("reducescatter0", HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
    bRet = ge::AttrUtils::SetInt(descPtr0, HCOM_ATTR_RANK_SIZE, 8);
    EXPECT_EQ(bRet, true);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    HCCL_INFO("start hccl OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end hccl OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));
    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_HcomAllGather_OptimizeFusedGraph)
{
    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("allgather0", HCCL_KERNEL_OP_TYPE_ALLGATHER);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_ALLGATHER);
    bRet = ge::AttrUtils::SetInt(descPtr0, HCOM_ATTR_RANK_SIZE, 8);
    EXPECT_EQ(bRet, true);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    HCCL_INFO("start hccl OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end hccl OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));
    MOCKER(HcomGetDeviceType)
    .stubs()
    .will(returnValue(DevType::DEV_TYPE_310P3));
    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_HcomRemoteRead_OptimizeFusedGraph)
{
    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    HCCL_INFO("start hccl OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end hccl OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_HcomReveive_OptimizeFusedGraph)
{
    ge ::Status ret;
    std::map<std::string,std::string> options;


    // 未设置 rank table:失败
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_TABLE_FILE,"rank_table.json"));
    // 实验室场景 设置 rank id
    options.insert(pair<string,string> (ge::OPTION_EXEC_DEPLOY_MODE,"0"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_RANK_ID,"1"));
    options.insert(pair<string,string> (ge::OPTION_EXEC_PROFILING_MODE,"0"));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = Initialize(options);
    EXPECT_EQ(ret, ge::SUCCESS);

    ge ::Status ge_ret;
    bool bRet;

    std::map<string, GraphOptimizerPtr> graphOptimizers;
    GetGraphOptimizerObjs(graphOptimizers);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Initialize(options, nullptr);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    EXPECT_EQ(attrs.engineName, HCCL_OPS_ENGIN);
    EXPECT_EQ(attrs.scope, ge::UNIT);

    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_RECEIVE;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("remoteread0", HCCL_KERNEL_OP_TYPE_RECEIVE);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_RECEIVE);
    bRet = ge::AttrUtils::SetInt(descPtr0, HCOM_ATTR_RANK_SIZE, 8);
    EXPECT_EQ(bRet, true);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    HCCL_INFO("start hccl OptimizeOriginalGraph");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeOriginalGraph(*graph);
    HCCL_INFO("end hccl OptimizeOriginalGraph");
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));
    ge::AttrUtils::HasAttr(*descPtr0, "DUMMY_SET_TRUE_DTYPE");
    ge::AttrUtils::HasAttr(*descPtr0, "DUMMY_SET_TRUE_SHAPE");
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ge_ret = graphOptimizers.at(HCCL_GRAPH_OPTIMIZER_NAME)->Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_OriginalGraphShapeTypeCfg)
{
    ge::Status ge_ret;
    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(true))
    .will(returnValue(ge::GRAPH_SUCCESS));
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    const string type = "HcomAllReduce";
    ge::OpDescPtr opPtr_ = std::make_shared<ge::OpDesc>();
    opPtr_->SetType(type);
    graph->AddNode(opPtr_);
    ge_ret = graphOptimizer.OptimizeOriginalGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

ge::graphStatus GetOption1(ge::GEContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    return ge::GRAPH_FAILED;
}

TEST_F(HcomGraphOptimizerTest, ut_SetUnknownShapAttr)
{
    HcclResult ret;
    int64_t memSize = 200*1024*1024 + 1;
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    const string type = "HcomAllReduce";
    ge::OpDescPtr opPtr_ = std::make_shared<ge::OpDesc>();
    opPtr_->SetType(type);
    graph->AddNode(opPtr_);

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(false))
    .will(returnValue(ge::GRAPH_SUCCESS));

    MOCKER_CPP(&ge::GEContext::GetOption)
    .stubs()
    .will(invoke(GetOption1));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    const string hcclType = "HcomAllGather";
    opPtr_->SetType(hcclType);
    graph->AddNode(opPtr_);
    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    MOCKER(HcomGetDeviceType)
    .stubs()
    .will(returnValue(DevType::DEV_TYPE_310P3));

    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_SetUnknownShapAttr_AlltoAllv)
{
    HcclResult ret;
    int64_t memSize = 200*1024*1024 + 1;
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    const string type = "HcomAllToAllV";
    ge::OpDescPtr opPtr_ = std::make_shared<ge::OpDesc>();
    opPtr_->SetType(type);
    graph->AddNode(opPtr_);

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(false))
    .will(returnValue(ge::GRAPH_SUCCESS));

    MOCKER_CPP(&ge::GEContext::GetOption)
    .stubs()
    .will(invoke(GetOption1));
    
    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_Bcast)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomBroadcastFusion fusionHcomBroadcastOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(5);

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    ge::AttrUtils::HasAttr(ops[1]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    std::string tempStr = HCCL_WORLD_GROUP;
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 2);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 2);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[1], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 1);

    ge::AttrUtils::SetStr(ops[2]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "fusion_id", 3);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[2], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ge::AttrUtils::SetStr(ops[3]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion_id", -1);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[3], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    ge::AttrUtils::SetStr(ops[4]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion_id", 3);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[4], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "group", "hccl_world_group");
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion_id", -1);
}

TEST_F(HcomGraphOptimizerTest, ut_FuseOps_Bcast)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomBroadcastFusion fusionHcomBroadcastOp;
    std::map<std::string, std::vector<ge::NodePtr>> fusionOps;

    std::vector<ge::NodePtr> ops(3);
    std::vector<ge::NodePtr> nodeVec_0;
    nodeVec_0.push_back(ops[0]);
    nodeVec_0.push_back(ops[1]);
    nodeVec_0.push_back(ops[2]);
    fusionOps["hccl_world_group"] = nodeVec_0;


    MOCKER_CPP(&HcomBroadcastFusion::RunFusionOps)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = fusionHcomBroadcastOp.FuseOps(graph, nodeVec_0);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
	ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "group", "hccl_world_group");
}

TEST_F(HcomGraphOptimizerTest, ut_RunFusionOps_reduce)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomReduceFusion fusionHcomReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    std::vector<ge::NodePtr> InOps(3);
    std::vector<ge::NodePtr> OutOps(3);

    fusionOps[0] = std::make_shared<NodeTest>();
    fusionOps[0]->Init();
    fusionOps[1] = std::make_shared<NodeTest>();
    fusionOps[1]->Init();
    fusionOps[2] = std::make_shared<NodeTest>();
    fusionOps[2]->Init();

    InOps[0] = std::make_shared<NodeTest>();
    InOps[0]->Init();
    InOps[1] = std::make_shared<NodeTest>();
    InOps[1]->Init();
    InOps[2] = std::make_shared<NodeTest>();
    InOps[2]->Init();

    OutOps[0] = std::make_shared<NodeTest>();
    OutOps[0]->Init();
    OutOps[1] = std::make_shared<NodeTest>();
    OutOps[1]->Init();
    OutOps[2] = std::make_shared<NodeTest>();
    OutOps[2]->Init();

    InOps[0]->GetOutControlAnchor()->LinkTo(fusionOps[0]->GetInControlAnchor());
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[0]->GetInDataAnchor(0));
    InOps[1]->GetOutDataAnchor(0)->LinkTo(fusionOps[1]->GetInControlAnchor());
    InOps[2]->GetOutDataAnchor(0)->LinkTo(fusionOps[2]->GetInDataAnchor(0));

    fusionOps[0]->GetOutControlAnchor()->LinkTo(OutOps[0]->GetInControlAnchor());
    fusionOps[0]->GetOutDataAnchor(0)->LinkTo(OutOps[0]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInControlAnchor());
    fusionOps[2]->GetOutDataAnchor(0)->LinkTo(OutOps[2]->GetInDataAnchor(0));

    ret = fusionHcomReduceOp.RunFusionOpsReduce(graph, fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_Reduce)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomReduceFusion fusionHcomReduceOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(5);

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    std::string tempStrReduction = "sum";
    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "reduction", tempStrReduction);

    ge::AttrUtils::HasAttr(ops[1]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    std::string tempStr = HCCL_WORLD_GROUP;
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 0);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 1);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 1);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[1], fusionOps);

    ge::AttrUtils::SetStr(ops[1]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[1], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(fusionOps.size(), 1);
    EXPECT_EQ(fusionOps["hccl_world_group"].size(), 0);

    ge::AttrUtils::SetStr(ops[2]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "fusion_id", 3);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "fusion", 1);
    ge::AttrUtils::SetInt(ops[2]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[2], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    ge::AttrUtils::SetStr(ops[3]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion_id", -1);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[3]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[3], fusionOps);
    EXPECT_EQ(ret, HCCL_E_PARA);

    ge::AttrUtils::SetStr(ops[4]->GetOpDesc(), "group", tempStr);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[4]->GetOpDesc(), "root_rank", 0);
    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[4], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "group", "hccl_world_group");
    ge::AttrUtils::SetInt(ops[1]->GetOpDesc(), "fusion_id", -1);

    MOCKER(GetCCLBufferAvailableSize)
    .stubs()
    .with(mockcpp::any())
    .will(invoke(stub_GetCCLBufferAvailableSize));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));

    FusionInfos fusionOpsTemp;
    ret = fusionHcomReduceOp.GetFusionOpsSlices(fusionOps, fusionOpsTemp);
}

TEST_F(HcomGraphOptimizerTest, ut_FuseOps_Reduce)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomReduceFusion fusionHcomReduceOp;
    std::map<std::string, std::vector<ge::NodePtr>> fusionOps;

    std::vector<ge::NodePtr> ops(3);
    std::vector<ge::NodePtr> nodeVec_0;
    nodeVec_0.push_back(ops[0]);
    nodeVec_0.push_back(ops[1]);
    nodeVec_0.push_back(ops[2]);
    fusionOps["hccl_world_group"] = nodeVec_0;


    MOCKER_CPP(&HcomReduceFusion::RunFusionOpsReduce)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = fusionHcomReduceOp.FuseOps(graph, nodeVec_0);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
	ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "group", "hccl_world_group");
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionSegments_1)
{
    std::vector<ge::NodePtr> ops(7);
    std::vector<ge::NodePtr> nodes;
    nodes.push_back(ops[0]);
    nodes.push_back(ops[1]);
    nodes.push_back(ops[2]);
    nodes.push_back(ops[3]);
    nodes.push_back(ops[4]);
    nodes.push_back(ops[5]);
    nodes.push_back(ops[6]);

    uint64_t inputTensorSize = 200*1024*1024;
    MOCKER(&HcomOpUtils::GetAllInputsTensorOriginSize)
    .stubs()
    .with(mockcpp::any(), outBound(inputTensorSize))
    .will(returnValue(HCCL_SUCCESS));

    std::vector<uint32_t> segmentIndex;
    HcomBroadcastFusion fusionHcomBroadcastOp;
    fusionHcomBroadcastOp.fusionTensorSizeLimit_ = 500*1024*1024;
    HcclResult ret = fusionHcomBroadcastOp.GetFusionSegments(nodes, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(segmentIndex.size(), 4);
    EXPECT_EQ(segmentIndex[0], 1);
    EXPECT_EQ(segmentIndex[1], 3);
    EXPECT_EQ(segmentIndex[2], 5);
    EXPECT_EQ(segmentIndex[3], 6);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionSegments_2)
{
    std::vector<ge::NodePtr> ops(3);
    std::vector<ge::NodePtr> nodes;
    nodes.push_back(ops[0]);
    nodes.push_back(ops[1]);
    nodes.push_back(ops[2]);

    uint64_t inputTensorSize = 500*1024*1024;
    MOCKER(&HcomOpUtils::GetAllInputsTensorOriginSize)
    .stubs()
    .with(mockcpp::any(), outBound(inputTensorSize))
    .will(returnValue(HCCL_SUCCESS));

    std::vector<uint32_t> segmentIndex;
    HcomBroadcastFusion fusionHcomBroadcastOp;
    fusionHcomBroadcastOp.fusionTensorSizeLimit_ = 500*1024*1024;
    HcclResult ret = fusionHcomBroadcastOp.GetFusionSegments(nodes, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    EXPECT_EQ(segmentIndex.size(), 3);
    EXPECT_EQ(segmentIndex[0], 0);
    EXPECT_EQ(segmentIndex[1], 1);
    EXPECT_EQ(segmentIndex[2], 2);
    GlobalMockObject::verify();
}

HcclResult GetOffDeviceTypeWithoutDevMock(DevType &devType)
{
    devType = DevType::DEV_TYPE_310P3;
    HCCL_DEBUG("[offline] Get devtype[%u]....", devType);
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_CalcOpRunningResources_OpenSource)
{
    // 测试开源版本的CalcOpRunningResources
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    std::string sCollectiveType;
    u32 streamNum = 0;
    u64 opMemSize = 0;
    u32 taskNum = 0;
    u32 aivCoreNum = 0;

    // 模拟IsUsingOpenSource返回true,使用开源版本
    MOCKER(IsUsingOpenSource)
    .expects(atMost(1))
    .with(outBound(true))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCreateOpParamGraphMode函数
    OpParamGraphModePtr opParamPtr = reinterpret_cast<OpParamGraphModePtr>(0x12345678);
    MOCKER(HcceCreateOpParamGraphMode)
    .expects(atMost(1))
    .with(outBound(opParamPtr))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟SetHcclOpParam函数
    MOCKER(SetHcclOpParam)
    .expects(atMost(1))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCalcOpResOfflineGraphMode函数
    MOCKER(HcceCalcOpResOfflineGraphMode)
    .expects(atMost(1))
    .with(mockcpp::any(), outBound(&opMemSize), outBound(&streamNum), outBound(&taskNum), outBound(&aivCoreNum))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟IsOfflineCompilation返回true
    MOCKER(IsOfflineCompilation)
    .expects(atMost(1))
    .will(returnValue(true));

    // 测试CalcOpRunningResources方法
    HcclResult ret = graphOptimizer.CalcOpRunningResources(*addedNodePtr0, sCollectiveType, streamNum, opMemSize, taskNum, aivCoreNum);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_SetHcclOpParam)
{
    // 测试SetHcclOpParam函数
    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    // 准备测试参数
    HcomOpParam hcomOpParam;
    std::string sCollectiveType;
    OpParamGraphModePtr opParamPtr = reinterpret_cast<OpParamGraphModePtr>(0x12345678);
    std::vector<int64_t> sendCounts;
    std::vector<int64_t> sendDispls;
    std::vector<int64_t> recvCounts;
    std::vector<int64_t> recvDispls;

    // 模拟IsUsingOpenSource返回true,使用开源版本
    MOCKER(IsUsingOpenSource)
    .expects(atMost(1))
    .with(outBound(true))
    .will(returnValue(HCCL_SUCCESS));

    // 模拟HcceCreateOpParamGraphMode函数
    MOCKER(HcceCreateOpParamGraphMode)
    .expects(atMost(1))
    .with(outBound(opParamPtr))
    .will(returnValue(HCCL_SUCCESS));

    HcclResult ret = graphOptimizer.SetHcclOpParam(*addedNodePtr0, &hcomOpParam, opParamPtr, sCollectiveType, 
                                                  sendCounts, sendDispls, recvCounts, recvDispls);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_OptimizeFusedGraph_allreduce)
{
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    auto descPtr1 = std::make_shared<ge::OpDesc>("Allreduce1", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr1->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    auto addedNodePtr1 = graph->AddNode(descPtr1);
    EXPECT_NE(addedNodePtr1, nullptr);

    u64 streamNumber = 4;
    char *group = "127.0.0.1%eth0_60000_0_1698475280390992";

    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    HcomGraphOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    MOCKER(IsOfflineCompilation)
    .stubs()
    .will(returnValue(true));

    MOCKER(GetOffDeviceTypeWithoutDev)
    .stubs()
    .will(invoke(GetOffDeviceTypeWithoutDevMock));
    ge_ret = graphOptimizer.OptimizeFusedGraph(*graph);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_OptimizeFusedGraph_broadcast)
{
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Broadcast0", HCCL_KERNEL_OP_TYPE_BROADCAST);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_BROADCAST);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    auto descPtr1 = std::make_shared<ge::OpDesc>("Allreduce1", HCCL_KERNEL_OP_TYPE_BROADCAST);
    EXPECT_EQ(descPtr1->GetType(), HCCL_KERNEL_OP_TYPE_BROADCAST);
    auto addedNodePtr1 = graph->AddNode(descPtr1);
    EXPECT_NE(addedNodePtr1, nullptr);

    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    
    HcomGraphOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_OptimizeFusedGraph_broadcast_unknown)
{
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    ge::OpDesc op;
    const string type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    op.SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Broadcast0", HCCL_KERNEL_OP_TYPE_BROADCAST);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_BROADCAST);
    auto addedNodePtr0 = graph->AddNode(descPtr0);
    EXPECT_NE(addedNodePtr0, nullptr);

    auto descPtr1 = std::make_shared<ge::OpDesc>("Allreduce1", HCCL_KERNEL_OP_TYPE_BROADCAST);
    EXPECT_EQ(descPtr1->GetType(), HCCL_KERNEL_OP_TYPE_BROADCAST);
    auto addedNodePtr1 = graph->AddNode(descPtr1);
    EXPECT_NE(addedNodePtr1, nullptr);

    u64 streamNumber = 4;
    u32 rankSize = 8;
    MOCKER(HcomGetRankSize)
    .stubs()
    .with(mockcpp::any(), outBound(&rankSize))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(true))
    .will(returnValue(ge::GRAPH_SUCCESS));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    HcomGraphOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_GetCommFromOpDesc_by_group_pytorch)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomGraphOptimizer hcomGraphOptimizer;
    std::vector<ge::NodePtr> option(2);
    int64_t hcomComm = 0;
    std::string sGroup;
    std::string tempStr = HCCL_WORLD_GROUP;

    ge::AttrUtils::HasAttr(option[0]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetStr(option[0]->GetOpDesc(), "group", tempStr);
    ge::OpDescPtr op1 = option[0]->GetOpDesc();
    ret = hcomGraphOptimizer.GetCommFromOpDesc(op1, hcomComm, sGroup);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_Reduce_by_comm_pytorch)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomReduceFusion fusionHcomReduceOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(1);

    std::string tempStrReduction = "sum";

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");

    ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "reduction", tempStrReduction);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 0);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "root_rank", 0);

    ret = fusionHcomReduceOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");

    MOCKER(GetCCLBufferAvailableSize)
    .stubs()
    .with(mockcpp::any())
    .will(invoke(stub_GetCCLBufferAvailableSize));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));

    FusionInfos fusionOpsTemp;
    ret = fusionHcomReduceOp.GetFusionOpsSlices(fusionOps, fusionOpsTemp);

}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_AllReduce_by_comm_pytorch)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionHcomAllReduceOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(1);

    std::string tempStrReduction = "sum";

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");

    ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "reduction", tempStrReduction);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 0);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "root_rank", 0);

    ret = fusionHcomAllReduceOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");

    MOCKER(GetCCLBufferAvailableSize)
    .stubs()
    .with(mockcpp::any())
    .will(invoke(stub_GetCCLBufferAvailableSize));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));

    FusionInfos fusionOpsTemp;
    ret = fusionHcomAllReduceOp.GetFusionOpsSlices(fusionOps, fusionOpsTemp);
}

TEST_F(HcomGraphOptimizerTest, ut_FuseOps_AllReduce_by_comm_pytorch)
{
    HcclResult ret;
    bool bRet;
    ge::ComputeGraph graph("test_graph");
    ge::OpDescPtr nodeComm = 0;
    ge::OpDescPtr nodeFusionId = 0;
    HcomAllReduceFusion fusionHcomAllReduceOp;
    std::map<std::string, std::vector<ge::NodePtr>> fusionOps;

    std::vector<ge::NodePtr> ops(3);
    std::vector<ge::NodePtr> nodeVec_0;
    int64_t comm = 645678545;
    int64_t fusionid = HCOM_ATTR_FUSION_ID_DEFAULT;
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    bRet = ge::AttrUtils::SetInt(nodeComm, "comm", comm);
    EXPECT_EQ(bRet, true);
    bRet = ge::AttrUtils::SetInt(nodeFusionId, "fusion_id", fusionid);
    EXPECT_EQ(bRet, true);
    nodeVec_0.push_back(ops[0]);
    nodeVec_0.push_back(ops[1]);
    nodeVec_0.push_back(ops[2]);
    fusionOps["hccl_world_group"] = nodeVec_0;

    MOCKER_CPP(&HcomAllReduceFusion::RunFusionOpsReduce)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));
    ret = fusionHcomAllReduceOp.FuseOps(graph, fusionOps["hccl_world_group"]);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_RunFusionOpss_allreduce_by_comm_pytorch)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomAllReduceFusion fusionAllHcomReduceOp;
    std::vector<ge::NodePtr> fusionOps(3);
    std::vector<ge::NodePtr> InOps(3);
    std::vector<ge::NodePtr> OutOps(3);

    fusionOps[0] = std::make_shared<NodeTest>();
    fusionOps[0]->Init();
    fusionOps[1] = std::make_shared<NodeTest>();
    fusionOps[1]->Init();
    fusionOps[2] = std::make_shared<NodeTest>();
    fusionOps[2]->Init();

    InOps[0] = std::make_shared<NodeTest>();
    InOps[0]->Init();
    InOps[1] = std::make_shared<NodeTest>();
    InOps[1]->Init();
    InOps[2] = std::make_shared<NodeTest>();
    InOps[2]->Init();

    OutOps[0] = std::make_shared<NodeTest>();
    OutOps[0]->Init();
    OutOps[1] = std::make_shared<NodeTest>();
    OutOps[1]->Init();
    OutOps[2] = std::make_shared<NodeTest>();
    OutOps[2]->Init();

    InOps[0]->GetOutControlAnchor()->LinkTo(fusionOps[0]->GetInControlAnchor());
    InOps[0]->GetOutDataAnchor(0)->LinkTo(fusionOps[0]->GetInDataAnchor(0));

    InOps[1]->GetOutDataAnchor(0)->LinkTo(fusionOps[1]->GetInControlAnchor());

    InOps[2]->GetOutDataAnchor(0)->LinkTo(fusionOps[2]->GetInDataAnchor(0));

    fusionOps[0]->GetOutControlAnchor()->LinkTo(OutOps[0]->GetInControlAnchor());
    fusionOps[0]->GetOutDataAnchor(0)->LinkTo(OutOps[0]->GetInDataAnchor(0));

    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInDataAnchor(0));
    fusionOps[1]->GetOutDataAnchor(0)->LinkTo(OutOps[1]->GetInControlAnchor());

    fusionOps[2]->GetOutDataAnchor(0)->LinkTo(OutOps[2]->GetInDataAnchor(0));

    ret = fusionAllHcomReduceOp.RunFusionOpsReduce(graph, fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_Bcast_by_comm_pytorch1)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomBroadcastFusion fusionHcomBroadcastOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(1);

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 0);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "root_rank", 0);

    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
}

TEST_F(HcomGraphOptimizerTest, ut_GetFusionOpInfo_Bcast_by_comm_pytorch2)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomBroadcastFusion fusionHcomBroadcastOp;
    FusionInfos fusionOps;
    std::vector<ge::NodePtr> ops(1);

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");

    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 0);
    ge::AttrUtils::SetStr(ops[0]->GetOpDesc(), "group", "hccl_world_group");
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion_id", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "fusion", 2);
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "root_rank", 0);

    ret = fusionHcomBroadcastOp.GetFusionOpInfo(ops[0], fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_GROUP");
}

TEST_F(HcomGraphOptimizerTest, ut_GetCommFromOpDesc_by_comm_pytorch1)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomGraphOptimizer hcomGraphOptimizer;
    std::vector<ge::NodePtr> ops(1);
    int64_t hcomComm = 0;
    std::string sGroup;

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 645678156);
    ge::OpDescPtr op0 = ops[0]->GetOpDesc();
    ret = hcomGraphOptimizer.GetCommFromOpDesc(op0, hcomComm, sGroup);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
}

TEST_F(HcomGraphOptimizerTest, ut_GetCommFromOpDesc_by_comm_pytorch2)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomGraphOptimizer hcomGraphOptimizer;
    std::vector<ge::NodePtr> ops(1);
    int64_t hcomComm = 0;
    std::string sGroup;

    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "comm");
    ge::AttrUtils::SetInt(ops[0]->GetOpDesc(), "comm", 0);
    ge::OpDescPtr op0 = ops[0]->GetOpDesc();
    ret = hcomGraphOptimizer.GetCommFromOpDesc(op0, hcomComm, sGroup);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::AttrUtils::HasAttr(ops[0]->GetOpDesc(), "DUMMY_SET_FALSE_COMM");
}

TEST_F(HcomGraphOptimizerTest, ut_HcomCalcOpRunningParam_by_comm_pytorch)
{
    HcclResult ret;
    ge::ComputeGraph graph("test_graph");
    HcomGraphOptimizer hcomGraphOptimizer;
    std::vector<ge::Node> ops(1);
    int64_t hcomComm = 0;
    u64 streamNum = 8;
    std::string sGroup;
    const string type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    ops[0].GetOpDesc()->SetType(type);

    auto descPtr0 = std::make_shared<ge::OpDesc>("Allreduce0", HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    EXPECT_EQ(descPtr0->GetType(), HCCL_KERNEL_OP_TYPE_ALLREDUCE);

    ge::AttrUtils::HasAttr(ops[0].GetOpDesc(), "DUMMY_SET_TRUE_COMM");
    ge::AttrUtils::HasAttr(ops[0].GetOpDesc(), "comm");
    ge::AttrUtils::SetInt(ops[0].GetOpDesc(), "comm", 645678156);

    ge::AttrUtils::SetInt(ops[0].GetOpDesc(), "used_stream_num", streamNum);
    MOCKER_CPP(&HcomGraphOptimizer::GetOriginalGraphShapeTypeFromDesc)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    u32 shapeType = ORIGINAL_GRAPH_UNKNOWNSHAPE_TYPE;

    MOCKER_CPP_VIRTUAL(hcomGraphOptimizer, &HcomGraphOptimizer::SetOpOutputMemSize)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomGraphOptimizer::SetOpWorkerSpaceForKnowShape)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomGraphOptimizer::SetOpAtomicInputIndex)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::AttrUtils::HasAttr(ops[0].GetOpDesc(), "DUMMY_SET_FALSE_COMM");

    ret = hcomGraphOptimizer.HcomCalcOpRunningParam(ops[0], false);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, utFusionVersion)
{
    HcclResult ret;
    std::string socVersion;
    HcomOpUtils ops;
    ret = ops.CreateFusionConfigVersion(socVersion);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, utFusionPath_Tune)
{
    HcclResult ret;
    HcomOpUtils ops;
    std::string fusionPath;
    std::string fusionFile;
    char* env = nullptr;

    ret = ops.GetPathFromEnv(env, fusionPath);
    EXPECT_EQ(ret, HCCL_E_AGAIN);

    env = "/tmp";
    ret = ops.GetPathFromEnv(env, fusionPath);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ret = ops.GetFileNameFromPath(fusionPath, fusionFile);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    env = "./test";
    ret = ops.GetPathFromEnv(env, fusionPath);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    ret = ops.GetFileNameFromPath(fusionPath, fusionFile);
    EXPECT_EQ(ret, HCCL_E_AGAIN);
}

TEST_F(HcomGraphOptimizerTest, utCalculateSegmentIndex)
{
    nlohmann::json rank_table =
    {
        {"status", "completed"},
        {"version", "1.0"},
        {"server_count", "1"},
        {
            "server_list",
            {
                {
                    {"server_id", "10.0.0.10"},
                    {"host_nic_ip", "192.168.0.12:0,192.168.1.12:199"},
                    {
                        "device",
                        {
                            {   {"rank_id", "0"},
                                {"device_id", "0"},
                                {"device_ip", "192.168.0.12,192.168.1.12"}

                            },
                        }
                    },
                }
            }
        }
    };

    char file_name_t[] = "./ut_hcom_get_new_rank_info_muti_ip.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();
    int ret = HCCL_SUCCESS;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    u32 localrankid = 0;
    u32 localranksize = 0;
    char* rank_table_file = "./ut_hcom_get_new_rank_info_muti_ip.json";
    char* rank_ID = "0";

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    u64 tensorFusionLimit = std::numeric_limits<u64>::max();
    std::string fusionHash = "001";
    std::vector<u32> segmentIndex;
    std::string fusionPath = "./Ascend910A_gradient_fusion.json";
    nlohmann::json root;
    root[0]["modelhash"] = "001";
    root[0]["modelgraphid"] = 1;
    root[0]["modelvalue"]["gradientSize"] = {16384000,4000,67108864,16384,150994944,16384,2359296,1024,3538944,1024,2654208,1536,1843200,768,139392,384};
    root[0]["modelvalue"]["gradientTime"] = {9833418590329,15080,100303,86064,408791,246415,634287,686785,103,564542,100,1686982,117497,3490117,202901,117711};
    fstream jFile;
    jFile.open(fusionPath, std::ios::out | std::ios::trunc);
    jFile.close();

    HcomAllReduceFusion ops;
    ret = ops.GetInformationFromLibrary(fusionPath, fusionHash, tensorFusionLimit, segmentIndex);
    EXPECT_EQ(ret, HCCL_E_AGAIN);

    jFile.open(fusionPath, std::ios::out);
    jFile << root;
    jFile.close();

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);
    remove(fusionPath.c_str());
    GlobalMockObject::verify();
}

extern vector<float> CalculateSizeRatio(const vector<float>& sliceSize, float totalSize);
TEST_F(HcomGraphOptimizerTest, utCalculateSizeRatio)
{
    float totalSize = 100.00;
    vector<float> sliceSize{41.00, 40.00, 9.00, 10.00};
    vector<float> ratio;
    ratio = CalculateSizeRatio(sliceSize, totalSize);
}

TEST_F(HcomGraphOptimizerTest, utCalculateSegmentIndexFromHomeExport)
{
    HcclResult ret;
    float tensorFusionLimit = 3.40282e+38;
    std::string fusionHash = "4569785153";
    std::string fusionSocVersion = "Ascend910A";
    std::string fusionStartPath = "/usr/local/ascend/compiler/lib64/";
    std::vector<u32> segmentIndex;
    std::vector<u32> index{0};
    HcomAllReduceFusion ops;

    MOCKER(HcomOpUtils::CreateFusionConfigVersion)
    .stubs()
    .with(outBound(fusionSocVersion))
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomAllReduceFusion::GetPathFromDefault)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    MOCKER_CPP(&HcomAllReduceFusion::GetInformationFromLibrary)
    .stubs()
    .with(mockcpp::any())
    .will(returnValue(HCCL_SUCCESS));

    ret = ops.CalculateSegmentIndex(fusionHash, tensorFusionLimit, segmentIndex);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

ge::graphStatus FakeGetOption3(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
    }
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_offlinebuild_calcSubStreamNum)
{
    ge::Status ret;
    HcomGraphOptimizer graphOptimizer;
    ge::NodePtr nodeptr(new NodeTest);
    std::string type = HCCL_KERNEL_OP_TYPE_ALLTOALLV;
    nodeptr->GetOpDesc()->SetType(type);

    std::string curGroup = "aa";
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", curGroup);

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(FakeGetOption3));

    MOCKER(&ge::AttrUtils::SetInt)
    .stubs()
    .will(returnValue(false));

    ret = graphOptimizer.HcomCalcOpRunningParam(*nodeptr, false);

    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    std::string nodeName = "ALL_GATHER_NO_CALCULATION";
    nodeptr->GetOpDesc()->SetType(type);
    nodeptr->GetOpDesc()->SetName(nodeName);
    ret = graphOptimizer.HcomCalcOpRunningParam(*nodeptr, false);

    GlobalMockObject::verify();
}

struct OpInfo {
    ge::NodePtr nodePtr;
    std::pair<std::string, std::string> opName;
    std::vector<std::pair<std::string, int64_t>> attrInt;
    std::vector<std::pair<std::string, std::string>> attrStr;
};

HcclResult CreateNodePtr(OpInfo& opInfo, ge::ComputeGraph& graph)
{
    ge::OpDescPtr opDescPtr =
        std::make_shared<ge::OpDesc>(opInfo.opName.first.c_str(), opInfo.opName.second.c_str());
    EXPECT_NE(opDescPtr, nullptr);
    opDescPtr->SetName(opInfo.opName.first.c_str());

    for (auto& it : opInfo.attrInt) {
        bool bErr = ge::AttrUtils::SetInt(opDescPtr, it.first.c_str(), it.second);
        CHK_PRT_RET(!bErr, HCCL_ERROR("node[%s] set attr: %s failed", \
            opInfo.opName.first.c_str(), it.first.c_str()), HCCL_E_INTERNAL);
    }
    for (auto& it : opInfo.attrStr) {
        bool bErr = ge::AttrUtils::SetStr(opDescPtr, it.first.c_str(), it.second.c_str());
        CHK_PRT_RET(!bErr, HCCL_ERROR("node[%s] set attr: %s failed", \
            opInfo.opName.first.c_str(), it.first.c_str()), HCCL_E_INTERNAL);
    }

    opInfo.nodePtr = graph.AddNode(opDescPtr);
    CHK_PRT_RET((opInfo.nodePtr == nullptr), HCCL_ERROR("[Create]node[%s] failed",
        opInfo.opName.first.c_str()), HCCL_E_INTERNAL);
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_FuseHcomAlltoAllVCNode)
{
    MOCKER_CPP(&HcomAlltoAllVCFusion::RunFusionOpsAlltoAllVC)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 alltoallvcNum = 2;
    for (u32 i = 0; i < alltoallvcNum; i++) {
        // alltoallvc
        OpInfo alltoallvcOpInfo;
        alltoallvcOpInfo.opName = std::make_pair("AlltoAllVC_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_ALLTOALLVC);
        alltoallvcOpInfo.attrInt.push_back(std::make_pair("rank", 0));
        alltoallvcOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        alltoallvcOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        alltoallvcOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(alltoallvcOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(alltoallvcOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        alltoallvcOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLTOALLVC);
    }
    HcomFusionOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.FuseHcomAlltoAllVCNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_RunFusionOpsAlltoAllVC)
{
    MOCKER_CPP(&HcomAlltoAllVCFusion::AddOpsEdge)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    HcomAlltoAllVCFusion fusionHcomAlltoAllVCOp;
    constexpr u32 fusionOpNum = 2;
    std::vector<ge::NodePtr> fusionOps(fusionOpNum);
    std::vector<ge::NodePtr> InOps(fusionOpNum * 2);
    std::vector<ge::NodePtr> OutOps(fusionOpNum);

    for (u32 i = 0; i < fusionOps.size(); ++i) {
        fusionOps[i] = std::make_shared<NodeTest>();
        fusionOps[i]->Init();
        fusionOps[i]->Init(); // 多调用一次Init, 让IndataAnchor的size + 1
    }
    for (u32 i = 0; i < InOps.size(); ++i) {
        InOps[i] = std::make_shared<NodeTest>();
        InOps[i]->Init();
    }
    for (u32 i = 0; i < OutOps.size(); ++i) {
        OutOps[i] = std::make_shared<NodeTest>();
        OutOps[i]->Init();
    }

    // link
    for (u32 i = 0; i < fusionOps.size(); ++i) {
        InOps[i * 2]->GetOutDataAnchor(0)->LinkTo(fusionOps[i]->GetInDataAnchor(0));
        InOps[i * 2 + 1]->GetOutDataAnchor(0)->LinkTo(fusionOps[i]->GetInDataAnchor(1));
        fusionOps[i]->GetOutDataAnchor(0)->LinkTo(OutOps[i]->GetInDataAnchor(0));
    }

    ret = fusionHcomAlltoAllVCOp.RunFusionOpsAlltoAllVC(*graph, fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_FuseHcomAllGatherNode)
{
    MOCKER_CPP(&HcomAllGatherFusion::RunFusionOpsAllGather)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 allgatherNum = 2;
    for (u32 i = 0; i < allgatherNum; i++) {
        // allgather
        OpInfo allgatherOpInfo;
        allgatherOpInfo.opName = std::make_pair("AllGather_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_ALLGATHER);
        allgatherOpInfo.attrInt.push_back(std::make_pair("rank_size", 3));
        allgatherOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        allgatherOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        allgatherOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(allgatherOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(allgatherOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        allgatherOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLGATHER);
    }
    HcomFusionOptimizer graphOptimizer;
    ret = graphOptimizer.HcomOptimizeOriginalGraph(*graph);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    ge::Status ge_ret = graphOptimizer.FuseHcomAllgatherNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.HcomOptimizeSetAttr(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_RunFusionOpsAllGather)
{
    MOCKER_CPP(&HcomAllGatherFusion::AddOpsEdge)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    HcomAllGatherFusion fusionHcomAllGatherOp;
    constexpr u32 fusionOpNum = 2;
    std::vector<ge::NodePtr> fusionOps(fusionOpNum);
    std::vector<ge::NodePtr> InOps(fusionOpNum);
    std::vector<ge::NodePtr> OutOps(fusionOpNum);

    for (u32 i = 0; i < fusionOps.size(); ++i) {
        fusionOps[i] = std::make_shared<NodeTest>();
        fusionOps[i]->Init();
    }
    for (u32 i = 0; i < InOps.size(); ++i) {
        InOps[i] = std::make_shared<NodeTest>();
        InOps[i]->Init();
    }
    for (u32 i = 0; i < OutOps.size(); ++i) {
        OutOps[i] = std::make_shared<NodeTest>();
        OutOps[i]->Init();
    }

    // link
    for (u32 i = 0; i < fusionOps.size(); ++i) {
        InOps[i]->GetOutDataAnchor(0)->LinkTo(fusionOps[i]->GetInDataAnchor(0));
        InOps[i]->GetOutControlAnchor()->LinkTo(fusionOps[i]->GetInControlAnchor());
        fusionOps[i]->GetOutDataAnchor(0)->LinkTo(OutOps[i]->GetInDataAnchor(0));
        fusionOps[i]->GetOutControlAnchor()->LinkTo(OutOps[i]->GetInControlAnchor());
    }

    ret = fusionHcomAllGatherOp.RunFusionOpsAllGather(*graph, fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}


TEST_F(HcomGraphOptimizerTest, ut_FuseHcomReduceScatterNode)
{
    MOCKER_CPP(&HcomReduceScatterFusion::RunFusionOpsReduceScatter)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 reducescatterNum = 2;
    for (u32 i = 0; i < reducescatterNum; i++) {
        // reducescatter
        OpInfo reducescatterOpInfo;
        reducescatterOpInfo.opName = std::make_pair("ReduceScatter_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
        reducescatterOpInfo.attrStr.push_back(std::make_pair("reduction", "sum"));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("rank_size", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        reducescatterOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(reducescatterOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(reducescatterOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        reducescatterOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
    }
    HcomFusionOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.FuseHcomReduceScatterNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ret = graphOptimizer.HcomOptimizeSetAttr(*graph);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_FuseHcomReduceScatterNode1)
{
    MOCKER_CPP(&HcomReduceScatterFusion::RunFusionOpsReduceScatter)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 reducescatterNum = 2;
    for (u32 i = 0; i < reducescatterNum; i++) {
        // reducescatter
        OpInfo reducescatterOpInfo;
        reducescatterOpInfo.opName = std::make_pair("ReduceScatter_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
        reducescatterOpInfo.attrStr.push_back(std::make_pair("reduction", "sum"));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("rank_size", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        reducescatterOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(reducescatterOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(reducescatterOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        reducescatterOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_REDUCESCATTER);
    }

    MOCKER(GetCCLBufferAvailableSize)
    .stubs()
    .with(mockcpp::any())
    .will(invoke(stub_GetCCLBufferAvailableSize));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));

    HcomFusionOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.FuseHcomReduceScatterNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_FuseHcomReduceNode1)
{
    MOCKER_CPP(&HcomReduceScatterFusion::RunFusionOpsReduce)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 reducescatterNum = 2;
    for (u32 i = 0; i < reducescatterNum; i++) {
        OpInfo reducescatterOpInfo;
        reducescatterOpInfo.opName = std::make_pair("Reduce_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_REDUCE);
        reducescatterOpInfo.attrStr.push_back(std::make_pair("reduction", "sum"));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("rank_size", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        reducescatterOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(reducescatterOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(reducescatterOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        reducescatterOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_REDUCE);
    }

    HcomGraphOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.FuseHcomReduceNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_FuseHcomReduceNode2)
{
    MOCKER_CPP(&HcomReduceScatterFusion::RunFusionOpsReduce)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 reducescatterNum = 2;
    for (u32 i = 0; i < reducescatterNum; i++) {
        OpInfo reducescatterOpInfo;
        reducescatterOpInfo.opName = std::make_pair("AllReduce_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_ALLREDUCE);
        reducescatterOpInfo.attrStr.push_back(std::make_pair("reduction", "sum"));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("rank_size", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        reducescatterOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        reducescatterOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(reducescatterOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(reducescatterOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        reducescatterOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLREDUCE);
    }
    HcomGraphOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.FuseHcomAllReduceNode(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_RunFusionReduceScatter)
{
    MOCKER_CPP(&HcomReduceScatterFusion::AddOpsEdge)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    HcomReduceScatterFusion fusionHcomReduceScatterOp;
    constexpr u32 fusionOpNum = 2;
    std::vector<ge::NodePtr> fusionOps(fusionOpNum);
    std::vector<ge::NodePtr> InOps(fusionOpNum);
    std::vector<ge::NodePtr> OutOps(fusionOpNum);

    for (u32 i = 0; i < fusionOps.size(); ++i) {
        fusionOps[i] = std::make_shared<NodeTest>();
        fusionOps[i]->Init();
    }
    for (u32 i = 0; i < InOps.size(); ++i) {
        InOps[i] = std::make_shared<NodeTest>();
        InOps[i]->Init();
    }
    for (u32 i = 0; i < OutOps.size(); ++i) {
        OutOps[i] = std::make_shared<NodeTest>();
        OutOps[i]->Init();
    }

    // link
    for (u32 i = 0; i < fusionOps.size(); ++i) {
        InOps[i]->GetOutDataAnchor(0)->LinkTo(fusionOps[i]->GetInDataAnchor(0));
        InOps[i]->GetOutControlAnchor()->LinkTo(fusionOps[i]->GetInControlAnchor());
        fusionOps[i]->GetOutDataAnchor(0)->LinkTo(OutOps[i]->GetInDataAnchor(0));
        fusionOps[i]->GetOutControlAnchor()->LinkTo(OutOps[i]->GetInControlAnchor());
    }

    ret = fusionHcomReduceScatterOp.RunFusionOpsReduceScatter(*graph, fusionOps);
    EXPECT_EQ(ret, HCCL_SUCCESS);
    GlobalMockObject::verify();
}

ge::graphStatus GetOption2(ge::GEContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    dumpDebugValue = "1";
    return ge::GRAPH_SUCCESS;
}

ge::graphStatus GetOption3(ge::GEContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    dumpDebugValue = "2";
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_SetknownShapAttr)
{
    HcclResult ret;
    int64_t memSize = 200*1024*1024 + 1;
    HcomGraphOptimizer graphOptimizer;
    graphOptimizer.optionFeatureBaseRefreshable_ = 1;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    const string type = "HcomAllToAllVC";
    ge::OpDescPtr opPtr_ = std::make_shared<ge::OpDesc>();
    opPtr_->SetType(type);
    graph->AddNode(opPtr_);

    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(false))
    .will(returnValue(ge::GRAPH_SUCCESS));

    MOCKER(&HcomOpUtils::GetAllInputsTensorMemSize)
    .stubs()
    .will(invoke(stub_GetAllInputsTensorMemSize));


    MOCKER_CPP(&ge::GEContext::GetOption)
    .stubs()
    .will(invoke(GetOption2));

    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);

    MOCKER_CPP(&ge::GEContext::GetOption)
    .stubs()
    .will(invoke(GetOption3));
    ret = graphOptimizer.SetUnknownShapeAttr(*graph, true);
    GlobalMockObject::verify();
}

HcclResult stub_GetVectorFromTensorGraphOptimizer(const ge::GeTensor* tensor, std::vector<int64_t>& vector)
{
    vector.resize(4*4);
    return HCCL_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_getAlltoAllCountsDispl_sendCountMatrix)
{
    ge::NodePtr nodeptr(new NodeTest);
    HcomOpUtils graphOptimizer;

    std::vector<int64_t> sendCountMatrix;

    HcclResult ret = graphOptimizer.GetAlltoAllCountMatrix(*(nodeptr.get()), sendCountMatrix);
    EXPECT_EQ(ret, HCCL_SUCCESS);
}


std::string optionExecTest = "Ascend910";
ge::graphStatus TaskNumGetOption_1(ge::GEThreadLocalContext *that, const std::string &optionExec, std::string &dumpDebugValue)
{
    nlohmann::json group_list =
    {
        {
            {"group_name", "aa"},
            {"group_rank_list", {0, 1}}
        },
        {
            {"group_name", "off_group_rank_list"},
            {"group_rank_list", {0, 1, 2, 3, 4, 5, 6, 7}}
        }
    };
    if (optionExec == ge::OPTION_EXEC_HCOM_GROUPLIST) {
        dumpDebugValue = group_list.dump();
    } else if (optionExec == ge::OPTION_EXEC_HCOM_RANK_MAPPING) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "-1","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE) {
        dumpDebugValue = R"({"status": "completed","version": "1.1","node_list":[{"node_id": "0","rank_list":[
        {"rank_id": "0","item_id": "0","rank_ip":"192.168.2.10"},
        {"rank_id": "1","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "2","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "3","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "4","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "5","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "6","item_id": "0","rank_ip":"192.168.2.11"},
        {"rank_id": "7","item_id": "0","rank_ip":"192.168.2.11"}]}]})";
    } else if (optionExec == "ge.socVersion") {
        dumpDebugValue = optionExecTest;
    } else if (optionExec == ge::OPTION_EXEC_RANK_TABLE_FILE) {
        dumpDebugValue = "./llt/ace/comop/hccl/stub/workspace/ut_task_num_one_server_hcom_test.json";
    } else if (optionExec == "ge.offline_hccl_compile") {
        return ge::GRAPH_FAILED;
    }
    return ge::GRAPH_SUCCESS;
}

TEST_F(HcomGraphOptimizerTest, ut_CalcOpTaskNum_1server_1)
{
    HcclResult ret;
    nlohmann::json rank_table = rank_table_1server_8rank;
    char file_name_t[] = "./llt/ace/comop/hccl/stub/workspace/ut_task_num_one_server_hcom_test.json";
    std::ofstream outfile(file_name_t, std::ios::out | std::ios::trunc | std::ios::binary);

    if (outfile.is_open())
    {
        outfile << std::setw(1) << rank_table << std::endl;
        HCCL_INFO("open %s success", file_name_t);
    }
    else
    {
        HCCL_ERROR("open %s failed", file_name_t);
    }

    outfile.close();

    ge::OpDesc op;
    ge ::Status ge_ret = ge::INTERNAL_ERROR;
    HcomGraphOptimizer hcomKernelInfo;

    ret = hrtSetDevice(0);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    DevType deviceType = DevType::DEV_TYPE_910;

    MOCKER_CPP(&ge::GEThreadLocalContext::GetOption)
    .stubs()
    .will(invoke(TaskNumGetOption_1));

    char* rank_table_file = "./llt/ace/comop/hccl/stub/workspace/ut_task_num_one_server_hcom_test.json";
    char* rank_ID = "0";

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    ge::NodePtr nodeptr(new NodeTest);
    int64_t RANK_SIZE = 4;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), "rank_size", RANK_SIZE);
    std::string tempStr = HCCL_WORLD_GROUP;
    ge::AttrUtils::SetStr(nodeptr->GetOpDesc(), "group", tempStr);

    std::string type;
    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    std::string name = HCCL_KERNEL_OP_TYPE_ALLREDUCE + "1server";
    nodeptr->GetOpDesc()->SetName(name);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLREDUCE;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_REDUCESCATTER;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_SEND;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    type = HCCL_KERNEL_OP_TYPE_BROADCAST;
    nodeptr->GetOpDesc()->SetType(type);
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    optionExecTest = "Ascend910_9391";
    ge_ret = hcomKernelInfo.CalcOpRunningParam(*nodeptr, true);
    EXPECT_EQ(ge_ret, ge::SUCCESS);

    ret = HcomDestroy();
    EXPECT_EQ(ret, HCCL_SUCCESS);
    remove(file_name_t);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_HcomOptimizeOriginalGraph)
{
    MOCKER_CPP(&HcomAllGatherFusion::RunFusionOpsAllGather)
        .stubs()
        .will(returnValue(HCCL_SUCCESS));

    HcclResult ret;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");

    u32 allgatherNum = 2;
    for (u32 i = 0; i < allgatherNum; i++) {
        // allgather
        OpInfo allgatherOpInfo;
        allgatherOpInfo.opName = std::make_pair("AllGather_" + std::to_string(i), HCCL_KERNEL_OP_TYPE_ALLGATHER);
        allgatherOpInfo.attrInt.push_back(std::make_pair("rank_size", 3));
        allgatherOpInfo.attrInt.push_back(std::make_pair("fusion", 2));
        allgatherOpInfo.attrInt.push_back(std::make_pair("fusion_id", 1));
        allgatherOpInfo.attrStr.push_back(std::make_pair("group", HCCL_WORLD_GROUP));
        ret = CreateNodePtr(allgatherOpInfo, *graph);
        EXPECT_EQ(ret, HCCL_SUCCESS);

        ge::AttrUtils::HasAttr(allgatherOpInfo.nodePtr->GetOpDesc(), "DUMMY_SET_TRUE_GROUP");
        allgatherOpInfo.nodePtr->GetOpDesc()->SetType(HCCL_KERNEL_OP_TYPE_ALLGATHER);
    }
    HcomFusionOptimizer graphOptimizer;
    ge::Status ge_ret = graphOptimizer.HcomOptimizeOriginalGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.OptimizeFusedGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.OptimizeGraphPrepare(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.OptimizeOriginalGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.OptimizeWholeGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge::GraphOptimizerAttribute attrs;
    ge_ret = graphOptimizer.GetAttributes(attrs);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    ge_ret = graphOptimizer.Finalize();
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_OptimizeOriginalGraphDynamicGraphNoSuperkernel)
{
    ge::Status ge_ret;
    MOCKER(&ge::NodeUtils::GetNodeUnknownShapeStatus)
    .stubs()
    .with(mockcpp::any(), outBound(true))
    .will(returnValue(ge::GRAPH_SUCCESS));

    HcomGraphOptimizer graphOptimizer;
    ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("test_graph");
    const string type = "HcomAllReduce";
    ge::OpDescPtr opPtr_ = std::make_shared<ge::OpDesc>();
    opPtr_->SetType(type);
    graph->AddNode(opPtr_);
    ge_ret = graphOptimizer.OptimizeOriginalGraph(*graph);
    EXPECT_EQ(ge_ret, ge::SUCCESS);
    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, ut_offlinebuild_calcSubStreamNum_1)
{
    HcomGraphOptimizer graphOptimizer;
    ge::NodePtr nodeptr(new NodeTest);

    MOCKER(HcomLoadRanktableFile)
    .stubs()
    .will(returnValue(HCCL_SUCCESS));

    std::string type = HCCL_KERNEL_OP_TYPE_ALLGATHER;
    std::string nodeName = "ALL_GATHER";
    nodeptr->GetOpDesc()->SetType(type);
    MOCKER(&ge::AttrUtils::SetInt).stubs().will(returnValue(true));
    ge::Status ret = graphOptimizer.HcomCalcOpRunningParam(*nodeptr, false);
    EXPECT_EQ(ret, HCCL_SUCCESS);

    GlobalMockObject::verify();
}

TEST_F(HcomGraphOptimizerTest, Ut_GetInfoFromContentedLibrary_When_AllParamsValid_Expect_CorrectSegmentIndex) {
    std::string fusionPath = "./test_fusion.json";
    u64 tensorFusionLimit = std::numeric_limits<u64>::max();
    std::string fusionHash = "001";
    std::vector<u32> segmentIndex;
    nlohmann::json root;
    root[0]["modelhash"] = "001";
    root[0]["modelgraphid"] = 1;
    root[0]["modelvalue"]["gradientSize"] = {16384000,4000,67108864,16384,150994944,16384,2359296,1024,3538944,1024,2654208,1536,1843200,768,139392,384};
    root[0]["modelvalue"]["gradientTime"] = {9833418590329,15080,100303,86064,408791,246415,634287,686785,103,564542,100,1686982,117497,3490117,202901,117711};
    fstream jFile;
    jFile.open(fusionPath, std::ios::out | std::ios::trunc);
    if (jFile.is_open()) {
        // 将root内容写入文件
        jFile << root;
        jFile.close();
    } else {
        // 处理文件打开失败的情况
        EXPECT_TRUE(false) << "Failed to open file: " << fusionPath;
        return;
    }

    std::vector<int>  vecint ={0,1,2,3};
    MOCKER_CPP(&EvaluatorDataParallel::run).stubs().will(returnValue(vecint));
    HcomAllReduceFusion ops;
    HcclResult result = ops.GetInfoFromContentedLibrary(fusionPath, fusionHash, tensorFusionLimit, segmentIndex);
    EXPECT_EQ(result, HCCL_SUCCESS);
    std::remove(fusionPath.c_str());

    std::string fusionPath1;
    result = ops.GetPathFromDefault(fusionPath1);
    EXPECT_EQ(result, HCCL_SUCCESS);
}

TEST_F(HcomGraphOptimizerTest, Ut_GetFusionOption_When_FusionAttrNoFusion_Expect_Success) {
    ge::NodePtr nodeptr(new NodeTest);
    nodeptr->GetOpDesc()->SetName("test_node");
    HcomAllReduceFusion allreduce;
    FusionOption fusionOption;

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_SPLIT_STRATEGY);
    HcclResult result = allreduce.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_FUSION_ID);
    result = allreduce.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, 3);
    result = allreduce.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_NO_FUSION);
    result = allreduce.GetFusionOption(nodeptr, fusionOption);
    EXPECT_EQ(result, HCCL_SUCCESS);

    HcomAlltoAllVCFusion alltoallVC;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_FUSION_ID);
    result = alltoallVC.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_SPLIT_STRATEGY);
    result = alltoallVC.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_NO_FUSION);
    result = alltoallVC.GetFusionOption(nodeptr, fusionOption);
    EXPECT_EQ(result, HCCL_SUCCESS);

    HcomAllGatherFusion allgather;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_FUSION_ID);
    result = allgather.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_SPLIT_STRATEGY);
    result = allgather.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_NO_FUSION);
    result = allgather.GetFusionOption(nodeptr, fusionOption);
    EXPECT_EQ(result, HCCL_SUCCESS);

    HcomBroadcastFusion broadcast;
    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_FUSION_ID);
    result = broadcast.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_BY_SPLIT_STRATEGY);
    result = broadcast.GetFusionOption(nodeptr, fusionOption);

    ge::AttrUtils::SetInt(nodeptr->GetOpDesc(), HCOM_ATTR_NAME_FUSION, HCOM_ATTR_FUSION_NO_FUSION);
    result = broadcast.GetFusionOption(nodeptr, fusionOption);
    EXPECT_EQ(result, HCCL_SUCCESS);
}