* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <gtest/gtest.h>
#include <GlobalDefs.h>
#include "ProtocolDefs.h"
#include "SummaryProtocol.h"
#include "SummaryProtocolUtil.h"
#include "SummaryProtocolRequest.h"
#include "SummaryProtocolResponse.h"
#ifdef GetObject
#pragma push_macro("GetObject")
#define RAPIDJSON_WINDOWS_GETOBJECT_WORKAROUND_APPLIED
#undef GetObject
#endif
using namespace Dic::Protocol;
using namespace Dic::Module;
class SummaryProtocolUtilTest : public ::testing::Test {
protected:
void SetUp() override { protocol.Register(); }
void TearDown() override { protocol.UnRegister(); }
Dic::Protocol::SummaryProtocol protocol;
};
TEST_F(SummaryProtocolUtilTest, ToTopNRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryTopData", "resultCallbackId": 0,
"params": {"isCompare": true, "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/queryTopData");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToTopNRequestLackIdTestReturnNull) {
std::string reqJson = R"({"moduleName": "summary", "type": "request",
"command": "summary/queryTopData", "resultCallbackId": 0, "params": {"isCompare": true}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Failed to set request base info of topN request.");
}
TEST_F(SummaryProtocolUtilTest, ToStatisticsRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/statistic", "resultCallbackId": 0,
"params": {"rankId": "0", "stepId": "1", "timeFlag": "all", "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/statistic");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToStatisticsRequestLackParamsTestReturnNull) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/statistic", "resultCallbackId": 0, "paramx": {"rankId": "0"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Failed to set request base info of statistics request.");
}
TEST_F(SummaryProtocolUtilTest, ToComputeDetailRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryComputeDetail", "resultCallbackId": 0,
"params": {"rankId": "0", "dbPath": "/data/db", "currentPage": 1, "timeFlag": "all", "pageSize": 10,
"orderBy": "time", "order": "desc", "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/queryComputeDetail");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToComputeDetailRequestLackModuleNameTestReturnNull) {
std::string reqJson = R"({"id": 1, "type": "request",
"command": "summary/queryComputeDetail", "resultCallbackId": 0, "params": {"rankId": "0"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Failed to set request base info of compute detail request.");
}
TEST_F(SummaryProtocolUtilTest, ToCommunicationRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryCommunicationDetail", "resultCallbackId": 0,
"params": {"rankId": "0", "currentPage": 1, "timeFlag": "all", "pageSize": 10, "orderBy": "time",
"order": "desc", "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/queryCommunicationDetail");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToCommunicationRequestLackIdTestReturnNull) {
std::string reqJson = R"({"moduleName": "summary", "type": "request",
"command": "summary/queryCommunicationDetail", "resultCallbackId": 0, "params": {"rankId": "0"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Failed to set request base info of communication request.");
}
TEST_F(SummaryProtocolUtilTest, ToImportExpertDataRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/importExpertData", "resultCallbackId": 0,
"params": {"filePath": "/data/expert.json", "version": "1.0", "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/importExpertData");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToImportExpertDataRequestLackParamsTestReturnNull) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/importExpertData", "resultCallbackId": 0})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
}
TEST_F(SummaryProtocolUtilTest, ToQueryExpertHotspotRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryExpertHotspot", "resultCallbackId": 0,
"params": {"modelStage": "prefill", "version": "1.0", "layerNum": 60, "expertNum": 8,
"denseLayerList": [0, 1], "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/queryExpertHotspot");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToQueryExpertHotspotRequestLackParamsTestReturnNull) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryExpertHotspot", "resultCallbackId": 0})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
}
TEST_F(SummaryProtocolUtilTest, ToQueryModelInfoRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/queryModelInfo", "resultCallbackId": 0, "params": {"clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->command, "summary/queryModelInfo");
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToQueryModelInfoRequestLackIdTestReturnNull) {
std::string reqJson = R"({"moduleName": "summary", "type": "request",
"command": "summary/queryModelInfo", "resultCallbackId": 0, "params": {"clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Failed to set request base info of query model info request.");
}
TEST_F(SummaryProtocolUtilTest, ToSummarySlowRankAdvisorRequestNormalTest) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/slowRank/advisor", "resultCallbackId": 0,
"params": {"algorithm": "test", "tpSize": 2, "ppSize": 3, "dpSize": 4, "epSize": 1,
"dimension": "ep-dp-pp-cp-tp", "clusterPath": "/data"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_NE(result, nullptr);
EXPECT_EQ(result->id, 1);
EXPECT_EQ(result->type, ProtocolMessage::Type::REQUEST);
EXPECT_EQ(result->moduleName, MODULE_SUMMARY);
}
TEST_F(SummaryProtocolUtilTest, ToSummarySlowRankAdvisorRequestLackKeyAlgorithmTestReturnNull) {
std::string reqJson = R"({"id": 1, "moduleName": "summary", "type": "request",
"command": "summary/slowRank/advisor", "resultCallbackId": 0, "params": {"tpSize": 2}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = protocol.FromJson(json, err);
EXPECT_EQ(result, nullptr);
EXPECT_EQ(err, "Query parallelism arrangement request didn't have key: algorithm");
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelStrategyRequestTest) {
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "summary/query/parallelStrategy", "resultCallbackId": 0, "params": {}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = dynamic_cast<QueryParallelStrategyRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_SUMMARY_QUERY_PARALLEL_STRATEGY);
}
TEST_F(SummaryProtocolUtilTest, ToSetParallelStrategyRequestTest) {
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "summary/set/parallelStrategy", "resultCallbackId": 0, "params": {}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto request = protocol.FromJson(json, err);
EXPECT_TRUE(request == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "summary/set/parallelStrategy", "resultCallbackId": 0, "params":
{"algorithm": "test", "tpSize": 2, "ppSize": 3, "dpSize": 4}})";
json.Parse(reqJson.c_str());
auto result = dynamic_cast<SetParallelStrategyRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_SUMMARY_SET_PARALLEL_STRATEGY);
EXPECT_EQ(result.params.config.algorithm, "test");
EXPECT_EQ(result.params.config.tpSize, 2);
EXPECT_EQ(result.params.config.ppSize, 3);
EXPECT_EQ(result.params.config.dpSize, 4);
EXPECT_EQ(result.params.config.cpSize, 1);
EXPECT_EQ(result.params.config.epSize, 1);
}
TEST_F(SummaryProtocolUtilTest, ToSetParallelStrategyRequestWithCpAndTpTest) {
Dic::document_t json;
std::string err;
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "summary/set/parallelStrategy", "resultCallbackId": 0, "params":
{"algorithm": "test", "tpSize": 2, "ppSize": 3, "dpSize": 4, "cpSize": 5, "epSize": 6}})";
json.Parse(reqJson.c_str());
auto result = dynamic_cast<SetParallelStrategyRequest &>(*(protocol.FromJson(json, err)));
const int64_t expectCp = 5;
const int64_t expectEp = 6;
EXPECT_EQ(result.command, REQ_RES_SUMMARY_SET_PARALLEL_STRATEGY);
EXPECT_EQ(result.params.config.algorithm, "test");
EXPECT_EQ(result.params.config.tpSize, 2);
EXPECT_EQ(result.params.config.ppSize, 3);
EXPECT_EQ(result.params.config.dpSize, 4);
EXPECT_EQ(result.params.config.cpSize, expectCp);
EXPECT_EQ(result.params.config.epSize, expectEp);
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismArrangementRequestWillReturnTrueWhenInputCorrect) {
Dic::document_t json;
std::string err;
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/arrangement/all", "resultCallbackId": 0, "params": {"algorithm": "test",
"tpSize": 2, "ppSize": 3, "dpSize": 4, "cpSize": 5, "epSize": 6,
"moeTpSize": 7, "dimension": "ep-dp-pp-cp-tp"}})";
json.Parse(reqJson.c_str());
auto result = dynamic_cast<QueryParallelismArrangementRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_PARALLELISM_ARRANGEMENT_ALL);
EXPECT_EQ(result.params.config.algorithm, "test");
EXPECT_EQ(result.params.config.tpSize, 2);
EXPECT_EQ(result.params.config.ppSize, 3);
EXPECT_EQ(result.params.config.dpSize, 4);
EXPECT_EQ(result.params.config.cpSize, 5);
EXPECT_EQ(result.params.config.epSize, 6);
EXPECT_EQ(result.params.config.moeTpSize, 7);
EXPECT_EQ(result.params.dimension, "ep-dp-pp-cp-tp");
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismArrangementRequestWillReturnNullWhenInputWrong) {
Dic::document_t json;
std::string err;
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/arrangement/all", "resultCallbackId": 0, "params": {
"tpSize": 2, "ppSize": 3, "dpSize": 4, "cpSize": 5, "epSize": 6, "dimension": "ep-dp-pp-cp-tp"}})";
json.Parse(reqJson.c_str());
auto result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/arrangement/all", "resultCallbackId": 0, "params": {"algorithm": "test",
"ppSize": 3, "dpSize": 4, "cpSize": 5, "epSize": 6, "dimension": "ep-dp-pp-cp-tp"}})";
json.Parse(reqJson.c_str());
result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/arrangement/all", "resultCallbackId": 0, "params": {"algorithm": "test",
"tpSize": 2, "dpSize": 4, "cpSize": 5, "epSize": 6, "dimension": "ep-dp-pp-cp-tp"}})";
json.Parse(reqJson.c_str());
result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismPerformanceRequestWillReturnTrueWhenInputCorrect) {
Dic::document_t json;
std::string err;
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/performance/data", "params": {"algorithm": "test", "tpSize": 2, "ppSize": 3,
"dpSize": 4, "cpSize": 5, "epSize": 6, "moeTpSize": 7,
"dimension": "ep-dp-pp-cp-tp", "step": "all"}})";
json.Parse(reqJson.c_str());
auto result = dynamic_cast<QueryParallelismPerformanceRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_PARALLELISM_PERFORMANCE_DATA);
EXPECT_EQ(result.params.config.algorithm, "test");
EXPECT_EQ(result.params.config.tpSize, 2);
EXPECT_EQ(result.params.config.ppSize, 3);
EXPECT_EQ(result.params.config.dpSize, 4);
EXPECT_EQ(result.params.config.cpSize, 5);
EXPECT_EQ(result.params.config.epSize, 6);
EXPECT_EQ(result.params.config.moeTpSize, 7);
EXPECT_EQ(result.params.dimension, "ep-dp-pp-cp-tp");
EXPECT_EQ(result.params.step, "all");
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismPerformanceRequestTest) {
Dic::document_t json;
std::string err;
std::string reqJson = "{\"id\":46,\"moduleName\":\"summary\",\"type\":\"request\",\"command\":"
"\"parallelism/performance/data\",\"projectName\":\"test\""
",\"params\":{\"step\":\"All\",\"baselineStep\":null,"
"\"algorithm\":\"megatron-lm(tp-cp-ep-dp-pp)\","
"\"dimension\":\"ep-dp\",\"ppSize\":2,\"tpSize\":2,"
"\"cpSize\":2,\"dpSize\":2,\"epSize\":2,\"moeTpSize\":1,"
"\"clusterPath\":\"test\",\"isCompare\":false}}";
json.Parse(reqJson.c_str());
auto result = dynamic_cast<QueryParallelismPerformanceRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_PARALLELISM_PERFORMANCE_DATA);
EXPECT_EQ(result.params.config.algorithm, "megatron-lm(tp-cp-ep-dp-pp)");
EXPECT_EQ(result.params.config.tpSize, 2);
EXPECT_EQ(result.params.config.ppSize, 2);
EXPECT_EQ(result.params.config.dpSize, 2);
EXPECT_EQ(result.params.config.cpSize, 2);
EXPECT_EQ(result.params.config.epSize, 2);
EXPECT_EQ(result.params.config.moeTpSize, 1);
EXPECT_EQ(result.params.dimension, "ep-dp");
EXPECT_EQ(result.params.step, "All");
EXPECT_EQ(result.params.clusterPath, "test");
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismPerformanceRequestWillReturnNullWhenInputWithWrong) {
Dic::document_t json;
std::string err;
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/performance/data", "params": {"tpSize": 2, "ppSize": 3,
"dpSize": 4, "cpSize": 5, "epSize": 6, "dimension": "ep-dp-pp-cp-tp", "step": "all"}})";
json.Parse(reqJson.c_str());
auto result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/performance/data", "params": {"algorithm": "test", "ppSize": 3, "cpSize": 4,
"dpSize": 5, "epSize": 6, "dimension": "ep-dp-pp-cp-tp",
"step": "all"}})";
json.Parse(reqJson.c_str());
result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request", "command": "parallelism/performance/data",
"params": {"algorithm": "test", "tpSize": 3, "cpSize": 4, "dpSize": 5, "epSize": 6,
"dimension": "ep-dp-pp-cp-tp", "step": "all"}})";
json.Parse(reqJson.c_str());
result = protocol.FromJson(json, err);
EXPECT_TRUE(result == nullptr);
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelStrategyResponseTest) {
Dic::Protocol::QueryParallelStrategyResponse response;
std::string err;
response.config.algorithm = "megatron-lm";
const int64_t expectCp = 7;
const int64_t expectEp = 9;
const int64_t expectMoeTp = 9;
response.config.tpSize = 8;
response.config.ppSize = 4;
response.config.dpSize = 2;
response.config.cpSize = expectCp;
response.config.epSize = expectEp;
response.config.moeTpSize = expectMoeTp;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.value()["body"][KEY_ALGORITHM.c_str()], response.config.algorithm.c_str());
EXPECT_EQ(jsonOptional.value()["body"][KEY_TP_SIZE.c_str()], response.config.tpSize);
EXPECT_EQ(jsonOptional.value()["body"][KEY_PP_SIZE.c_str()], response.config.ppSize);
EXPECT_EQ(jsonOptional.value()["body"][KEY_DP_SIZE.c_str()], response.config.dpSize);
EXPECT_EQ(jsonOptional.value()["body"][KEY_CP_SIZE.c_str()], response.config.cpSize);
EXPECT_EQ(jsonOptional.value()["body"][KEY_EP_SIZE.c_str()], response.config.epSize);
EXPECT_EQ(jsonOptional.value()["body"][KEY_MOE_TP_SIZE.c_str()], response.config.moeTpSize);
}
TEST_F(SummaryProtocolUtilTest, ToSetParallelStrategyResponseTest) {
Dic::Protocol::SetParallelStrategyResponse response;
std::string err;
response.result = false;
response.msg = "test";
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.value()["body"][KEY_RESULT.c_str()], response.result);
EXPECT_EQ(jsonOptional.value()["body"][KEY_MSG.c_str()], response.msg.c_str());
}
TEST_F(SummaryProtocolUtilTest, ToQueryFwdBwdTimelineRequestWillReturnNullWhenInputWrong) {
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/pipeline/fwdBwdTimeline", "resultCallbackId": 0, "params": {}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto request = protocol.FromJson(json, err);
EXPECT_TRUE(request == nullptr);
reqJson = R"({"id": 2, "moduleName": "summary", "type": "request",
"command": "parallelism/pipeline/fwdBwdTimeline", "resultCallbackId": 0, "params": {"step": "2"}})";
json.Parse(reqJson.c_str());
request = protocol.FromJson(json, err);
EXPECT_TRUE(request == nullptr);
}
TEST_F(SummaryProtocolUtilTest, ToQueryFwdBwdTimelineRequestWillReturnTrueWhenInputCorrect) {
std::string reqJson = R"({"id": 2, "moduleName": "summary", "type": "request", "resultCallbackId": 0,
"command": "parallelism/pipeline/fwdBwdTimeline", "params": {"stepId": "2", "stageId": "3"}})";
Dic::document_t json;
json.Parse(reqJson.c_str());
std::string err;
auto result = dynamic_cast<PipelineFwdBwdTimelineRequest &>(*(protocol.FromJson(json, err)));
EXPECT_EQ(result.command, REQ_RES_PIPELINE_FWD_BWD_TIMELINE);
EXPECT_EQ(result.params.stepId, "2");
EXPECT_EQ(result.params.stageId, "3");
}
TEST_F(SummaryProtocolUtilTest, ToQueryFwdBwdTimelineResponseTestWillReturnWhenEmptyInput) {
Dic::Protocol::PipelineFwdBwdTimelineResponse response{};
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("minTime"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("minTime"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("rankList"), true);
EXPECT_EQ(jsonOptional.value()["body"]["rankList"].GetArray().Size(), 0);
}
TEST_F(SummaryProtocolUtilTest, ToQueryFwdBwdTimelineResponseTestWillReturnWhenNormalInput) {
Dic::Protocol::PipelineFwdBwdTimelineResponse response{};
response.body.maxTime = 10086;
response.body.minTime = 10010;
response.body.rankLists = {"0", "1"};
PipelineFwdBwdTimelineByComponent rank0FwdBwd = {"FWD/BWD",
{{"FP", 3000, 123456, 135678, 0, "FWD/BWD", "0", "1", "FWD"},
{"BP", 5000, 147890, 234567, 0, "FWD/BWD", "0", "1", "BWD"}}};
PipelineFwdBwdTimelineByComponent rank0P2P = {
"P2P", {{"hcom_send", 3000, 136789, 137890, 0, "P2P", "0", "1", "SEND"}}};
PipelineFwdBwdTimelineByRank rank0 = {"0", {"FWD/BWD", "P2P"}, {rank0FwdBwd, rank0P2P}};
PipelineFwdBwdTimelineByComponent rank1FwdBwd = {
"FWD/BWD", {{"FP", 3000, 123456, 135678, 0, "FWD/BWD", "0", "1", "FWD"}}};
PipelineFwdBwdTimelineByRank rank1 = {"1", {"FWD/BWD"}, {rank1FwdBwd}};
response.body.rankDataList = {rank0, rank1};
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("minTime"), true);
EXPECT_EQ(jsonOptional.value()["body"]["minTime"], response.body.minTime);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("rankList"), true);
EXPECT_EQ(jsonOptional.value()["body"]["rankList"].GetArray().Size(), response.body.rankDataList.size());
int i = 0;
for (const auto &item : jsonOptional.value()["body"]["rankList"].GetArray()) {
auto tmp = response.body.rankDataList.at(i);
EXPECT_EQ(item.HasMember("rank"), true);
EXPECT_EQ(item["rank"].GetString(), tmp.rankId);
EXPECT_EQ(item.HasMember("componentList"), true);
EXPECT_EQ(item["componentList"].GetArray().Size(), tmp.componentDataList.size());
int j = 0;
for (const auto &componentItem : item["componentList"].GetArray()) {
auto componentTmp = tmp.componentDataList.at(j);
EXPECT_EQ(componentItem.HasMember("component"), true);
EXPECT_EQ(componentItem["component"].GetString(), componentTmp.component);
EXPECT_EQ(componentItem.HasMember("traceList"), true);
EXPECT_EQ(componentItem["traceList"].GetArray().Size(), componentTmp.traceList.size());
j++;
}
i++;
}
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismArrangementResponseTestWillReturnWhenNormalInput) {
Dic::Protocol::ParallelismArrangementResponse response{};
IndicatorAttr attr = {.number = 0,
.key = "computingTime",
.name = "computing time",
.renderHeatMap = true,
.renderChart = false,
.visible = true,
.chart = "bar",
.stack = "time",
.yAxisType = "time"};
response.arrangeData.indicators.push_back(attr);
Position pos = {0, 0};
Element ele;
ele.index = 0;
ele.name = "ep0-dp0-cp0-pp0-tp0";
ele.position = pos;
ele.indexAttributes["tpIndex"] = 0;
response.arrangeData.arrangements.push_back(ele);
response.arrangeData.size = response.arrangeData.arrangements.size();
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("arrangements"), true);
ASSERT_EQ(jsonOptional.value()["body"]["arrangements"].GetArray().Size(), response.arrangeData.size);
int i = 0;
for (const auto &item : jsonOptional.value()["body"]["arrangements"].GetArray()) {
auto tmp = response.arrangeData.arrangements.at(i);
EXPECT_EQ(item["index"].GetUint(), tmp.index);
EXPECT_EQ(item["name"].GetString(), tmp.name);
EXPECT_EQ(item.HasMember("position"), true);
EXPECT_EQ(item.HasMember("attribute"), true);
i++;
}
EXPECT_EQ(jsonOptional.value()["body"].HasMember("indicators"), true);
ASSERT_EQ(jsonOptional.value()["body"]["indicators"].GetArray().Size(), response.arrangeData.indicators.size());
i = 0;
for (const auto &item : jsonOptional.value()["body"]["indicators"].GetArray()) {
auto tmp = response.arrangeData.indicators.at(i);
EXPECT_EQ(item["key"].GetString(), tmp.key);
EXPECT_EQ(item["name"].GetString(), tmp.name);
i++;
}
EXPECT_EQ(jsonOptional.value()["body"].HasMember("connections"), true);
}
TEST_F(SummaryProtocolUtilTest, ToQueryParallelismPerformanceResponseTestWillReturnWhenNormalInput) {
Dic::Protocol::ParallelismPerformanceResponse response{};
IndicatorDataStructVo indicator;
indicator.index = 0;
indicator.indicators.compare["computingTime"] = 100;
response.indicatorData.performanceData.push_back(indicator);
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("performance"), true);
ASSERT_EQ(
jsonOptional.value()["body"]["performance"].GetArray().Size(), response.indicatorData.performanceData.size());
int i = 0;
for (const auto &item : jsonOptional.value()["body"]["performance"].GetArray()) {
auto tmp = response.indicatorData.performanceData.at(i);
EXPECT_EQ(item["index"].GetUint(), tmp.index);
EXPECT_EQ(item["indicators"]["compare"]["computingTime"].GetDouble(), tmp.indicators.compare["computingTime"]);
i++;
}
EXPECT_EQ(jsonOptional.value()["body"].HasMember("advice"), true);
}
TEST_F(SummaryProtocolUtilTest, ToSummarySlowRankAdvisorResponseTestWillReturnWhenNormalInput) {
Dic::Protocol::SummarySlowRankAdvisorResponse response{};
Module::AdviceInfoForSlowRank adviceInfo;
adviceInfo.name = "dp4-pp0-cp1-tp2";
adviceInfo.index = 9;
adviceInfo.synchronizeTime["tp"] = 1.23;
adviceInfo.synchronizeTime["cp"] = 4.56;
adviceInfo.synchronizeTime["dp"] = 7.89;
response.body.topNElements.push_back(adviceInfo);
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("topNElements"), true);
EXPECT_EQ(jsonOptional.value()["body"]["topNElements"].GetArray().Size(), response.body.topNElements.size());
int i = 0;
for (const auto &item : jsonOptional.value()["body"]["topNElements"].GetArray()) {
auto tmp = response.body.topNElements.at(i);
EXPECT_EQ(item["index"].GetUint(), tmp.index);
EXPECT_EQ(item["name"].GetString(), tmp.name);
EXPECT_EQ(item.HasMember("tpSynchronizeTime"), true);
EXPECT_EQ(item["tpSynchronizeTime"], tmp.synchronizeTime["tp"]);
EXPECT_EQ(item.HasMember("cpSynchronizeTime"), true);
EXPECT_EQ(item["cpSynchronizeTime"], tmp.synchronizeTime["cp"]);
EXPECT_EQ(item.HasMember("dpSynchronizeTime"), true);
EXPECT_EQ(item["dpSynchronizeTime"], tmp.synchronizeTime["dp"]);
}
EXPECT_EQ(jsonOptional.value()["body"].HasMember("hasSlowRank"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("matchSuccess"), true);
}
TEST_F(SummaryProtocolUtilTest, SummaryTopRankResponseSuccess) {
SummaryTopRankResponse response;
SummaryBaseInfo compare{5, {"1", "2"}, 100, 100, "filePath", 10, 5, {"1"}};
SummaryBaseInfo baseline{5, {"1", "2"}, 100, 100, "filePath", 10, 5, {"1"}};
response.body.baseInfo.compare = compare;
response.body.baseInfo.baseline = baseline;
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("baseInfo"), true);
EXPECT_EQ(jsonOptional.value()["body"]["baseInfo"].HasMember("baseline"), true);
auto baselineJson = jsonOptional.value()["body"]["baseInfo"]["baseline"].GetObject();
EXPECT_EQ(baselineJson["rankCount"].GetUint(), baseline.rankCount);
EXPECT_EQ(baselineJson["rankList"].GetArray().Size(), baseline.rankList.size());
EXPECT_EQ(baselineJson["dataSize"].GetDouble(), baseline.dataSize);
EXPECT_EQ(baselineJson["collectStartTime"].GetInt(), baseline.collectStartTime);
EXPECT_EQ(baselineJson["filePath"].GetString(), baseline.filePath);
EXPECT_EQ(baselineJson["collectDuration"].GetDouble(), baseline.collectDuration);
EXPECT_EQ(baselineJson["stepNum"].GetUint(), baseline.stepNum);
EXPECT_EQ(baselineJson["stepList"].GetArray().Size(), baseline.stepList.size());
EXPECT_EQ(jsonOptional.value()["body"]["baseInfo"].HasMember("compare"), true);
auto compareJson = jsonOptional.value()["body"]["baseInfo"]["compare"].GetObject();
EXPECT_EQ(compareJson["rankCount"].GetUint(), compare.rankCount);
EXPECT_EQ(compareJson["rankList"].GetArray().Size(), compare.rankList.size());
EXPECT_EQ(compareJson["dataSize"].GetDouble(), compare.dataSize);
EXPECT_EQ(compareJson["collectStartTime"].GetInt(), compare.collectStartTime);
EXPECT_EQ(compareJson["filePath"].GetString(), compare.filePath);
EXPECT_EQ(compareJson["collectDuration"].GetDouble(), compare.collectDuration);
EXPECT_EQ(compareJson["stepNum"].GetUint(), compare.stepNum);
EXPECT_EQ(compareJson["stepList"].GetArray().Size(), compare.stepList.size());
}
TEST_F(SummaryProtocolUtilTest, QueryModelInfoResponseSuccess) {
QueryModelInfoResponse response;
response.body = {60, {0, 1}, 200};
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
auto body = jsonOptional.value()["body"].GetObject();
EXPECT_EQ(body["layerNum"].GetInt(), response.body.layerNum);
EXPECT_EQ(body["expertNum"].GetInt(), response.body.expertNum);
EXPECT_EQ(body["denseLayerList"].GetArray().Size(), response.body.denseLayerList.size());
}
TEST_F(SummaryProtocolUtilTest, ToQueryExpertHotspotResponseJsonWhenNormalInput) {
QueryExpertHotspotResponse response{};
ExpertHotspotStruct info{"prefill", 0, 100, 0, 0, "1", 0};
response.body.hotspotInfos.push_back(info);
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("hotspotInfos"), true);
ASSERT_EQ(jsonOptional.value()["body"]["hotspotInfos"].GetArray().Size(), response.body.hotspotInfos.size());
}
TEST_F(SummaryProtocolUtilTest, ToImportExpertDataResponseJsonWhenNormalInput) {
ImportExpertDataResponse response{};
response.result = true;
response.msg = "";
std::string err;
std::optional<Dic::document_t> jsonOptional = protocol.ToJson(response, err);
EXPECT_EQ(jsonOptional.has_value(), true);
EXPECT_EQ(jsonOptional.value().HasMember("body"), true);
EXPECT_EQ(jsonOptional.value()["body"].HasMember("result"), true);
EXPECT_EQ(jsonOptional.value()["body"]["result"].GetBool(), true);
}
TEST_F(SummaryProtocolUtilTest, ToTestQueryParallelStrategyResponse) {
using namespace Dic::Module;
Dic::Protocol::QueryParallelStrategyResponse response{};
const int expectId = 1;
response.config = {
.algorithm = MEGATRON_LM_TP_CP_EP_DP_PP_ALG, .ppSize = 0, .tpSize = 2, .dpSize = 1, .cpSize = 1, .epSize = 1};
EXPECT_EQ(response.IsValid(), false);
response.SetDefault();
EXPECT_EQ(response.config.algorithm, MEGATRON_LM_TP_CP_EP_DP_PP_ALG);
EXPECT_EQ(response.config.ppSize, expectId);
const int expectTp = 2;
EXPECT_EQ(response.config.tpSize, expectTp);
}
TEST_F(SummaryProtocolUtilTest, ToTestQueryParallelStrategyResponse2) {
using namespace Dic::Module;
Dic::Protocol::QueryParallelStrategyResponse response{};
const int expectId = 1;
response.config = {
.algorithm = MEGATRON_LM_TP_CP_PP_EP_DP_ALG, .ppSize = 0, .tpSize = 2, .dpSize = 1, .cpSize = 1, .epSize = 1};
EXPECT_EQ(response.IsValid(), false);
response.SetDefault();
EXPECT_EQ(response.config.algorithm, MEGATRON_LM_TP_CP_PP_EP_DP_ALG);
EXPECT_EQ(response.config.ppSize, expectId);
const int expectTp = 2;
EXPECT_EQ(response.config.tpSize, expectTp);
}
TEST_F(SummaryProtocolUtilTest, ToTestQueryParallelStrategyResponseWhenTpPpDp) {
using namespace Dic::Module;
Dic::Protocol::QueryParallelStrategyResponse response{};
const int expectId = 1;
response.config = {
.algorithm = MEGATRON_LM_TP_CP_PP_EP_DP_ALG, .ppSize = 0, .tpSize = 2, .dpSize = 1, .cpSize = 1, .epSize = 1};
EXPECT_EQ(response.IsValid(), false);
response.SetDefault();
EXPECT_EQ(response.config.algorithm, MEGATRON_LM_TP_CP_PP_EP_DP_ALG);
EXPECT_EQ(response.config.ppSize, expectId);
const int expectTp = 2;
EXPECT_EQ(response.config.tpSize, expectTp);
}
TEST_F(SummaryProtocolUtilTest, ToTestQueryParallelStrategyResponseWhenTpDpPp) {
using namespace Dic::Module;
Dic::Protocol::QueryParallelStrategyResponse response{};
const int expectId = 1;
response.config = {
.algorithm = MEGATRON_LM_TP_CP_EP_DP_PP_ALG, .ppSize = 0, .tpSize = 2, .dpSize = 1, .cpSize = 1, .epSize = 1};
EXPECT_EQ(response.IsValid(), false);
response.SetDefault();
EXPECT_EQ(response.config.algorithm, MEGATRON_LM_TP_CP_EP_DP_PP_ALG);
EXPECT_EQ(response.config.ppSize, expectId);
const int expectTp = 2;
EXPECT_EQ(response.config.tpSize, expectTp);
}
TEST_F(SummaryProtocolUtilTest, ToTestQueryParallelStrategyResponseWhenInvalid) {
using namespace Dic::Module;
Dic::Protocol::QueryParallelStrategyResponse response{};
const int expectId = 1;
response.config = {.algorithm = "LLLLLLLLLLLL", .ppSize = 0, .tpSize = 2, .dpSize = 1, .cpSize = 1, .epSize = 1};
EXPECT_EQ(response.IsValid(), false);
response.SetDefault();
EXPECT_EQ(response.config.algorithm, MEGATRON_LM_TP_CP_EP_DP_PP_ALG);
EXPECT_EQ(response.config.ppSize, expectId);
const int expectTp = 2;
EXPECT_EQ(response.config.tpSize, expectTp);
}