* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include "graph/tensor.h"
#include <dlfcn.h>
#define private public
#define protected public
#include "tiling_api.h"
#include "platform_stub.h"
#include "include/adv_api/hccl/internal/hccl_tiling_msg.h"
#include "include/adv_api/hccl/hccl_tiling.h"
#include "include/adv_api/hccl/hccl_common.h"
#include "tiling/platform/platform_ascendc.h"
#include "impl/adv_api/tiling/hccl/hccl_symbol_loader.h"
using namespace ge;
using namespace std;
using namespace optiling;
using namespace AscendC;
using namespace HcclApi;
class TestHcclTiling : public testing::Test {
protected:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
virtual void SetUp() {}
void TearDown() {}
};
TEST_F(TestHcclTiling, Mc2CcTilingConfig_normal)
{
::Mc2InitTiling initTilingInner;
::Mc2CcTiling ccTilingInner;
string groupName = "test";
uint32_t opType = 1;
string algConfig = "fullmesh";
uint32_t reduceType = 1;
Mc2CcTilingConfig ccTilingConfig(groupName, opType, algConfig, reduceType);
EXPECT_EQ(ccTilingConfig.SetDebugMode(1U), EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.SetQueueNum(40U), EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.SetCommBlockNum(48U), EXIT_SUCCESS);
uint32_t ret = ccTilingConfig.GetTiling(initTilingInner);
EXPECT_EQ(ret, EXIT_SUCCESS);
opType = 0;
EXPECT_NE(ccTilingConfig.SetOpType(opType), EXIT_SUCCESS);
groupName = "test1";
EXPECT_EQ(ccTilingConfig.SetGroupName(groupName), EXIT_SUCCESS);
algConfig = "doublering";
EXPECT_EQ(ccTilingConfig.SetAlgConfig(algConfig), EXIT_SUCCESS);
reduceType = 0;
EXPECT_EQ(ccTilingConfig.SetReduceType(reduceType), EXIT_SUCCESS);
uint8_t stepSize = 1;
if (platform_ascendc::PlatformAscendCManager::GetInstance()->GetCurNpuArch() == NpuArch::DAV_2201) {
EXPECT_EQ(ccTilingConfig.SetStepSize(stepSize), EXIT_SUCCESS);
} else {
EXPECT_EQ(ccTilingConfig.SetStepSize(stepSize), EXIT_FAILURE);
}
uint8_t skipLocalRankCopy = 1;
EXPECT_EQ(ccTilingConfig.SetSkipLocalRankCopy(skipLocalRankCopy), EXIT_SUCCESS);
uint8_t skipBufferWindowCopy = 1;
EXPECT_EQ(ccTilingConfig.SetSkipBufferWindowCopy(skipBufferWindowCopy), EXIT_SUCCESS);
uint8_t commEngine = 1;
EXPECT_EQ(ccTilingConfig.SetCommEngine(commEngine), EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.GetTiling(ccTilingInner), EXIT_SUCCESS);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_failed1)
{
::Mc2CcTiling ccTilingInner;
string groupName = "test";
uint32_t opType = 1;
string algConfig = "fullmesh";
uint32_t reduceType = 1;
Mc2CcTilingConfig ccTilingConfig(groupName, opType, algConfig, reduceType);
EXPECT_NE(ccTilingConfig.SetOpType(static_cast<uint32_t>(HcclCMDType::HCCL_CMD_ALL)), EXIT_SUCCESS);
string value129 = "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678";
string value128 = "01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567";
EXPECT_NE(ccTilingConfig.SetGroupName(value129), EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.SetGroupName(value128), EXIT_SUCCESS);
EXPECT_NE(ccTilingConfig.SetAlgConfig(value129), EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.SetAlgConfig(value128), EXIT_SUCCESS);
EXPECT_NE(ccTilingConfig.SetReduceType(HCCL_REDUCE_RESERVED), EXIT_SUCCESS);
EXPECT_NE(ccTilingConfig.SetSkipLocalRankCopy(2), EXIT_SUCCESS);
EXPECT_NE(ccTilingConfig.SetSkipBufferWindowCopy(3), EXIT_SUCCESS);
EXPECT_NE(ccTilingConfig.GetTiling(ccTilingInner), EXIT_SUCCESS);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_failed2)
{
::Mc2InitTiling initTilingInner;
string groupName = "test";
uint32_t opType = static_cast<uint32_t>(HcclCMDType::HCCL_CMD_ALLREDUCE);
string algConfig = "fullmesh";
uint32_t reduceType = HCCL_REDUCE_RESERVED;
Mc2CcTilingConfig ccTilingConfig(groupName, opType, algConfig, reduceType);
uint32_t ret = ccTilingConfig.GetTiling(initTilingInner);
EXPECT_NE(ret, EXIT_SUCCESS);
EXPECT_EQ(ccTilingConfig.SetOpType(static_cast<uint32_t>(HcclCMDType::HCCL_CMD_SEND)), EXIT_SUCCESS);
ret = ccTilingConfig.GetTiling(initTilingInner);
EXPECT_EQ(ret, EXIT_SUCCESS);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_failed3)
{
::Mc2InitTiling initTilingInner;
::Mc2CcTiling ccTilingInner;
string groupName = "test";
uint32_t opType = 1;
string algConfig = "fullmesh";
uint32_t reduceType = 1;
Mc2CcTilingConfig ccTilingConfig(groupName, opType, algConfig, reduceType);
EXPECT_NE(ccTilingConfig.GetTiling(ccTilingInner), EXIT_SUCCESS);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_SetReduceType_ReduceOp)
{
const char *groupName = "testGroup";
uint32_t opType = static_cast<uint32_t>(HcclCMDType::HCCL_CMD_REDUCE_SCATTER);
std::string algConfig = "ReduceScatter=level0:doublering";
uint32_t reduceType = static_cast<uint32_t>(HcclReduceOp::HCCL_REDUCE_RESERVED);
uint8_t srcDataType = static_cast<uint32_t>(HcclDataType::HCCL_DATA_TYPE_FP16);
uint8_t dstDataType = static_cast<uint32_t>(HcclDataType::HCCL_DATA_TYPE_FP16);
Mc2CcTilingConfig mc2CcTilingConfig(groupName, opType, algConfig, reduceType, srcDataType, dstDataType);
EXPECT_EQ(mc2CcTilingConfig.SetReduceType(HcclReduceOp::HCCL_REDUCE_SUM, srcDataType, dstDataType), EXIT_SUCCESS);
dstDataType = -1;
EXPECT_EQ(mc2CcTilingConfig.SetReduceType(HcclReduceOp::HCCL_REDUCE_SUM, srcDataType, dstDataType), EXIT_FAILURE);
dstDataType = static_cast<uint32_t>(HcclDataType::HCCL_DATA_TYPE_FP16);
srcDataType = -1;
EXPECT_EQ(mc2CcTilingConfig.SetReduceType(HcclReduceOp::HCCL_REDUCE_SUM, srcDataType, dstDataType), EXIT_FAILURE);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_SetReduceType_NotReduceOp)
{
const char *groupName = "testGroup";
uint32_t opType = static_cast<uint32_t>(HcclCMDType::HCCL_CMD_ALLGATHER);
std::string algConfig = "AllGather=level0:doublering";
uint32_t reduceType = static_cast<uint32_t>(HcclReduceOp::HCCL_REDUCE_RESERVED);
uint8_t srcDataType = static_cast<uint32_t>(HcclDataType::HCCL_DATA_TYPE_FP16);
uint8_t dstDataType = static_cast<uint32_t>(HcclDataType::HCCL_DATA_TYPE_FP16);
Mc2CcTilingConfig mc2CcTilingConfig(groupName, opType, algConfig, reduceType, srcDataType, dstDataType);
EXPECT_EQ(mc2CcTilingConfig.SetReduceType(HcclReduceOp::HCCL_REDUCE_SUM, srcDataType, dstDataType), EXIT_SUCCESS);
}
TEST_F(TestHcclTiling, Mc2CcTilingConfig_multiTiling)
{
::Mc2InitTiling initTilingInner;
::Mc2CcTiling ccTilingInner;
Mc2CcTilingConfig ccTilingConfig("test", 1, "fullmesh", 1);
EXPECT_EQ(ccTilingConfig.GetTiling(initTilingInner), EXIT_SUCCESS);
for (uint32_t i = 0; i < 8; ++i) {
EXPECT_EQ(ccTilingConfig.GetTiling(ccTilingInner), EXIT_SUCCESS);
}
EXPECT_NE(ccTilingConfig.GetTiling(ccTilingInner), EXIT_SUCCESS);
}
class TestHcclSymbolLoader : public testing::Test {
protected:
static void SetUpTestCase() {}
static void TearDownTestCase() {}
void SetUp() override
{
const char *env = std::getenv("ASCEND_HOME_PATH");
hasOldEnv_ = (env != nullptr);
if (hasOldEnv_) {
oldEnv_ = env;
}
}
void TearDown() override
{
if (hasOldEnv_) {
setenv("ASCEND_HOME_PATH", oldEnv_.c_str(), 1);
} else {
unsetenv("ASCEND_HOME_PATH");
}
}
bool hasOldEnv_ = false;
std::string oldEnv_;
};
TEST_F(TestHcclSymbolLoader, Load_validPath_success)
{
const char *homePath = std::getenv("ASCEND_HOME_PATH");
ASSERT_NE(homePath, nullptr);
std::string pathName = std::string(homePath) + "/lib64/";
auto func = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime.so", "rtGetSocVersion", pathName);
EXPECT_NE(func, nullptr);
}
TEST_F(TestHcclSymbolLoader, Load_cached_returnsSamePointer)
{
const char *homePath = std::getenv("ASCEND_HOME_PATH");
ASSERT_NE(homePath, nullptr);
std::string pathName = std::string(homePath) + "/lib64/";
auto func1 = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime.so", "rtGetSocVersion", pathName);
auto func2 = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime.so", "rtGetSocVersion", pathName);
EXPECT_NE(func1, nullptr);
EXPECT_EQ(func1, func2);
}
TEST_F(TestHcclSymbolLoader, Load_invalidPath_returnsNull)
{
std::string pathName = "/path/that/does/not/exist/";
auto func = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime_not_exist.so", "rtGetSocVersion", pathName);
EXPECT_EQ(func, nullptr);
}
TEST_F(TestHcclSymbolLoader, Load_emptyPath_returnsNull)
{
std::string pathName = "";
auto func = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime_not_exist.so", "rtGetSocVersion", pathName);
EXPECT_EQ(func, nullptr);
}
TEST_F(TestHcclSymbolLoader, Load_invalidSymbol_returnsNull)
{
const char *homePath = std::getenv("ASCEND_HOME_PATH");
ASSERT_NE(homePath, nullptr);
std::string pathName = std::string(homePath) + "/lib64/";
auto func = HcclSymbolLoader::GetInstance().Load<void (*)(char*, uint32_t)>(
"libruntime.so", "rtSymbolDefinitelyNotExist", pathName);
EXPECT_EQ(func, nullptr);
}
TEST_F(TestHcclSymbolLoader, GetInitTiling_withValidAscendHome_success)
{
const char *homePath = std::getenv("ASCEND_HOME_PATH");
ASSERT_NE(homePath, nullptr);
::Mc2InitTiling initTilingInner;
Mc2CcTilingConfig ccTilingConfig("test", 1, "fullmesh", 1);
EXPECT_EQ(ccTilingConfig.GetTiling(initTilingInner), EXIT_SUCCESS);
}
TEST_F(TestHcclSymbolLoader, GetInitTiling_ascendHomeUnset_failure)
{
unsetenv("ASCEND_HOME_PATH");
ASSERT_EQ(std::getenv("ASCEND_HOME_PATH"), nullptr);
::Mc2InitTiling initTilingInner;
Mc2CcTilingConfig ccTilingConfig("test", 1, "fullmesh", 1);
EXPECT_EQ(ccTilingConfig.GetTiling(initTilingInner), EXIT_FAILURE);
}
TEST_F(TestHcclSymbolLoader, GetInitTiling_ascendHomeEmpty_failure)
{
setenv("ASCEND_HOME_PATH", "", 1);
const char *homePath = std::getenv("ASCEND_HOME_PATH");
ASSERT_NE(homePath, nullptr);
ASSERT_EQ(homePath[0], '\0');
::Mc2InitTiling initTilingInner;
Mc2CcTilingConfig ccTilingConfig("test", 1, "fullmesh", 1);
EXPECT_EQ(ccTilingConfig.GetTiling(initTilingInner), EXIT_FAILURE);
}