* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "SummaryService.h"
#include "TraceTime.h"
#include "DataBaseManager.h"
#include "CollectionUtil.h"
#include "BaselineManager.h"
#include "SummaryErrorManager.h"
using namespace Dic::Module::Summary;
using namespace Dic::Module::Global;
using namespace Dic::Protocol;
using namespace Dic::Module;
bool SummaryService::UpdateStartTimeAndDuration(
SummaryBaseInfo &baseInfo, std::shared_ptr<VirtualClusterDatabase> &db) {
uint64_t min = UINT64_MAX;
uint64_t max = 0;
db->QueryExtremumTimestamp(min, max);
if (min > max) {
ServerLog::Warn("Fail to get extremum timestamp when query summary base info.");
return false;
}
baseInfo.collectStartTime = static_cast<int64_t>(
NumberUtil::CeilingClamp(min / (numberThousands * numberThousands), static_cast<uint64_t>(INT64_MAX)));
baseInfo.collectDuration =
NumberUtil::CeilingClamp((max - min) / numberThousands, static_cast<uint64_t>(INT64_MAX));
std::unordered_map<std::uint32_t, StepStatistic> rankStepTraceData{};
db->QueryAllPerformanceDataByStep("All", rankStepTraceData);
for (const auto &stepTraceData : rankStepTraceData) {
baseInfo.collectDuration = stepTraceData.second.npuTotalTime > baseInfo.collectDuration
? stepTraceData.second.npuTotalTime
: baseInfo.collectDuration;
}
if (!db->UpdateCollectTimeInfo(baseInfo)) {
ServerLog::Warn("Failed to update database for cluster base info.");
return false;
}
return true;
}
bool SummaryService::QuerySummaryBaseInfo(SummaryBaseInfo &baseInfo, std::shared_ptr<VirtualClusterDatabase> &db) {
if (!db->QueryBaseInfo(baseInfo)) {
ServerLog::Warn("Fail to query summary base info.");
return false;
}
if ((baseInfo.collectStartTime == 0 || baseInfo.collectDuration == 0.0) &&
!UpdateStartTimeAndDuration(baseInfo, db)) {
ServerLog::Warn("Fail to update start time and duration for summary base info.");
return false;
}
return true;
}
void SummaryService::QueryCompareSummaryBaseInfo(
const SummaryTopRankRequest &request, SummaryTopRankResponse &response) {
auto database = Timeline::DataBaseManager::Instance().GetClusterDatabase(request.params.clusterPath);
if (database == nullptr || !QuerySummaryBaseInfo(response.body.baseInfo.compare, database)) {
ServerLog::Warn("Fail to query compare summary base info");
}
if (!request.params.isCompare) {
return;
}
auto baselineDatabase =
Timeline::DataBaseManager::Instance().GetClusterDatabase(BaselineManager::Instance().GetBaseLineClusterPath());
if (baselineDatabase == nullptr || !QuerySummaryBaseInfo(response.body.baseInfo.baseline, baselineDatabase)) {
ServerLog::Warn("Fail to query baseline summary base info");
}
}
std::vector<IndicatorDataStruct> SummaryService::GetPerformanceDataByDimension(
std::shared_ptr<VirtualClusterDatabase> &database, const GetPerformanceIndicatorParam ¶ms) {
std::vector<IndicatorDataStruct> indicatorData;
if (database == nullptr) {
ServerLog::Warn("Fail to query compare parallelism info");
return indicatorData;
}
std::unordered_map<std::uint32_t, StepStatistic> stepStatisticData{};
bool result = database->QueryAllPerformanceDataByStep(params.step, stepStatisticData);
if (!result || stepStatisticData.empty()) {
ServerLog::Warn("Failed to query original parallelism performance data.");
return indicatorData;
}
uint32_t worldSize = params.config.dpSize * params.config.cpSize * params.config.tpSize * params.config.ppSize;
if (worldSize == 1 && stepStatisticData.size() > maxRankCountForSummaryWithoutConfig) {
ServerLog::Warn("When no parallel strategy is configured, computation/communication overview is limited to "
"a maximum of " +
std::to_string(maxRankCountForSummaryWithoutConfig) + " cards.");
return indicatorData;
}
auto algPtr = ParallelStrategyAlgorithmManager::Instance().GetAlgorithmByProjectName(database->GetDbPath());
if (algPtr == nullptr) {
ServerLog::Warn("Failed to get algorithm by project name for query parallelism performance.");
return indicatorData;
}
std::string err;
result = algPtr->GetPerformanceIndicatorByDimension(params, stepStatisticData, indicatorData, err);
if (!result) {
ServerLog::Warn(err);
return indicatorData;
}
return indicatorData;
}
std::unordered_map<std::string, double> SummaryService::CalDiffIndicators(
std::unordered_map<std::string, double> &compare, std::unordered_map<std::string, double> &baseline) {
std::set<std::string> keySet;
for (const auto &item : compare) {
keySet.insert(item.first);
}
for (const auto &item : baseline) {
keySet.insert(item.first);
}
std::unordered_map<std::string, double> diff;
int precision = 3;
for (const auto &item : keySet) {
diff[item] = NumberUtil::DoubleReservedNDigits(compare[item] - baseline[item], precision);
}
return diff;
}
void SummaryService::MergeParallelismPerformance(std::vector<IndicatorDataStruct> &compare,
std::vector<IndicatorDataStruct> &baseline, PerformanceIndicatorData &indicatorData) {
std::set<uint32_t> indexList;
std::map<uint32_t, IndicatorDataStruct> compareMap;
for (const auto &item : compare) {
indexList.insert(item.index);
compareMap[item.index] = item;
}
std::map<uint32_t, IndicatorDataStruct> baselineMap;
for (const auto &item : baseline) {
indexList.insert(item.index);
baselineMap[item.index] = item;
}
for (const auto &item : indexList) {
IndicatorDataStructVo indicatorDataStructVo;
indicatorDataStructVo.index = item;
if (compareMap.count(item)) {
indicatorDataStructVo.indicators.compare = compareMap[item].indicators;
}
if (baselineMap.count(item)) {
indicatorDataStructVo.indicators.baseline = baselineMap[item].indicators;
}
indicatorDataStructVo.indicators.diff =
CalDiffIndicators(indicatorDataStructVo.indicators.compare, indicatorDataStructVo.indicators.baseline);
indicatorData.performanceData.push_back(indicatorDataStructVo);
}
}
bool SummaryService::QueryParallelismPerformanceInfo(
const ParallelismPerformance ¶ms, PerformanceIndicatorData &indicatorData) {
auto database = Timeline::DataBaseManager::Instance().GetClusterDatabase(params.clusterPath);
GetPerformanceIndicatorParam indicatorParam{params.step, params.dimension, params.config};
std::vector<IndicatorDataStruct> compareIndicatorData = GetPerformanceDataByDimension(database, indicatorParam);
CommInfoMap compareCommInTpDimension;
CommInfoMap compareCommInfo = QueryParallelismCommTime(database, indicatorParam, compareCommInTpDimension);
std::vector<IndicatorDataStruct> baselineIndicatorData;
std::unordered_map<std::string, std::vector<CommInfoUnderRank>> baselineCommInfo;
if (params.isCompare) {
auto databaseBaseline = Timeline::DataBaseManager::Instance().GetClusterDatabase(
BaselineManager::Instance().GetBaseLineClusterPath());
GetPerformanceIndicatorParam baselineParams{params.baselineStep, params.dimension, params.config};
baselineIndicatorData = GetPerformanceDataByDimension(databaseBaseline, baselineParams);
CommInfoMap baseCommInTpDimension;
baselineCommInfo = QueryParallelismCommTime(databaseBaseline, baselineParams, baseCommInTpDimension);
}
if (compareIndicatorData.empty() && baselineIndicatorData.empty()) {
ServerLog::Error("Fail to query parallelism performance info.");
SetSummaryError(ErrorCode::QUERY_PARALLELISM_PERFORMANCE_FAILED);
return false;
}
MergeParallelismPerformance(compareIndicatorData, baselineIndicatorData, indicatorData);
MergeCommDataPerformance(compareCommInfo, baselineCommInfo, indicatorData);
if (!params.isCompare && database != nullptr) {
auto algPtr = ParallelStrategyAlgorithmManager::Instance().GetAlgorithmByProjectName(database->GetDbPath());
if (algPtr != nullptr) {
algPtr->CalAdviceInfo(params.dimension, indicatorData.advices, compareIndicatorData);
if (!algPtr->CalAdviceInfoByCommInfo(compareCommInTpDimension)) {
ServerLog::Warn("Failed to calculate slow rank advice by communication time. Current parallel "
"strategy config do not match the actual model training parameters.");
}
}
}
return true;
}
void SummaryService::MergeCommDataPerformance(std::unordered_map<std::string, std::vector<CommInfoUnderRank>> &compare,
std::unordered_map<std::string, std::vector<CommInfoUnderRank>> &baseline,
PerformanceIndicatorData &indicatorData) {
for (auto &item : indicatorData.performanceData) {
std::string key = std::to_string(item.index);
MergeCommInfo(compare[key], baseline[key], item.commTimeIndicator);
}
}
void SummaryService::MergeCommInfo(std::vector<CommInfoUnderRank> compare, std::vector<CommInfoUnderRank> baseline,
CompareData<std::unordered_map<std::string, double>> &commRes) {
std::unordered_map<std::string, double> compareMap;
for (const auto &item : compare) {
compareMap[item.pgName] = item.commTime;
}
std::unordered_map<std::string, double> baselineMap;
for (const auto &item : baseline) {
baselineMap[item.pgName] = item.commTime;
}
commRes.compare = compareMap;
commRes.baseline = baselineMap;
commRes.diff = CalDiffIndicators(compareMap, baselineMap);
}
std::unordered_map<std::string, std::vector<CommInfoUnderRank>> SummaryService::QueryParallelismCommTime(
const std::shared_ptr<VirtualClusterDatabase> &database, const GetPerformanceIndicatorParam ¶ms,
CommInfoMap &commInTpDimension) {
if (database == nullptr) {
ServerLog::Warn("Fail to query parallelism communication info, database not exist.");
return {};
}
auto algPtr = ParallelStrategyAlgorithmManager::Instance().GetAlgorithmByProjectName(database->GetDbPath());
if (algPtr == nullptr) {
ServerLog::Warn("Failed to get algorithm by project name for query parallelism communication info.");
return {};
}
std::vector<std::string> importRankList = database->GetAllRankFromStepStatisticInfo();
if (importRankList.empty()) {
ServerLog::Warn("Fail to get all rank from step statistic info.");
return {};
}
std::vector<CommInfoUnderRank> commTimeForRankDim = database->GetCommTimeForRankDim(params.step);
if (commTimeForRankDim.empty()) {
ServerLog::Warn("Fail to get communication time data.");
return {};
}
for (auto &item : commTimeForRankDim) {
if (item.pgName.empty()) {
continue;
}
commInTpDimension[item.rankId].push_back(item);
}
return algPtr->GetCommInfoByDimension(commInTpDimension, params.dimension);
}