* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <cmath>
#include "CommonDefs.h"
#include "ServerLog.h"
#include "MetaDataCacheManager.h"
#include "VirtualTraceDatabase.h"
#include "TrackInfoManager.h"
#include "RepositoryFactory.h"
#include "FullDbEnumUtil.h"
#include "TraceTime.h"
#include "SliceAnalyzer.h"
namespace Dic::Module::Timeline {
using namespace Dic::Server;
using namespace Dic::Protocol;
bool VirtualTraceDatabase::QueryGroupedAscendHardwareThreadsByModelId(std::vector<ThreadGroup> &threadGroupList) {
std::map<std::string, std::string> tId2ModelIdMap = QueryAllModelIdOfAscendHardwareThreads();
if (tId2ModelIdMap.empty()) {
threadGroupList.clear();
return false;
}
std::map<std::string, ThreadGroup> modelIdToThreadsMap;
const std::string emptyModelId = "";
const std::string uintMaxModelId = std::to_string(UINT_MAX);
for (const auto &[threadId, modelId] : tId2ModelIdMap) {
if (modelId == emptyModelId || modelId == uintMaxModelId) {
ServerLog::Warn("Invalid ModelId when querying grouped ascend hardware threads.");
continue;
}
if (modelIdToThreadsMap.find(modelId) == modelIdToThreadsMap.end()) {
const ThreadGroup group;
modelIdToThreadsMap[modelId] = group;
}
modelIdToThreadsMap[modelId].push(threadId);
}
threadGroupList.clear();
threadGroupList.reserve(modelIdToThreadsMap.size());
for (const auto &[modelId, threadGroup] : modelIdToThreadsMap) {
threadGroupList.push_back(threadGroup);
}
return true;
}
uint64_t VirtualTraceDatabase::CalculateUncoveredTime(
const std::vector<Protocol::ThreadTraces> &uncovered, size_t &index, const ThreadTraces &element) {
uint64_t totalUncoveredTime = 0;
if (uncovered.empty() || index >= uncovered.size()) {
return totalUncoveredTime;
}
while (index < uncovered.size()) {
Protocol::ThreadTraces uncoveredEle = uncovered.at(index);
if (element.startTime >= uncoveredEle.endTime) {
index++;
continue;
}
if (element.endTime <= uncoveredEle.startTime) {
break;
}
uint64_t startMax = uncoveredEle.startTime > element.startTime ? uncoveredEle.startTime : element.startTime;
uint64_t endMin = uncoveredEle.endTime > element.endTime ? element.endTime : uncoveredEle.endTime;
uint64_t uncoveredTime = endMin >= startMax ? endMin - startMax : 0;
if (UINT64_MAX - totalUncoveredTime > uncoveredTime) {
totalUncoveredTime += uncoveredTime;
} else {
ServerLog::Error("Accumulation overflow occurred when calculating total uncovered time: ", uncoveredTime);
totalUncoveredTime += 0;
}
if (element.endTime > uncoveredEle.endTime) {
index++;
} else {
break;
}
}
return totalUncoveredTime;
}
void VirtualTraceDatabase::ExecuteQueryCommunicationSummaryData(
std::map<std::string, Protocol::CommunicationSummaryInfoByGroup> &summaryInfoMap,
const std::unique_ptr<SqliteResultSet> &resultSet, const std::map<std::string, std::string> &groupInfoMap,
const std::vector<Protocol::ThreadTraces> &uncovered) {
size_t index = 0;
while (resultSet->Next()) {
std::string threadName = StringUtil::FixGbkMojibakeStr(resultSet->GetString("threadName"));
Protocol::ThreadTraces ele = {.name = resultSet->GetString("name"),
.duration = resultSet->GetUint64("duration"),
.startTime = resultSet->GetUint64("startTime"),
.endTime = resultSet->GetUint64("endTime"),
.depth = resultSet->GetUint32("type"),
.threadId = std::to_string(resultSet->GetInt64("plane")),
.pid = std::to_string(resultSet->GetUint64("groupName")),
.id = ele.pid + "@" + ele.threadId,
.cname = threadName};
uint64_t flag = resultSet->GetUint64("flag");
if (groupInfoMap.find(ele.id) == groupInfoMap.end()) {
continue;
}
const std::string &group = groupInfoMap.at(ele.id);
if (summaryInfoMap.count(group) == 0) {
CommunicationSummaryInfoByGroup tmp = {group, {group, "", "", 0, 0, 0, 0}, {}};
summaryInfoMap.emplace(group, tmp);
index = 0;
}
CommunicationSummaryInfoByGroup *groupInfo = &summaryInfoMap.at(group);
if (ele.depth == 0 && groupInfo->taskMap.find(ele.threadId) == groupInfo->taskMap.end()) {
CommunicationSummaryInfoByThread newPlane = {ele.cname, ele.pid, ele.threadId, 0, 0, 0, 0};
groupInfo->taskMap.emplace(ele.threadId, newPlane);
index = 0;
} else if (ele.depth == 1 && summaryInfoMap.at(group).op.completeTransmitTime == 0) {
groupInfo->op.group = ele.pid;
groupInfo->op.plane = ele.threadId;
index = 0;
}
uint64_t uncoveredTime = CalculateUncoveredTime(uncovered, index, ele);
if (ele.depth == 0) {
groupInfo->taskMap.at(ele.threadId).UpdateData(flag == 1, ele.duration, uncoveredTime);
} else {
groupInfo->op.group = ele.pid;
groupInfo->op.plane = ele.threadId;
groupInfo->op.UpdateData(false, ele.duration, uncoveredTime);
}
}
}
* 兼容Text场景和DB场景的Group Name,其中Text场景为"Group {groupNameValue} Communication",DB场景为{groupNameValue}
* 解开 "Group {groupNameValue} Communication" 的形式,获取 {groupNameValue}
* @param str "Group {groupNameValue} Communication"
* @return {groupNameValue}
*/
std::string VirtualTraceDatabase::ExtractGroupNameValue(const std::string &str) {
static const std::regex expr(R"(Group ([\S]+(\s\w*)?) Communication)");
std::smatch match;
if (std::regex_match(str, match, expr) && match.size() > 1) {
return match.str(1);
}
return "";
}
SystemViewOverallRes VirtualTraceDatabase::CollectCommunicationGroupMetrics(
const CommunicationSummaryInfoByGroup &data, SystemViewOverallHelper &overallHelper) {
Protocol::SystemViewOverallRes group = {.totalTime = 0,
.ratio = 0,
.nums = 0,
.avg = 0,
.max = 0,
.min = 0,
.name = data.groupName,
.children = {},
.level = 2,
.id = std::to_string(overallHelper.idCounter++)};
std::string groupNameValue = ExtractGroupNameValue(group.name);
std::vector<std::string> groupNameSplit = StringUtil::Split(groupNameValue, " ");
std::string normalizedGroupNameValue = groupNameSplit.size() > 1 ? groupNameSplit[0] : groupNameValue;
auto groupInfoOpt = MetaDataCacheManager::Instance().GetParallelGroupInfo(normalizedGroupNameValue);
if (groupInfoOpt.has_value() && !groupInfoOpt.value().groupName.empty()) {
group.name = groupInfoOpt.value().groupName + ":" + data.groupName;
}
group.totalTime = NumberUtil::DoubleReservedNDigits(data.op.uncoveredTransmitTime * NS_TO_US, TWO);
group.ratio =
NumberUtil::DoubleReservedNDigits(group.totalTime / overallHelper.e2eTime * PERCENTAGE_RATIO_SCALE, TWO);
return group;
}
void VirtualTraceDatabase::ComputeCommunicationWaitAndTransmitTimeByGroup(
const std::map<std::string, CommunicationSummaryInfoByGroup> &summaryData, SystemViewOverallHelper &overallHelper,
Protocol::SystemViewOverallRes &result) {
if (summaryData.empty() || overallHelper.e2eTime <= 0) {
return;
}
for (auto &item : summaryData) {
CommunicationSummaryInfoByGroup data = item.second;
SystemViewOverallRes group = CollectCommunicationGroupMetrics(data, overallHelper);
Protocol::SystemViewOverallRes wait = {.totalTime = 0,
.ratio = 0,
.nums = 0,
.avg = 0,
.max = 0,
.min = 0,
.name = WAIT_TIME,
.children = {},
.level = 3,
.id = std::to_string(overallHelper.idCounter++)};
uint64_t minWait = UINT64_MAX;
for (auto &tmpItem : data.taskMap) {
minWait = std::min(minWait, tmpItem.second.uncoveredWaitTime);
}
wait.totalTime = NumberUtil::DoubleReservedNDigits(minWait * NS_TO_US, TWO);
wait.ratio =
NumberUtil::DoubleReservedNDigits(wait.totalTime / overallHelper.e2eTime * PERCENTAGE_RATIO_SCALE, TWO);
Protocol::SystemViewOverallRes transmit = {.totalTime = 0,
.ratio = 0,
.nums = 0,
.avg = 0,
.max = 0,
.min = 0,
.name = TRANSMIT_TIME,
.children = {},
.level = 3,
.id = std::to_string(overallHelper.idCounter++)};
if (data.op.uncoveredTransmitTime > minWait) {
transmit.totalTime =
NumberUtil::DoubleReservedNDigits((data.op.uncoveredTransmitTime - minWait) * NS_TO_US, TWO);
}
transmit.ratio =
NumberUtil::DoubleReservedNDigits(transmit.totalTime / overallHelper.e2eTime * PERCENTAGE_RATIO_SCALE, TWO);
group.children.emplace_back(wait);
group.children.emplace_back(transmit);
result.children.emplace_back(group);
}
}
std::vector<UnitCounterData> VirtualTraceDatabase::DownSampleUnitCounterData(
const std::vector<UnitCounterData> &dataList, size_t targetSize) {
if (targetSize == 0) {
return dataList;
}
std::vector<UnitCounterData> sampledData;
if (dataList.empty()) {
return sampledData;
}
size_t totalSize = dataList.size();
if (totalSize <= targetSize) {
return dataList;
}
double step = static_cast<double>(totalSize) / targetSize;
sampledData.reserve(targetSize);
for (size_t i = 0; i < targetSize; ++i) {
size_t index = static_cast<size_t>(i * step);
if (index >= totalSize) {
index = totalSize - 1;
}
sampledData.push_back(dataList[index]);
}
return sampledData;
}
SliceQuery VirtualTraceDatabase::CreateSliceQueryWithTimeRange(const SliceBaseInfo &sliceInfo) {
auto curTrackId = TrackInfoManager::Instance().GetTrackId(sliceInfo.rankId, sliceInfo.pid, sliceInfo.tid);
SliceQuery sliceQuery;
sliceQuery.trackId = curTrackId;
sliceQuery.pid = sliceInfo.pid;
sliceQuery.tid = sliceInfo.tid;
sliceQuery.rankId = sliceInfo.rankId;
auto metaTypeEnum = STR_TO_ENUM<PROCESS_TYPE>(sliceInfo.metaType);
if (metaTypeEnum.has_value()) {
sliceQuery.metaType = metaTypeEnum.value();
}
sliceQuery.minTimestamp = TraceTime::Instance().GetStartTime();
sliceQuery.startTime = sliceInfo.startTime;
sliceQuery.endTime = sliceInfo.startTime + sliceInfo.duration;
SliceQuery newSliceQuery = SliceCacheManager::GetSlicePagedQueryForDb(sliceQuery);
return newSliceQuery;
}
uint64_t VirtualTraceDatabase::GetSliceDepthForJump(const SliceQuery ¶ms, uint64_t sliceId) {
SliceAnalyzer sliceAnalyzer;
auto repositoryFactory = RepositoryFactory::Instance();
auto repo = repositoryFactory->GetSliceRespo(params.metaType);
if (repo == nullptr) {
return 0;
}
sliceAnalyzer.SetRepository(repo);
std::unordered_map<uint64_t, uint32_t> depthCache;
sliceAnalyzer.ComputeDepthInfoByTrackId(params, depthCache);
return depthCache[sliceId];
}
}