* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef DIC_COMMUNICATION_PROTOCOL_COMMUNICATION_REQUEST_H
#define DIC_COMMUNICATION_PROTOCOL_COMMUNICATION_REQUEST_H
#include <string>
#include <optional>
#include <vector>
#include "ProtocolMessage.h"
#include "ProtocolDefs.h"
namespace Dic {
namespace Protocol {
struct OperatorDetailsParam {
std::string iterationId;
std::string rankId;
std::string orderBy;
std::string order;
std::string stage;
std::string queryType = "Comparison";
std::string pgName;
std::string clusterPath;
std::string groupIdHash;
int pageSize{};
int currentPage{};
bool CheckParams(std::string &errorMsg) const {
if (!CheckPageValid(this->pageSize, this->currentPage, errorMsg)) {
return false;
}
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->rankId, paramError)) {
errorMsg = "[Communication] Failed to check rank id." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->orderBy, paramError)) {
errorMsg = "[Communication] Failed to check orderBy." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->order, paramError)) {
errorMsg = "[Communication] Failed to check order type." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValid(this->queryType, paramError)) {
errorMsg = "[Communication] Failed to check query type." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
if (!CheckStrParamValid(groupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
return true;
}
OperatorDetailsParam() = default;
OperatorDetailsParam(const OperatorDetailsParam ¶m) {
this->iterationId = param.iterationId;
this->rankId = param.rankId;
this->orderBy = param.orderBy;
this->order = param.order;
this->stage = param.stage;
this->pageSize = param.pageSize;
this->currentPage = param.currentPage;
this->pgName = param.pgName;
this->groupIdHash = param.groupIdHash;
}
OperatorDetailsParam &operator=(const OperatorDetailsParam ¶m) = delete;
};
struct OperatorDetailsRequest : public Request {
OperatorDetailsRequest() : Request(REQ_RES_COMMUNICATION_OPERATOR_DETAILS) {};
OperatorDetailsParam params;
};
struct BandwidthDataParam {
std::string iterationId;
std::string rankId;
std::string operatorName;
std::string stage;
std::string pgName;
std::string clusterPath;
std::string groupIdHash;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValid(this->rankId, paramError)) {
errorMsg = "[Communication] Failed to check rank id." + paramError;
return false;
}
if (!CheckStrParamValid(this->operatorName, paramError)) {
errorMsg = "[Communication] Failed to operator name." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
if (!CheckStrParamValid(groupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check group id hash." + paramError;
return false;
}
return true;
}
};
struct BandwidthDataRequest : public Request {
BandwidthDataRequest() : Request(REQ_RES_COMMUNICATION_BANDWIDTH) {};
BandwidthDataParam params;
};
struct DistributionDataParam {
std::string iterationId;
std::string rankId;
std::string operatorName;
std::string transportType;
std::string stage;
std::string pgName;
std::string clusterPath;
std::string groupIdHash;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValid(this->rankId, paramError)) {
errorMsg = "[Communication] Failed to check rank id." + paramError;
return false;
}
if (!CheckStrParamValid(this->operatorName, paramError)) {
errorMsg = "[Communication] Failed to check operator name." + paramError;
return false;
}
if (!CheckStrParamValid(this->transportType, paramError)) {
errorMsg = "[Communication] Failed to check transport type." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
if (!CheckStrParamValid(groupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check group id hash." + paramError;
return false;
}
return true;
}
};
struct DistributionDataRequest : public Request {
DistributionDataRequest() : Request(REQ_RES_COMMUNICATION_DISTRIBUTION) {};
DistributionDataParam params;
};
struct IterationsParams {
bool isCompare = false;
std::string clusterPath;
inline bool CheckParams(std::string &errMsg) const {
std::string paramErr;
if (!CheckStrParamValid(clusterPath, paramErr)) {
errMsg = "[Communication] Failed to check cluster." + paramErr;
return false;
}
return true;
}
};
struct IterationsRequest : public Request {
IterationsRequest() : Request(REQ_RES_COMMUNICATION_ITERATIONS) {};
IterationsParams params;
};
struct OperatorNamesParams {
std::string iterationId;
std::vector<std::string> rankList = {};
std::string stage;
std::string pgName;
std::string clusterPath;
std::string groupIdHash;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
if (!CheckStrParamValid(groupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check group id hash." + paramError;
return false;
}
return true;
}
OperatorNamesParams() = default;
OperatorNamesParams(const OperatorNamesParams ¶ms) {
this->iterationId = params.iterationId;
for (const auto &item : params.rankList) {
this->rankList.push_back(item);
}
this->stage = params.stage;
this->groupIdHash = params.groupIdHash;
}
OperatorNamesParams &operator=(const OperatorNamesParams ¶ms) = delete;
};
struct OperatorNamesRequest : public Request {
OperatorNamesRequest() : Request(REQ_RES_COMMUNICATION_OPERATORNAMES) {};
OperatorNamesParams params;
};
struct MatrixSortOpNamesRequest : public Request {
MatrixSortOpNamesRequest() : Request(REQ_RES_COMMUNICATION_SORT_OP) {};
OperatorNamesParams params;
};
struct DurationListParams {
std::string iterationId;
std::vector<std::string> rankList = {};
std::string operatorName;
std::string stage;
std::string targetOperatorName;
bool isCompare = false;
std::string baselineIterationId;
std::string pgName;
std::string clusterPath;
std::string groupIdHash;
std::string baselineGroupIdHash;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValid(this->operatorName, paramError)) {
errorMsg = "[Communication] Failed to operator name." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->baselineIterationId, paramError)) {
errorMsg = "[Communication] Failed to check baseline iteration id." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->baselineGroupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check baseline group id hash." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->groupIdHash, paramError)) {
errorMsg = "[Communication] Failed to check group id hash." + paramError;
return false;
}
return true;
}
DurationListParams() = default;
DurationListParams(const DurationListParams ¶ms) {
this->iterationId = params.iterationId;
this->operatorName = params.operatorName;
this->stage = params.stage;
this->targetOperatorName = params.targetOperatorName;
this->pgName = params.pgName;
for (const auto &item : params.rankList) {
this->rankList.push_back(item);
}
this->groupIdHash = params.groupIdHash;
}
DurationListParams &operator=(const DurationListParams ¶ms) = delete;
};
struct DurationListRequest : public Request {
DurationListRequest() : Request(REQ_RES_COMMUNICATION_LIST) {};
DurationListParams params;
};
struct MatrixGroupParam {
std::string iterationId;
std::string baselineIterationId;
bool isCompare = false;
std::string clusterPath;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->baselineIterationId, paramError)) {
errorMsg = "[Communication] Failed to check baseline iteration id." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
return true;
}
};
struct MatrixGroupRequest : public Request {
MatrixGroupRequest() : Request(REQ_RES_COMMUNICATION_MATRIX_GROUP) {};
MatrixGroupParam params;
};
struct MatrixBandwidthParam {
std::string stage;
std::string operatorName;
std::string iterationId;
std::string pgName;
std::string groupIdHash;
bool isCompare = false;
std::string baselineIterationId;
std::string clusterPath;
std::string baselineGroupIdHash;
bool CheckParams(std::string &errorMsg) const {
std::string paramError;
if (!CheckStrParamValidEmptyAllowed(this->iterationId, paramError)) {
errorMsg = "[Communication] Failed to check iteration id." + paramError;
return false;
}
if (!CheckStrParamValid(this->operatorName, paramError)) {
errorMsg = "[Communication] Failed to operator name." + paramError;
return false;
}
if (!CheckStrParamValidWithoutLenLimit(this->stage, paramError)) {
errorMsg = "[Communication] Failed to check stage." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->baselineIterationId, paramError)) {
errorMsg = "[Communication] Failed to check baseline iteration id." + paramError;
return false;
}
if (!CheckStrParamValidEmptyAllowed(this->pgName, paramError)) {
errorMsg = "[Communication] Failed to check pg name." + paramError;
return false;
}
if (!CheckStrParamValid(clusterPath, paramError)) {
errorMsg = "[Communication] Failed to check cluster." + paramError;
return false;
}
return true;
}
};
struct MatrixBandwidthRequest : public Request {
MatrixBandwidthRequest() : Request(REQ_RES_COMMUNICATION_MATRIX_BANDWIDTH) {};
MatrixBandwidthParam params;
};
struct CommunicationAdvisorParam {
std::string clusterPath;
inline bool CheckParams(std::string &errMsg) const {
std::string paramErr;
if (!CheckStrParamValid(clusterPath, paramErr)) {
errMsg = "[Communication] Failed to check cluster." + paramErr;
return false;
}
return true;
}
};
struct CommunicationAdvisorRequest : public Request {
CommunicationAdvisorRequest() : Request(REQ_RES_COMMUNICATION_ADVISOR) {};
CommunicationAdvisorParam params;
};
}
}
#endif