* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <gtest/gtest.h>
#include "CommunicationProtocolRequest.h"
#include "DataBaseManager.h"
#include "ClusterDomainObject.h"
#include "ClusterDef.h"
#include "../../TestSuit.h"
class CommunicationTest : TestSuit {};
TEST_F(TestSuit, QueryIterationsData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Protocol::IterationsOrRanksObject> responseBody;
database->QueryIterations(responseBody);
int expectSize = 1;
EXPECT_EQ(responseBody.size(), expectSize);
EXPECT_EQ(responseBody[0].iterationOrRankId, "2");
}
TEST_F(TestSuit, QueryOperatorNameData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::OperatorNamesParams requestParams;
requestParams.iterationId = "2";
requestParams.groupIdHash = "13337953553061386822";
std::vector<Dic::Protocol::OperatorNamesObject> responseBody;
database->QueryOperatorNames(requestParams, responseBody);
int expectSize = 2;
EXPECT_EQ(responseBody.size(), expectSize);
EXPECT_EQ(responseBody[0].operatorName, "Total Op Info");
EXPECT_EQ(responseBody[1].operatorName, "hcom_send__822_0");
}
TEST_F(TestSuit, QueryIterationAndCommunicationGroupSuccess) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::CommunicationKernelParams requestParams;
requestParams.name = "hcom_broadcast__483_0";
requestParams.rankId = "0";
Dic::Protocol::CommunicationKernelBody responseBody;
database->QueryIterationAndCommunicationGroup(requestParams, responseBody);
EXPECT_EQ(responseBody.step, "2");
EXPECT_EQ(responseBody.group, "(0, 1, 2, 3, 4, 5, 6, 7)");
}
TEST_F(TestSuit, QueryOperatorNamesFailWithErrorGroupIdHash) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::OperatorNamesParams requestParams;
requestParams.iterationId = "2";
requestParams.groupIdHash = "test";
std::vector<Dic::Protocol::OperatorNamesObject> responseBody;
database->QueryOperatorNames(requestParams, responseBody);
int expectSize = 0;
EXPECT_EQ(responseBody.size(), expectSize);
}
TEST_F(TestSuit, QueryParseClusterStatusSuccess) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string status = database->QueryParseClusterStatus();
EXPECT_EQ(status, "FINISH");
}
TEST_F(TestSuit, QueryRanksData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Protocol::IterationsOrRanksObject> responseBody;
database->QueryRanksHandler(responseBody);
int expectSize = 16;
EXPECT_EQ(responseBody.size(), expectSize);
EXPECT_EQ(responseBody[0].iterationOrRankId, "0");
EXPECT_EQ(responseBody[1].iterationOrRankId, "1");
}
TEST_F(TestSuit, QueryDurationData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::DurationListParams requestParams;
std::vector<Dic::Module::DurationDo> durationList;
requestParams.iterationId = "2";
requestParams.operatorName = "hcom_send__822_0";
requestParams.groupIdHash = "13337953553061386822";
database->QueryDurationList(requestParams, durationList);
int expectSize = 2;
EXPECT_EQ(durationList.size(), expectSize);
}
TEST_F(TestSuit, QueryBandwidthDistributionData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::DistributionDataParam requestParams;
Dic::Protocol::DistributionResBody responseBody;
requestParams.iterationId = "2";
requestParams.stage = "(0, 1, 2, 3, 4, 5, 6, 7)";
requestParams.operatorName = "hcom_broadcast__483_1";
requestParams.transportType = "HCCS";
requestParams.rankId = "1";
requestParams.groupIdHash = "9350434047717501483";
database->QueryDistributionData(requestParams, responseBody);
std::string expectResult = "{\"0.016512\":[7,0.01114],\"0.015504\":[1,0.00166]}";
EXPECT_EQ(responseBody.distributionData, expectResult);
}
TEST_F(TestSuit, QueryBandwidthData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::BandwidthDataParam requestParams;
Dic::Protocol::BandwidthDataResBody responseBody;
requestParams.iterationId = "2";
requestParams.operatorName = "hcom_send__822_0";
requestParams.rankId = "1";
requestParams.groupIdHash = "13337953553061386822";
database->QueryBandwidthData(requestParams, responseBody);
int expectSize = 4;
EXPECT_EQ(responseBody.items.size(), expectSize);
EXPECT_EQ(responseBody.items[0].transportType, "RDMA");
EXPECT_EQ(responseBody.items[0].transitSize, 20.9715);
EXPECT_EQ(responseBody.items[0].transitTime, 0.8638);
EXPECT_EQ(responseBody.items[0].bandwidth, 24.2777);
EXPECT_EQ(responseBody.items[0].largePacketRatio, 0.0);
}
TEST_F(TestSuit, QueryBandwidthDataWithErrorParamReturnExpectSize0) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::BandwidthDataParam requestParams;
Dic::Protocol::BandwidthDataResBody responseBody;
requestParams.iterationId = "100";
requestParams.stage = "p2p";
requestParams.operatorName = "Total Op Info";
requestParams.rankId = "1";
database->QueryBandwidthData(requestParams, responseBody);
const int expectSize = 0;
EXPECT_EQ(responseBody.items.size(), expectSize);
}
TEST_F(TestSuit, QueryOperatorsCount) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::OperatorDetailsParam requestParams;
Dic::Protocol::OperatorDetailsResBody responseBody;
requestParams.iterationId = "2";
requestParams.rankId = "1";
requestParams.groupIdHash = "13337953553061386822";
database->QueryOperatorsCount(requestParams, responseBody);
const int expectSize = 1;
EXPECT_EQ(responseBody.count, expectSize);
requestParams.iterationId = "2";
requestParams.rankId = "1";
requestParams.groupIdHash = "9350434047717501483";
database->QueryOperatorsCount(requestParams, responseBody);
const int stageExpectSize = 2;
EXPECT_EQ(responseBody.count, stageExpectSize);
}
TEST_F(TestSuit, GetCommunicationGroups) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Module::GroupInfoDo> groupList;
database->GetGroups(groupList);
int expectSize = 28;
EXPECT_EQ(groupList.size(), expectSize);
}
TEST_F(TestSuit, GetAllRankFromStepStatisticInfoSuccess) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<std::string> res = database->GetAllRankFromStepStatisticInfo();
int expectSize = 16;
EXPECT_EQ(res.size(), expectSize);
}
TEST_F(TestSuit, QueryMatrixData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::MatrixBandwidthParam requestParams;
std::vector<Dic::Module::MatrixInfoDo> matrixList;
requestParams.iterationId = "2";
requestParams.operatorName = "hcom_send__822_0";
requestParams.groupIdHash = "13337953553061386822";
database->QueryMatrixList(requestParams, matrixList);
int expectSize = 2;
EXPECT_EQ(matrixList.size(), expectSize);
EXPECT_EQ(matrixList[1].srcRank, 0);
EXPECT_EQ(matrixList[1].dstRank, 8);
EXPECT_EQ(matrixList[1].transportType, "RDMA");
EXPECT_EQ(matrixList[1].transitSize, 20.9715);
EXPECT_EQ(matrixList[1].transitTime, 0.8638);
EXPECT_EQ(matrixList[1].bandwidth, 24.2778);
}
TEST_F(TestSuit, QueryAllCommunicationOperatorsDetails) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
Dic::Protocol::OperatorDetailsParam requestParams;
Dic::Protocol::OperatorDetailsResBody responseBody;
requestParams.iterationId = "2";
requestParams.rankId = "1";
requestParams.groupIdHash = "13337953553061386822";
requestParams.currentPage = 1;
requestParams.pageSize = 10;
database->QueryAllOperators(requestParams, responseBody);
int expectSize = 1;
EXPECT_EQ(responseBody.allOperators.size(), expectSize);
EXPECT_EQ(responseBody.allOperators[0].operatorName, "hcom_send__822_0");
EXPECT_EQ(responseBody.allOperators[0].elapseTime, 0.8966);
EXPECT_EQ(responseBody.allOperators[0].transitTime, 0.8638);
EXPECT_EQ(responseBody.allOperators[0].synchronizationTime, 0);
EXPECT_EQ(responseBody.allOperators[0].waitTime, 0);
EXPECT_EQ(responseBody.allOperators[0].idleTime, 0.0327);
EXPECT_EQ(responseBody.allOperators[0].synchronizationTimeRatio, 0.0);
EXPECT_EQ(responseBody.allOperators[0].waitTimeRatio, 0.0);
}
TEST_F(TestSuit, QueryMatrixSortOpNames) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Protocol::OperatorNamesObject> responseBody;
Dic::Protocol::OperatorNamesParams requestParams;
requestParams.iterationId = "2";
requestParams.groupIdHash = "9350434047717501483";
requestParams.rankList = {};
database->QueryMatrixSortOpNames(requestParams, responseBody);
int expectSize = 2;
EXPECT_EQ(responseBody.size(), expectSize);
}
TEST_F(TestSuit, QueryAllPerformanceDataByStepWhenSingleStep) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string step = "2";
std::unordered_map<std::uint32_t, Dic::Module::StepStatistic> data{};
auto result = database->QueryAllPerformanceDataByStep(step, data);
EXPECT_EQ(result, true);
EXPECT_EQ(data.size(), 16);
EXPECT_EQ(data.at(0).prepareTime, 0);
step = "3";
data.clear();
result = database->QueryAllPerformanceDataByStep(step, data);
EXPECT_EQ(result, true);
EXPECT_EQ(data.size(), 0);
}
TEST_F(TestSuit, GetCommTimeForRankDimByStepWhenSingleStep) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string step = "2";
std::vector<Dic::Module::CommInfoUnderRank> result = database->GetCommTimeForRankDim(step);
const int exceptSize = 4;
EXPECT_EQ(result.size(), exceptSize);
}
TEST_F(TestSuit, GetCommTimeForRankDimByStepWhenAllStep) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string step = "";
std::vector<Dic::Module::CommInfoUnderRank> result = database->GetCommTimeForRankDim(step);
const int exceptSize = 4;
EXPECT_EQ(result.size(), exceptSize);
}
TEST_F(TestSuit, QueryAllPerformanceDataByStepWhenAllStep) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string step = "";
std::unordered_map<std::uint32_t, Dic::Module::StepStatistic> data{};
auto result = database->QueryAllPerformanceDataByStep(step, data);
EXPECT_EQ(result, true);
EXPECT_EQ(data.size(), 16);
EXPECT_EQ(data.at(0).prepareTime, 0);
EXPECT_NEAR(data.at(0).freeTime, 69909.504, 0.0001);
step = "All";
data.clear();
result = database->QueryAllPerformanceDataByStep(step, data);
EXPECT_EQ(result, true);
EXPECT_EQ(data.size(), 16);
EXPECT_EQ(data.at(0).prepareTime, 0);
EXPECT_NEAR(data.at(0).freeTime, 69909.504, 0.0001);
}
TEST_F(TestSuit, QueryPacketAnalyzerDataTest) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Module::PacketAnalyzerData> data;
bool result = database->QueryPacketAnalyzerData(data);
int expectSize = 5;
ASSERT_TRUE(result);
ASSERT_EQ(data.size(), expectSize);
EXPECT_EQ(data[0].type, "RDMA");
EXPECT_EQ(data[1].type, "RDMA");
EXPECT_EQ(data[2].type, "SDMA");
EXPECT_EQ(data[3].type, "SDMA");
EXPECT_EQ(data[4].type, "SDMA");
}
TEST_F(TestSuit, QueryBandwidthContentionAnalyzerDataTest) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Module::BandwidthContentionSDMAInfo> res;
std::string rankId = "0";
bool result = database->QueryBandwidthContentionAnalyzerData(res, rankId);
ASSERT_TRUE(result);
ASSERT_EQ(res.size(), 2);
EXPECT_EQ(res[0].name, "Total Op Info");
EXPECT_EQ(res[1].name, "hcom_broadcast__483_1");
rankId = "1";
res.clear();
result = database->QueryBandwidthContentionAnalyzerData(res, rankId);
ASSERT_TRUE(result);
ASSERT_EQ(res.size(), 3);
EXPECT_EQ(res[0].name, "Total Op Info");
EXPECT_EQ(res[1].name, "hcom_broadcast__483_0");
EXPECT_EQ(res[2].name, "hcom_broadcast__483_1");
}
TEST_F(TestSuit, QueryRetransmissionAnalyzerClassificationData) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::vector<Dic::Module::RetransmissionClassificationInfo> data;
bool result = database->QueryRetransmissionAnalyzerData(data);
int expectSize = 2;
ASSERT_TRUE(result);
ASSERT_EQ(data.size(), expectSize);
}
TEST_F(TestSuit, GetOpStatByStepIdSuccess) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string stepId = "2";
std::vector<Dic::Module::OpTypeStatistics> res = database->GetOpStatByStepId(stepId);
const int expectSize = 2;
ASSERT_EQ(res.size(), expectSize);
EXPECT_EQ(res[0].opType, "hcombroadcast");
EXPECT_EQ(res[1].opType, "hcomsend");
}
TEST_F(TestSuit, GetCommTimeForRankDimTest) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string iterationId = "2";
auto commTimeForRankDim = database->GetCommTimeForRankDim(iterationId);
EXPECT_EQ(commTimeForRankDim.size(), 4);
}
TEST_F(TestSuit, QuerySlowOpByCommDurationTest) {
auto database = Dic::Module::Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
std::string fastestRankId = "1";
Protocol::DurationListParams params;
params.iterationId = "2";
params.groupIdHash = "9350434047717501483";
RankDetailsForSlowRank slowRank;
slowRank.rankId = "0";
auto res = database->QuerySlowOpByCommDuration(params, fastestRankId, slowRank);
EXPECT_TRUE(res);
EXPECT_EQ(slowRank.opDetails.size(), 2);
}