* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "gtest/gtest.h"
#include <mockcpp/mockcpp.hpp>
#include <stdio.h>
#define protected public
#define private public
#include "alg_template_register.h"
#include "alg_template_base_pub.h"
#include "alltoallv_staged_base_pub.h"
#include "alltoallv_for_310p_pub.h"
#undef private
#undef protected
using namespace std;
using namespace hccl;
class AlgTemplateBaseTest : public testing::Test {
protected:
static void SetUpTestCase()
{
std::cout << "AlgTemplateBaseTest Testcase SetUP" << std::endl;
}
static void TearDownTestCase()
{
std::cout << "AlgTemplateBaseTest Testcase TearDown" << std::endl;
}
virtual void SetUp()
{
std::cout << "AlgTemplateBaseTest SetUP" << std::endl;
}
virtual void TearDown()
{
GlobalMockObject::verify();
std::cout << "AlgTemplateBaseTest TearDown" << std::endl;
}
};
TEST_F(AlgTemplateBaseTest, alg_template_base_prepare)
{
HcclDispatcher disp;
struct PrepareData params;
DeviceMem devMem;
Stream mainStream;
std::vector<Stream> subStreams;
std::vector<LINK> links;
std::vector<SendRecvInfo> sendrecvInfo;
std::vector<std::shared_ptr<LocalNotify>> signals;
std::map<u32, std::vector<u64>> displsMap;
StageAlltoAllVAddrInfo alltoallAddr;
AlltoAllVBufferInfo alltoallBuf;
u64 attr;
u32 idx;
HcomCollOpInfo opInfo;
std::vector<u32> order;
std::vector<std::vector<u32>> multiOrders;
std::vector<Slice> slices;
std::vector<std::vector<Slice>> multiSlices;
SubCommInfo comInfo;
AlgTemplateBase tempBase(disp);
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(params));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, devMem, signals, signals,
mainStream, subStreams, links, 0, 0, sendrecvInfo));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(alltoallBuf, alltoallBuf, true, mainStream,
HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE, displsMap, displsMap));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(alltoallBuf, alltoallBuf, devMem, devMem, true, mainStream,
HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE, displsMap, displsMap));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, alltoallAddr, alltoallAddr,
true, mainStream));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, devMem, alltoallAddr, alltoallAddr,
true, mainStream));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, alltoallAddr, alltoallAddr,
true, 0, mainStream, subStreams, signals, signals));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, devMem, alltoallAddr, alltoallAddr,
true, 0, mainStream, subStreams, signals, signals));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(attr, nullptr));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(attr, false));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(attr, idx));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(attr, &opInfo, idx, subStreams, signals, signals,
order, slices, false));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(&opInfo, devMem, attr, attr, attr,
comInfo, comInfo, mainStream, subStreams, signals, signals, attr));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, order, idx));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
subStreams, signals, signals, idx, &opInfo));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(mainStream, comInfo, devMem, devMem, devMem, devMem,
attr, subStreams, signals, signals, HcclDataType::HCCL_DATA_TYPE_INT8, HcclReduceOp::HCCL_REDUCE_SUM,
multiSlices, attr));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, idx, attr, UserMemType::INPUT_MEM, UserMemType::INPUT_MEM));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, &opInfo));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, idx, &opInfo));
EXPECT_EQ(HcclResult::HCCL_E_PARA, tempBase.Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, multiSlices, HcclReduceOp::HCCL_REDUCE_SUM,
idx, attr, false, attr, &opInfo, idx, subStreams, signals, signals, multiOrders, multiSlices));
}
TEST_F(AlgTemplateBaseTest, alltoallv_staged_base_prepare)
{
HcclDispatcher disp;
DeviceMem devMem;
StageAlltoAllVAddrInfo alltoallAddr;
Stream mainStream;
AlltoAllVStagedBase tempBase(disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempBase.Prepare(devMem, devMem, alltoallAddr, alltoallAddr,
true, mainStream));
}
TEST_F(AlgTemplateBaseTest, alltoallv_template_instance_init)
{
HcclDispatcher disp;
DeviceMem devMem = DeviceMem::create(nullptr, 1024);
Stream mainStream;
std::vector<Stream> subStreams(2);
std::vector<LINK> links;
std::vector<SendRecvInfo> sendrecvInfo;
std::shared_ptr<LocalNotify> s0 = std::make_shared<LocalNotify>();
std::shared_ptr<LocalNotify> s1 = std::make_shared<LocalNotify>();
std::vector<std::shared_ptr<LocalNotify>> signals = {s0, s1};
StageAlltoAllVAddrInfo alltoallAddr;
std::unique_ptr<AlgTemplateBase> tempAlg;
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_ALL_2_ALL_V_FOR310P, disp);
HcclResult ret = tempAlg->Prepare(devMem, devMem, devMem, devMem, signals, signals,
mainStream, subStreams, links, 0, 1, sendrecvInfo);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, ret);
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_ALL_2_ALL_V_STAGED_MESH, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, alltoallAddr, alltoallAddr,
false, 0, mainStream, subStreams, signals, signals));
EXPECT_NE(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, devMem, alltoallAddr, alltoallAddr,
false, 0, mainStream, subStreams, signals, signals));
}
TEST_F(AlgTemplateBaseTest, reducescatter_template_instance_init)
{
HcclDispatcher disp;
DeviceMem devMem;
Stream mainStream;
std::vector<Stream> subStreams(2);
std::vector<std::shared_ptr<LocalNotify>> signals(2);
u64 attr = 0ULL;
u32 idx = 0UL;
HcomCollOpInfo opInfo;
std::vector<Slice> slices;
std::vector<std::vector<Slice>> multiSlices(2);
SubCommInfo comInfo;
opInfo.dataType = HcclDataType::HCCL_DATA_TYPE_INT8;
opInfo.reduceOp = HcclReduceOp::HCCL_REDUCE_SUM;
comInfo.localRankSize = 0;
std::unique_ptr<AlgTemplateBase> tempAlg;
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_HDSTAGE, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, &opInfo));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_LOCAL_REDUCE, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, &opInfo));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_NB, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(attr));
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_PIPELINE, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(&opInfo, devMem, attr, attr, attr, comInfo, comInfo,
mainStream, subStreams, signals, signals, attr));
multiSlices[0].resize(2);
comInfo.localRankSize = 2;
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_UNIFIED_MARCH, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(mainStream, comInfo, devMem, devMem, devMem, devMem,
attr, subStreams, signals, signals, HcclDataType::HCCL_DATA_TYPE_INT8, HcclReduceOp::HCCL_REDUCE_SUM,
multiSlices, attr));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_MESH_DIRECT, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, &opInfo));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_MESH_ATOMIC, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, &opInfo));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_MESH_MIX_SS, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(attr, idx));
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx, slices, attr));
tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
TemplateType::TEMPLATE_REDUCESCATTER_MESH_MIX, disp);
EXPECT_EQ(HcclResult::HCCL_SUCCESS, tempAlg->Prepare(devMem, devMem, devMem, attr,
HcclDataType::HCCL_DATA_TYPE_INT8, mainStream, HcclReduceOp::HCCL_REDUCE_SUM, idx,
slices, attr, attr, subStreams, signals, signals, idx, idx, &opInfo));
}