* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <algorithm>
#include "pch.h"
#include "BaselineManager.h"
#include "DataBaseManager.h"
#include "OperatorProtocolRequest.h"
#include "OperatorProtocol.h"
#include "QueryOpDetailInfoHandler.h"
namespace {
using namespace Dic::Server;
using namespace Dic;
using DetailCmpRes = Dic::Protocol::OperatorDetailCmpInfoRes;
using DetailCmpFun = std::function<bool(const DetailCmpRes &, const DetailCmpRes &)>;
static std::unordered_map<std::string, DetailCmpFun> DetailDescCompareFunctions = {
{"deviceId", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.rankId > b.diff.rankId; }},
{"step_id", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.stepId > b.diff.stepId; }},
{"name", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.name > b.diff.name; }},
{"type", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.type > b.diff.type; }},
{"accCore", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.accCore > b.diff.accCore; }},
{"startTime",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleDesc(a.diff.startTime, b.diff.startTime);
}},
{"duration",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleDesc(a.diff.duration, b.diff.duration);
}},
{"waitTime",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleDesc(a.diff.waitTime, b.diff.waitTime);
}},
};
static std::unordered_map<std::string, DetailCmpFun> DetailAsceCompareFunctions = {
{"deviceId", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.rankId < b.diff.rankId; }},
{"step_id", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.stepId < b.diff.stepId; }},
{"name", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.name < b.diff.name; }},
{"type", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.type < b.diff.type; }},
{"accCore", [](const DetailCmpRes &a, const DetailCmpRes &b) { return a.diff.accCore < b.diff.accCore; }},
{"startTime",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleAsce(a.diff.startTime, b.diff.startTime);
}},
{"duration",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleAsce(a.diff.duration, b.diff.duration);
}},
{"waitTime",
[](const DetailCmpRes &a, const DetailCmpRes &b) {
return NumberUtil::IsStr2DoubleAsce(a.diff.waitTime, b.diff.waitTime);
}},
};
bool DetailDescCmp(const DetailCmpRes &a, const DetailCmpRes &b, const std::string orderBy) {
auto it = DetailDescCompareFunctions.find(orderBy);
if (it != DetailDescCompareFunctions.end()) {
return it->second(a, b);
}
return a.diff.duration > b.diff.duration;
}
bool DetailAsceCmp(const DetailCmpRes &a, const DetailCmpRes &b, const std::string orderBy) {
auto it = DetailAsceCompareFunctions.find(orderBy);
if (it != DetailAsceCompareFunctions.end()) {
return it->second(a, b);
}
return a.diff.duration > b.diff.duration;
}
bool StaticCmp(const DetailCmpRes &a, const DetailCmpRes &b, const std::string order, const std::string orderBy) {
if (order == "ascend") {
return DetailAsceCmp(a, b, orderBy);
} else {
return DetailDescCmp(a, b, orderBy);
}
}
};
namespace Dic::Module::Operator {
using namespace Dic::Server;
bool QueryOpDetailInfoHandler::HandleRequest(std::unique_ptr<Protocol::Request> requestPtr) {
OperatorDetailInfoRequest &request = dynamic_cast<OperatorDetailInfoRequest &>(*requestPtr);
std::unique_ptr<OperatorDetailInfoResponse> responsePtr = std::make_unique<OperatorDetailInfoResponse>();
OperatorDetailInfoResponse &response = *responsePtr;
std::string errorMsg;
if (!request.params.CommonCheck(errorMsg)) {
ServerLog::Error(errorMsg);
SetOperatorError(ErrorCode::PARAMS_ERROR);
SetBaseResponse(request, response);
SetResponseResult(response, false);
SendResponse(std::move(responsePtr), false);
return false;
}
bool rst = request.params.isCompare
? HandleCompareDataRequest(request, dynamic_cast<OperatorDetailInfoResponse &>(*responsePtr))
: HandleDetailDataRequest(request, dynamic_cast<OperatorDetailInfoResponse &>(*responsePtr));
SetBaseResponse(request, response);
SetResponseResult(response, rst);
SendResponse(std::move(responsePtr), rst);
return rst;
}
bool QueryOpDetailInfoHandler::HandleCompareDataRequest(
OperatorDetailInfoRequest &request, OperatorDetailInfoResponse &response) {
if (request.params.topK == 0) {
return true;
}
std::string rankId = Summary::VirtualSummaryDataBase::GetFileIdFromCombinationId(request.params.rankId);
auto database = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(rankId);
std::string deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(rankId);
if (deviceId.empty()) {
return false;
}
request.params.deviceId = deviceId;
std::vector<Protocol::OperatorDetailInfoRes> cmpRes;
if (!database) {
ServerLog::Warn("[Operator]Not exist operator database. Fail get op detail info.");
return true;
}
if (!database->QueryAllOperatorDetailInfo(request.params, cmpRes, response.level)) {
ServerLog::Error("[Operator]Failed to query current detail info by rankId.");
SetOperatorError(ErrorCode::QUERY_ALL_DETAIL_FAILED);
return false;
}
std::set<std::string> cmpPmuHeader = database->GetPmuColumns();
std::string baselineId = Global::BaselineManager::Instance().GetBaselineId();
if (baselineId == "") {
ServerLog::Error("[Operator]Failed to get baseline id.");
SetOperatorError(ErrorCode::GET_BASELINE_ID_FAILED);
return false;
}
auto databaseBaseline = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(baselineId);
std::vector<Protocol::OperatorDetailInfoRes> baselineRes;
request.params.deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(baselineId);
if (!databaseBaseline) {
ServerLog::Warn("[Operator]Not exist baseline operator database. Fail get op detail info.");
return true;
}
if (!databaseBaseline->QueryAllOperatorDetailInfo(request.params, baselineRes, response.level)) {
ServerLog::Error("[Operator]Failed to query baseline detail Info by baselineId.");
SetOperatorError(ErrorCode::QUERY_ALL_DETAIL_FAILED);
return false;
}
std::set<std::string> basePmuHeader = databaseBaseline->GetPmuColumns();
std::vector<Protocol::OperatorDetailCmpInfoRes> fullCmpData;
fullCmpData = GetCmpDataVec(baselineRes, cmpRes);
constexpr int64_t MAX_INT64 = std::numeric_limits<int64_t>::max();
const int64_t safeSize = std::min(static_cast<int64_t>(fullCmpData.size()), MAX_INT64);
if (request.params.topK > 0) {
response.total = std::min(request.params.topK, safeSize);
} else {
response.total = safeSize;
}
cmpPmuHeader.insert(basePmuHeader.begin(), basePmuHeader.end());
response.pmuHeaders = cmpPmuHeader;
response.data = GetFixNumDiffCmpData(fullCmpData, request.params, response.total, basePmuHeader, cmpPmuHeader);
return true;
}
bool QueryOpDetailInfoHandler::HandleDetailDataRequest(
OperatorDetailInfoRequest &request, OperatorDetailInfoResponse &response) {
std::string rankId = Summary::VirtualSummaryDataBase::GetFileIdFromCombinationId(request.params.rankId);
auto database = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(rankId);
if (!database) {
ServerLog::Warn("[Operator]Not exist operator database. Fail to get op detail info.");
return true;
}
std::string deviceId = Timeline::DataBaseManager::Instance().GetDeviceIdFromRankId(rankId);
if (deviceId.empty()) {
SetOperatorError(ErrorCode::GET_DEVICE_ID_FAILED);
return false;
}
request.params.deviceId = deviceId;
if (!database->QueryOperatorDetailInfo(request.params, response)) {
ServerLog::Error("[Operator]Failed to query detail Info by rankId");
SetOperatorError(ErrorCode::QUERY_DETAIL_FAILED);
return false;
}
return true;
}
void QueryOpDetailInfoHandler::SortDataBynameAndStartTime(
std::vector<Protocol::OperatorDetailInfoRes> &baseDbData, std::vector<Protocol::OperatorDetailInfoRes> &cmpDbData) {
sort(baseDbData.begin(), baseDbData.end(), [](OperatorDetailInfoRes &a, OperatorDetailInfoRes &b) {
if (a.name != b.name) {
return a.name < b.name;
} else {
return a.startTime < b.startTime;
}
});
sort(cmpDbData.begin(), cmpDbData.end(), [](OperatorDetailInfoRes &a, OperatorDetailInfoRes &b) {
if (a.name != b.name) {
return a.name < b.name;
} else {
return a.startTime < b.startTime;
}
});
}
std::vector<Protocol::OperatorDetailCmpInfoRes> QueryOpDetailInfoHandler::GetCmpDataVec(
std::vector<Protocol::OperatorDetailInfoRes> &baseDbData, std::vector<Protocol::OperatorDetailInfoRes> &cmpDbData) {
SortDataBynameAndStartTime(baseDbData, cmpDbData);
std::vector<Protocol::OperatorDetailInfoRes>::iterator baseIter = baseDbData.begin();
std::vector<Protocol::OperatorDetailInfoRes>::iterator cmpIter = cmpDbData.begin();
std::vector<Protocol::OperatorDetailCmpInfoRes> datailData;
while (baseIter != baseDbData.end() && cmpIter != cmpDbData.end()) {
while (baseIter != baseDbData.end() && cmpIter != cmpDbData.end() && baseIter->name == cmpIter->name) {
OperatorDetailCmpInfoRes tmp;
tmp.baseline = *baseIter;
tmp.compare = *cmpIter;
datailData.emplace_back(tmp);
baseIter++;
cmpIter++;
}
if (baseIter == baseDbData.end() || cmpIter == cmpDbData.end()) {
break;
}
while (baseIter != baseDbData.end() && baseIter->name < cmpIter->name) {
OperatorDetailCmpInfoRes tmp;
tmp.baseline = *baseIter;
datailData.emplace_back(tmp);
baseIter++;
}
if (baseIter == baseDbData.end() || cmpIter == cmpDbData.end()) {
break;
}
while (cmpIter != cmpDbData.end() && cmpIter->name < baseIter->name) {
OperatorDetailCmpInfoRes tmp;
tmp.compare = *cmpIter;
datailData.emplace_back(tmp);
cmpIter++;
}
}
while (baseIter != baseDbData.end()) {
OperatorDetailCmpInfoRes tmp;
tmp.baseline = *baseIter;
datailData.emplace_back(tmp);
baseIter++;
}
while (cmpIter != cmpDbData.end()) {
OperatorDetailCmpInfoRes tmp;
tmp.compare = *cmpIter;
datailData.emplace_back(tmp);
cmpIter++;
}
return datailData;
}
std::string QueryOpDetailInfoHandler::CalPmuDataCompare(const std::string &comPmu, const std::string &basePmu) {
if (NumberUtil::IsDouble(comPmu) && NumberUtil::IsDouble(basePmu)) {
return NumberUtil::StringDoubleMinus(comPmu, basePmu);
} else {
return comPmu + "->" + basePmu;
}
}
void QueryOpDetailInfoHandler::FromatDatailData(Protocol::OperatorDetailCmpInfoRes &data,
const std::set<std::string> &baseDiff, const std::set<std::string> &cmpDiff,
const std::set<std::string> intersection) {
data.diff.rankId = data.compare.rankId.empty() ? data.baseline.rankId : data.compare.rankId;
data.diff.stepId = data.compare.stepId.empty() ? data.baseline.stepId : data.compare.stepId;
data.diff.name = data.compare.name.empty() ? data.baseline.name : data.compare.name;
data.diff.type = data.compare.type + "->" + data.baseline.type;
data.diff.accCore = data.compare.accCore + "->" + data.baseline.accCore;
data.diff.startTime = CalPmuDataCompare(data.compare.startTime, data.baseline.startTime);
data.diff.duration = CalPmuDataCompare(data.compare.duration, data.baseline.duration);
data.diff.waitTime = CalPmuDataCompare(data.compare.waitTime, data.baseline.waitTime);
data.diff.blockNum = CalPmuDataCompare(data.compare.blockNum, data.baseline.blockNum);
data.diff.inputShape = data.compare.inputShape + "->" + data.baseline.inputShape;
data.diff.inputType = data.compare.inputType + "->" + data.baseline.inputType;
data.diff.inputFormat = data.compare.inputFormat + "->" + data.baseline.inputFormat;
data.diff.outputShape = data.compare.outputShape + "->" + data.baseline.outputShape;
data.diff.outputType = data.compare.outputType + "->" + data.baseline.outputType;
data.diff.outputFormat = data.compare.outputFormat + "->" + data.baseline.outputFormat;
for (const auto &col : intersection) {
data.diff.pmuDatas[col] = CalPmuDataCompare(data.compare.pmuDatas[col], data.baseline.pmuDatas[col]);
}
for (const auto &col : baseDiff) {
data.diff.pmuDatas[col] = "->" + data.baseline.pmuDatas[col];
}
for (const auto &col : cmpDiff) {
data.diff.pmuDatas[col] = data.compare.pmuDatas[col] + "->";
}
}
std::vector<Protocol::OperatorDetailCmpInfoRes> QueryOpDetailInfoHandler::GetFixNumDiffCmpData(
std::vector<Protocol::OperatorDetailCmpInfoRes> &datailData, Protocol::OperatorStatisticReqParams &reqParams,
const int64_t total, const std::set<std::string> &basePmuHeader, const std::set<std::string> &cmpPmuHeader) {
if (datailData.empty()) {
return datailData;
}
std::set<std::string> baseDiff;
std::set_difference(basePmuHeader.begin(), basePmuHeader.end(), cmpPmuHeader.begin(), cmpPmuHeader.end(),
std::inserter(baseDiff, baseDiff.begin()));
std::set<std::string> cmpDiff;
std::set_difference(cmpPmuHeader.begin(), cmpPmuHeader.end(), basePmuHeader.begin(), basePmuHeader.end(),
std::inserter(cmpDiff, cmpDiff.begin()));
std::set<std::string> intersection;
std::set_intersection(basePmuHeader.begin(), basePmuHeader.end(), cmpPmuHeader.begin(), cmpPmuHeader.end(),
std::inserter(intersection, intersection.begin()));
for (auto &data : datailData) {
FromatDatailData(data, baseDiff, cmpDiff, intersection);
}
std::string dbOrderBy = reqParams.orderBy;
const std::string order = reqParams.order;
std::sort(datailData.begin(), datailData.end(),
[&order, &dbOrderBy](Protocol::OperatorDetailCmpInfoRes &a, Protocol::OperatorDetailCmpInfoRes &b) {
return StaticCmp(a, b, order, dbOrderBy);
});
std::vector<Protocol::OperatorDetailCmpInfoRes> topKDetailData;
topKDetailData.assign(datailData.begin(), datailData.begin() + total);
uint64_t dataSize = topKDetailData.size();
uint64_t pageSize = (reqParams.pageSize <= 0 ? 10 : static_cast<uint64_t>(reqParams.pageSize));
uint64_t offset = pageSize * static_cast<uint64_t>(reqParams.current - 1);
if (offset >= dataSize) {
offset = dataSize - ((dataSize % pageSize) == 0 ? pageSize : (dataSize % pageSize));
}
std::vector<Protocol::OperatorDetailCmpInfoRes>::const_iterator start = topKDetailData.begin() + offset;
std::vector<Protocol::OperatorDetailCmpInfoRes>::const_iterator end =
topKDetailData.begin() + std::min(offset + pageSize, dataSize);
std::vector<Protocol::OperatorDetailCmpInfoRes> result;
result.assign(start, end);
return result;
}
}