* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "pch.h"
#include "CommunicationProtocolUtil.h"
#include "CommunicationProtocolRequest.h"
#include "CommunicationProtocolResponse.h"
#include "CommunicationProtocol.h"
namespace Dic {
namespace Protocol {
void CommunicationProtocol::RegisterJsonToRequestFuncs() {
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_OPERATOR_DETAILS, ToOperatorDetailsRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_DISTRIBUTION, ToDistributionRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_BANDWIDTH, ToBandwidthDataRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_ITERATIONS, ToIterationsRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_OPERATORNAMES, ToOperatorNamesRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_SORT_OP, ToMatrixOpNamesRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_LIST, ToDurationRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_MATRIX_GROUP, ToMatrixGroupRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_MATRIX_BANDWIDTH, ToMatrixListRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_OPERATOR_LISTS, ToDurationRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_ADVISOR, ToCommunicationAdvisorRequest);
jsonToReqFactory.emplace(REQ_RES_COMMUNICATION_DURATION_SLOW_RANK_LIST, ToDurationRequest);
}
void CommunicationProtocol::RegisterResponseToJsonFuncs() {
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_OPERATOR_DETAILS, ToOperatorDetailsResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_DISTRIBUTION, ToDistributionResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_BANDWIDTH, ToBandwidthDataResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_ITERATIONS, ToIterationsResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_OPERATORNAMES, ToOperatorNamesResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_SORT_OP, ToMatrixOpNamesResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_LIST, ToDurationResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_OPERATOR_LISTS, ToOperatorListResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_MATRIX_GROUP, ToMatrixGroupResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_MATRIX_BANDWIDTH, ToMatrixListResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_ADVISOR, ToCommunicationAdvisorResponse);
resToJsonFactory.emplace(REQ_RES_COMMUNICATION_DURATION_SLOW_RANK_LIST, ToDurationSlowRankResponse);
}
void CommunicationProtocol::RegisterEventToJsonFuncs() {}
#pragma region <<Json To Request>>
std::unique_ptr<Request> CommunicationProtocol::ToOperatorDetailsRequest(const json_t &json, std::string &error) {
std::unique_ptr<OperatorDetailsRequest> reqPtr = std::make_unique<OperatorDetailsRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of operator details.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.rankId, json["params"], "rankId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.orderBy, json["params"], "orderBy");
JsonUtil::SetByJsonKeyValue(reqPtr->params.order, json["params"], "order");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pageSize, json["params"], "pageSize");
JsonUtil::SetByJsonKeyValue(reqPtr->params.currentPage, json["params"], "currentPage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.queryType, json["params"], "queryType");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToDistributionRequest(const json_t &json, std::string &error) {
std::unique_ptr<DistributionDataRequest> reqPtr = std::make_unique<DistributionDataRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of distribution request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.rankId, json["params"], "rankId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.operatorName, json["params"], "operatorName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.transportType, json["params"], "transportType");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToBandwidthDataRequest(const json_t &json, std::string &error) {
std::unique_ptr<BandwidthDataRequest> reqPtr = std::make_unique<BandwidthDataRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of bandwidth data request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.rankId, json["params"], "rankId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.operatorName, json["params"], "operatorName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToMatrixGroupRequest(const json_t &json, std::string &error) {
std::unique_ptr<MatrixGroupRequest> reqPtr = std::make_unique<MatrixGroupRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of matrix group request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.isCompare, json["params"], "isCompare");
JsonUtil::SetByJsonKeyValue(reqPtr->params.baselineIterationId, json["params"], "baselineIterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToMatrixListRequest(const json_t &json, std::string &error) {
std::unique_ptr<MatrixBandwidthRequest> reqPtr = std::make_unique<MatrixBandwidthRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info to matrix list request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.operatorName, json["params"], "operatorName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.isCompare, json["params"], "isCompare");
JsonUtil::SetByJsonKeyValue(reqPtr->params.baselineIterationId, json["params"], "baselineIterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
JsonUtil::SetByJsonKeyValue(reqPtr->params.baselineGroupIdHash, json["params"], "baselineGroupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToIterationsRequest(const json_t &json, std::string &error) {
std::unique_ptr<IterationsRequest> reqPtr = std::make_unique<IterationsRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of iterations request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.isCompare, json["params"], "isCompare");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToDurationRequest(const json_t &json, std::string &error) {
std::unique_ptr<DurationListRequest> reqPtr = std::make_unique<DurationListRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of duration request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.operatorName, json["params"], "operatorName");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
if (json["params"].HasMember("targetOperatorName")) {
JsonUtil::SetByJsonKeyValue(reqPtr->params.targetOperatorName, json["params"], "targetOperatorName");
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.isCompare, json["params"], "isCompare");
JsonUtil::SetByJsonKeyValue(reqPtr->params.baselineIterationId, json["params"], "baselineIterationId");
if (json["params"].HasMember("rankList") && json["params"]["rankList"].IsArray()) {
for (const auto &rankId : json["params"]["rankList"].GetArray()) {
if (rankId.IsString()) {
reqPtr->params.rankList.emplace_back(rankId.GetString());
}
}
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.baselineGroupIdHash, json["params"], "baselineGroupIdHash");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToOperatorNamesRequest(const json_t &json, std::string &error) {
std::unique_ptr<OperatorNamesRequest> reqPtr = std::make_unique<OperatorNamesRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of operator names request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
if (json["params"].HasMember("rankList") && json["params"]["rankList"].IsArray()) {
for (const auto &rankId : json["params"]["rankList"].GetArray()) {
if (rankId.IsString()) {
reqPtr->params.rankList.emplace_back(rankId.GetString());
}
}
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToMatrixOpNamesRequest(const json_t &json, std::string &error) {
std::unique_ptr<MatrixSortOpNamesRequest> reqPtr = std::make_unique<MatrixSortOpNamesRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of matrix op names request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.iterationId, json["params"], "iterationId");
JsonUtil::SetByJsonKeyValue(reqPtr->params.stage, json["params"], "stage");
JsonUtil::SetByJsonKeyValue(reqPtr->params.pgName, json["params"], "pgName");
if (json["params"].HasMember("rankList") && json["params"]["rankList"].IsArray()) {
for (const auto &rankId : json["params"]["rankList"].GetArray()) {
if (rankId.IsString()) {
reqPtr->params.rankList.emplace_back(rankId.GetString());
}
}
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
JsonUtil::SetByJsonKeyValue(reqPtr->params.groupIdHash, json["params"], "groupIdHash");
return reqPtr;
}
std::unique_ptr<Request> CommunicationProtocol::ToCommunicationAdvisorRequest(
const Dic::json_t &json, std::string &error) {
std::unique_ptr<CommunicationAdvisorRequest> reqPtr = std::make_unique<CommunicationAdvisorRequest>();
if (!ProtocolUtil::SetRequestBaseInfo(*reqPtr, json)) {
error = "Failed to set request base info of communication advisor request.";
return nullptr;
}
JsonUtil::SetByJsonKeyValue(reqPtr->params.clusterPath, json["params"], "clusterPath");
return reqPtr;
}
#pragma endregion
#pragma region <<Response To Json>>
std::optional<document_t> CommunicationProtocol::ToOperatorDetailsResponse(const Response &response) {
return ToResponseJson<OperatorDetailsResponse>(dynamic_cast<const OperatorDetailsResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToBandwidthDataResponse(const Response &response) {
return ToResponseJson<BandwidthDataResponse>(dynamic_cast<const BandwidthDataResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToDistributionResponse(const Response &response) {
return ToResponseJson<DistributionResponse>(dynamic_cast<const DistributionResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToIterationsResponse(const Response &response) {
return ToResponseJson<IterationsOrRanksResponse>(dynamic_cast<const IterationsOrRanksResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToOperatorNamesResponse(const Response &response) {
return ToResponseJson<OperatorNamesResponse>(dynamic_cast<const OperatorNamesResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToMatrixOpNamesResponse(const Dic::Protocol::Response &response) {
return ToResponseJson<MatrixSortOpNamesResponse>(dynamic_cast<const MatrixSortOpNamesResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToDurationResponse(const Response &response) {
return ToResponseJson<DurationResponse>(dynamic_cast<const DurationResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToOperatorListResponse(const Response &response) {
return ToResponseJson<OperatorListsResponse>(dynamic_cast<const OperatorListsResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToMatrixGroupResponse(const Response &response) {
return ToResponseJson<MatrixGroupResponse>(dynamic_cast<const MatrixGroupResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToMatrixListResponse(const Response &response) {
return ToResponseJson<MatrixListResponse>(dynamic_cast<const MatrixListResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToCommunicationAdvisorResponse(
const Dic::Protocol::Response &response) {
return ToResponseJson<CommunicationAdvisorResponse>(dynamic_cast<const CommunicationAdvisorResponse &>(response));
}
std::optional<document_t> CommunicationProtocol::ToDurationSlowRankResponse(const Dic::Protocol::Response &response) {
return ToResponseJson<CommunicationSlowRankAnalysisResponse>(
dynamic_cast<const CommunicationSlowRankAnalysisResponse &>(response));
}
#pragma endregion
}
}