* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <unordered_map>
#include <algorithm>
#include <functional>
#include "pch.h"
#include "BaselineManager.h"
#include "DataBaseManager.h"
#include "OperatorProtocolRequest.h"
#include "OperatorProtocolResponse.h"
#include "OperatorGroupConverter.h"
#include "OperatorProtocol.h"
#include "QueryOpStatisticInfoHandler.h"
namespace {
using namespace Dic;
using namespace Dic::Server;
using StatisticCmpRes = Dic::Protocol::OperatorStatisticCmpInfoRes;
using GetObjFunc = std::function<OperatorStatisticInfoRes(const StatisticCmpRes &)>;
using StatisticCmpFun = std::function<bool(const StatisticCmpRes &, const StatisticCmpRes &, GetObjFunc)>;
std::unordered_map<std::string, StatisticCmpFun> StatisticDescCompareFunctions = {
{"opType",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).opType > func(b).opType;
}},
{"opName",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).opName > func(b).opName;
}},
{"inputShape",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).inputShape > func(b).inputShape;
}},
{"accCore",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).accCore > func(b).accCore;
}},
{"totalTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleDesc(func(a).totalTime, func(b).totalTime);
}},
{"count",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleDesc(func(a).count, func(b).count);
}},
{"avgTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleDesc(func(a).avgTime, func(b).avgTime);
}},
{"maxTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleDesc(func(a).maxTime, func(b).maxTime);
}},
{"minTime", [](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleDesc(func(a).minTime, func(b).minTime);
}}};
std::unordered_map<std::string, StatisticCmpFun> StatisticAsceCompareFunctions = {
{"opType",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).opType < func(b).opType;
}},
{"opName",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).opName < func(b).opName;
}},
{"inputShape",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).inputShape < func(b).inputShape;
}},
{"accCore",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return func(a).accCore < func(b).accCore;
}},
{"totalTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleAsce(func(a).totalTime, func(b).totalTime);
}},
{"count",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleAsce(func(a).count, func(b).count);
}},
{"avgTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleAsce(func(a).avgTime, func(b).avgTime);
}},
{"maxTime",
[](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleAsce(func(a).maxTime, func(b).maxTime);
}},
{"minTime", [](const StatisticCmpRes &a, const StatisticCmpRes &b, GetObjFunc func) {
return NumberUtil::IsStr2DoubleAsce(func(a).minTime, func(b).minTime);
}}};
bool StatisticDescCmp(const StatisticCmpRes &a, const StatisticCmpRes &b, const std::string orderBy, GetObjFunc func) {
auto it = StatisticDescCompareFunctions.find(orderBy);
if (it != StatisticDescCompareFunctions.end()) {
return it->second(a, b, func);
}
return func(a).totalTime > func(b).totalTime;
}
bool StatisticAsceCmp(const StatisticCmpRes &a, const StatisticCmpRes &b, const std::string orderBy, GetObjFunc func) {
auto it = StatisticAsceCompareFunctions.find(orderBy);
if (it != StatisticAsceCompareFunctions.end()) {
return it->second(a, b, func);
}
return func(a).totalTime < func(b).totalTime;
}
bool StaticCmp(const StatisticCmpRes &a, const StatisticCmpRes &b, const std::string order, const std::string orderBy,
GetObjFunc func) {
if (order == "ascend") {
return StatisticAsceCmp(a, b, orderBy, func);
} else {
return StatisticDescCmp(a, b, orderBy, func);
}
}
};
namespace Dic::Module::Operator {
using namespace Dic::Server;
bool QueryOpStatisticInfoHandler::HandleRequest(std::unique_ptr<Protocol::Request> requestPtr) {
OperatorStatisticInfoRequest &request = dynamic_cast<OperatorStatisticInfoRequest &>(*requestPtr);
std::unique_ptr<OperatorStatisticInfoResponse> responsePtr = std::make_unique<OperatorStatisticInfoResponse>();
OperatorStatisticInfoResponse &response = *responsePtr;
bool rst = false;
std::string errorMsg;
if ((request.params.topK != 0) && request.params.CommonCheck(errorMsg) &&
request.params.StatisticGroupCheck(errorMsg)) {
rst = request.params.isCompare
? HandleCompareDataRequest(request, dynamic_cast<OperatorStatisticInfoResponse &>(*responsePtr))
: HandleStatisticcDataRequest(request, dynamic_cast<OperatorStatisticInfoResponse &>(*responsePtr));
} else {
SetOperatorError(ErrorCode::PARAMS_ERROR);
}
SetBaseResponse(request, response);
SendResponse(std::move(responsePtr), rst);
return rst;
}
bool QueryOpStatisticInfoHandler::HandleCompareDataRequest(
OperatorStatisticInfoRequest &request, OperatorStatisticInfoResponse &response) {
std::string rankId = Summary::VirtualSummaryDataBase::GetFileIdFromCombinationId(request.params.rankId);
auto database = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(rankId);
std::string deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(rankId);
if (deviceId.empty()) {
SetOperatorError(ErrorCode::GET_DEVICE_ID_FAILED);
return false;
}
request.params.deviceId = deviceId;
std::vector<Protocol::OperatorStatisticInfoRes> compareRes;
if (!database) {
ServerLog::Warn("[Operator]Not exist compare operator database. Fail get op statistic info.");
return true;
}
if (!database->QueryAllOperatorStatisticInfo(request.params, compareRes)) {
ServerLog::Error("[Operator]Failed to query current Statistic Info by rankId.");
SetOperatorError(ErrorCode::QUERY_ALL_STATISTIC_FAILED);
return false;
}
std::string baselineId = Global::BaselineManager::Instance().GetBaselineId();
if (baselineId == "") {
ServerLog::Error("[Operator]Failed to get baseline id.");
SetOperatorError(ErrorCode::GET_BASELINE_ID_FAILED);
return false;
}
auto databaseBaseline = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(baselineId);
std::vector<Protocol::OperatorStatisticInfoRes> baselineRes;
request.params.deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(baselineId);
if (!databaseBaseline) {
ServerLog::Warn("[Operator]Not exist baseline operator database. Fail get op statistic info.");
return true;
}
if (!databaseBaseline->QueryAllOperatorStatisticInfo(request.params, baselineRes)) {
ServerLog::Error("[Operator]Failed to query baseline Statistic Info by baselineId.");
SetOperatorError(ErrorCode::QUERY_ALL_STATISTIC_FAILED);
return false;
}
std::vector<Protocol::OperatorStatisticCmpInfoRes> res;
res = GetCmpDataVec(request.params.group, baselineRes, compareRes);
constexpr int64_t MAX_INT64 = std::numeric_limits<int64_t>::max();
int64_t safeSize = (res.size() > MAX_INT64) ? MAX_INT64 : static_cast<int64_t>(res.size());
if (request.params.topK > 0) {
response.total = std::min(request.params.topK, safeSize);
} else {
response.total = safeSize;
}
response.data = GetFixNumDiffCmpData(res, request.params, response.total);
return true;
}
bool QueryOpStatisticInfoHandler::HandleStatisticcDataRequest(
OperatorStatisticInfoRequest &request, OperatorStatisticInfoResponse &response) {
std::string rankId = Summary::VirtualSummaryDataBase::GetFileIdFromCombinationId(request.params.rankId);
auto database = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(rankId);
if (!database) {
ServerLog::Warn("[Operator]Not exist operator database. Fail get statistic info.");
return true;
}
std::string deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(rankId);
if (deviceId.empty()) {
SetOperatorError(ErrorCode::GET_DEVICE_ID_FAILED);
return false;
}
request.params.deviceId = deviceId;
if (!database->QueryOperatorStatisticInfo(request.params, response)) {
ServerLog::Error("[Operator]Failed to query Statistic Info by rankId.");
SetOperatorError(ErrorCode::QUERY_STATISTIC_FAILED);
return false;
}
GetObjFunc func = [](const StatisticCmpRes &stat) { return stat.compare; };
std::sort(response.data.begin(), response.data.end(),
[&request, func](OperatorStatisticCmpInfoRes &a, OperatorStatisticCmpInfoRes &b) {
return StaticCmp(a, b, request.params.order, request.params.orderBy, func);
});
return true;
}
std::string QueryOpStatisticInfoHandler::GetGroup(const std::string ¶msGroup, OperatorStatisticInfoRes &data) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(paramsGroup);
switch (operatorGroup) {
case OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP:
return data.opType + data.accCore;
case OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP:
return data.opType;
case OperatorGroupConverter::OperatorGroup::OP_INPUT_SHAPE_GROUP:
return data.opName + data.inputShape + data.accCore;
default:
return "";
}
}
void QueryOpStatisticInfoHandler::GroupingData(const std::string ¶msGroup, OpStaticResVec &datFromDb,
std::map<std::string, Protocol::OperatorStatisticCmpInfoRes> &groupMap, bool isBaselineData) {
std::string group;
for (auto &data : datFromDb) {
group = GetGroup(paramsGroup, data);
if (group.empty()) {
continue;
}
if (isBaselineData) {
groupMap[group].baseline = data;
} else {
groupMap[group].compare = data;
}
}
}
std::string QueryOpStatisticInfoHandler::CalDataCompare(const std::string &com, const std::string &base) {
return NumberUtil::StringDoubleMinus(com, base);
}
void QueryOpStatisticInfoHandler::SetOpInputShapeGroupData(OperatorStatisticCmpInfoRes &data) {
data.diff.opName = data.compare.opName.empty() ? data.baseline.opName : data.compare.opName;
data.diff.inputShape = data.compare.inputShape.empty() ? data.baseline.inputShape : data.compare.inputShape;
data.diff.opType = data.compare.opType + " -> " + data.compare.opType;
}
void QueryOpStatisticInfoHandler::SetOpOrHcclTypeGroupData(OperatorStatisticCmpInfoRes &data) {
data.diff.opType = data.compare.opType.empty() ? data.baseline.opType : data.compare.opType;
data.diff.opName = data.compare.opName + "->" + data.baseline.opName;
data.diff.inputShape = data.compare.inputShape + "->" + data.baseline.inputShape;
}
void QueryOpStatisticInfoHandler::CalDiffDataByGroup(
const std::string ¶msGroup, OperatorStatisticCmpInfoRes &data) {
data.diff.accCore = data.compare.accCore.empty() ? data.baseline.accCore : data.compare.accCore;
data.diff.totalTime = CalDataCompare(data.compare.totalTime, data.baseline.totalTime);
data.diff.count = CalDataCompare(data.compare.count, data.baseline.count);
data.diff.avgTime = CalDataCompare(data.compare.avgTime, data.baseline.avgTime);
data.diff.maxTime = CalDataCompare(data.compare.maxTime, data.baseline.maxTime);
data.diff.minTime = CalDataCompare(data.compare.minTime, data.baseline.minTime);
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(paramsGroup);
switch (operatorGroup) {
case OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP:
case OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP:
SetOpOrHcclTypeGroupData(data);
break;
case OperatorGroupConverter::OperatorGroup::OP_INPUT_SHAPE_GROUP:
SetOpInputShapeGroupData(data);
break;
default:
break;
}
}
std::vector<Protocol::OperatorStatisticCmpInfoRes> QueryOpStatisticInfoHandler::GetCmpDataVec(
std::string &group, OpStaticResVec &base, OpStaticResVec &cmp) {
std::map<std::string, Protocol::OperatorStatisticCmpInfoRes> groupMap;
std::vector<Protocol::OperatorStatisticCmpInfoRes> cmpRes;
GroupingData(group, base, groupMap, true);
GroupingData(group, cmp, groupMap, false);
for (auto &CmpInfo : groupMap) {
CalDiffDataByGroup(group, CmpInfo.second);
cmpRes.emplace_back(CmpInfo.second);
}
return cmpRes;
}
std::vector<Protocol::OperatorStatisticCmpInfoRes> QueryOpStatisticInfoHandler::GetFixNumDiffCmpData(
std::vector<Protocol::OperatorStatisticCmpInfoRes> &statisticData, Protocol::OperatorStatisticReqParams &reqParams,
const int64_t total) {
if (statisticData.empty()) {
return statisticData;
}
std::string dbOrderBy = reqParams.orderBy;
const std::string order = reqParams.order;
GetObjFunc func = [](const StatisticCmpRes &stat) { return stat.diff; };
std::sort(statisticData.begin(), statisticData.end(),
[&order, &dbOrderBy, func](OperatorStatisticCmpInfoRes &a, OperatorStatisticCmpInfoRes &b) {
return StaticCmp(a, b, order, dbOrderBy, func);
});
uint64_t MAX_UINT64 = std::numeric_limits<uint64_t>::max();
uint64_t safeTotal = total < 0 ? MAX_UINT64 : static_cast<uint64_t>(total);
safeTotal = std::min(safeTotal, static_cast<uint64_t>(statisticData.size()));
std::vector<Protocol::OperatorStatisticCmpInfoRes> topKStatisticData(
statisticData.begin(), statisticData.begin() + safeTotal);
uint64_t dataSize = topKStatisticData.size();
uint64_t pageSize =
(reqParams.pageSize <= 0 ? 10
: static_cast<uint64_t>(reqParams.pageSize));
uint64_t offset = pageSize * static_cast<uint64_t>(reqParams.current - 1);
if (offset >= dataSize) {
offset = dataSize - ((dataSize % pageSize) == 0 ? pageSize : (dataSize % pageSize));
}
std::vector<Protocol::OperatorStatisticCmpInfoRes>::const_iterator start = topKStatisticData.begin() + offset;
std::vector<Protocol::OperatorStatisticCmpInfoRes>::const_iterator end =
topKStatisticData.begin() + std::min(offset + pageSize, dataSize);
std::vector<Protocol::OperatorStatisticCmpInfoRes> result;
result.assign(start, end);
return result;
}
}