* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <algorithm>
#include "WsSessionManager.h"
#include "DataBaseManager.h"
#include "TraceTime.h"
#include "TraceDatabaseHelper.h"
#include "QueryFwdBwdTimelineHandler.h"
namespace Dic::Module::Summary {
using namespace Dic::Server;
using namespace Dic::Module::Timeline;
std::map<std::string, PipelineFwdBwdTimelineByRank> QueryFwdBwdTimelineHandler::dataMap;
bool QueryFwdBwdTimelineHandler::HandleRequest(std::unique_ptr<Protocol::Request> requestPtr) {
auto &request = dynamic_cast<PipelineFwdBwdTimelineRequest &>(*requestPtr.get());
std::unique_ptr<PipelineFwdBwdTimelineResponse> responsePtr = std::make_unique<PipelineFwdBwdTimelineResponse>();
PipelineFwdBwdTimelineResponse &response = *responsePtr.get();
SetBaseResponse(request, response);
std::string err;
if (!request.params.CheckParams(err)) {
SetSummaryError(ErrorCode::PARAMS_ERROR);
SendResponse(std::move(responsePtr), false);
return false;
}
std::vector<std::string> rankIds = StringUtil::SplitStringWithParenthesesByComma(request.params.stageId);
if (rankIds.empty()) {
SetSummaryError(ErrorCode::GET_RANK_ID_FAILED);
SendResponse(std::move(responsePtr), false);
return false;
}
dataMap.clear();
static ThreadPool threadPool = ThreadPool(4);
for (auto const &rankId : rankIds) {
response.body.rankLists.push_back(rankId);
PipelineFwdBwdTimelineByRank rank = {rankId, {}, {}};
dataMap.emplace(rankId, rank);
threadPool.AddTask(QueryFwdBwdTimelineByRank, TraceIdManager::GetTraceId(), rankId, request.params.stepId,
request.params.clusterPath);
}
threadPool.WaitForAllTasks();
CalFlowInfo(response.body.flowList, rankIds);
for (auto const &rankId : rankIds) {
response.body.rankDataList.push_back(dataMap[rankId]);
}
for (auto &rank : response.body.rankDataList) {
if (rank.componentDataList.empty()) {
continue;
}
for (auto &component : rank.componentDataList) {
if (component.traceList.empty()) {
continue;
}
auto first = component.traceList.at(0);
auto last = component.traceList.at(component.traceList.size() - 1);
response.body.maxTime = std::max(response.body.maxTime, last.startTime + last.duration);
response.body.minTime = std::min(response.body.minTime, first.startTime);
}
}
SendResponse(std::move(responsePtr), true);
dataMap.clear();
return true;
}
bool QueryFwdBwdTimelineHandler::QueryFwdBwdTimelineByRank(
const std::string &rankId, const std::string &stepId, const std::string &clusterPath) {
if (dataMap.find(rankId) == dataMap.end()) {
return false;
}
auto database = DataBaseManager::Instance().GetTraceDatabaseByRankId(rankId);
if (database == nullptr) {
database = DataBaseManager::Instance().GetTraceDatabaseInCluster(clusterPath, rankId);
if (database == nullptr) {
ServerLog::Error("Failed to query fwd/bwd timeline data by rank due to null connection for database.");
return false;
}
}
if (QueryFwdBwdTimelineFromMstx(rankId, stepId, database)) {
return true;
}
return QueryFwdBwdTimelineFromFlow(rankId, stepId, database);
}
bool QueryFwdBwdTimelineHandler::QueryFwdBwdTimelineFromMstx(const std::string &rankId, const std::string &stepId,
const std::shared_ptr<Dic::Module::Timeline::VirtualTraceDatabase> &database) {
if (!database->CheckTableExist(TABLE_STEP_TASK_INFO)) {
ServerLog::Warn("The table of step task is not exist, skip to query fwd/bwd info from mstx.");
return false;
}
auto rank = &dataMap.at(rankId);
rank->rankId = rankId;
PipelineFwdBwdTimelineByComponent fwdBwdData = {LANE_FP_BP, {}};
if (!database->QueryFwdBwdFromMstx(fwdBwdData.traceList) || fwdBwdData.traceList.empty()) {
ServerLog::Error("Fail to query fwd/bwd info from mstx.");
return false;
}
rank->componentDataList.push_back(fwdBwdData);
PipelineFwdBwdTimelineByComponent p2pOpData = {LANE_P2P_OP, {}};
if (!database->QueryP2PCommunicationOpHaveConnectionId(p2pOpData.traceList)) {
ServerLog::Warn("Query Fwd/Bwd timeline from mstx without p2p communication op info.");
return true;
}
rank->componentDataList.push_back(p2pOpData);
return true;
}
void QueryFwdBwdTimelineHandler::CalFlowInfo(std::vector<FlowInfo> &flowList, const std::vector<std::string> &rankIds) {
std::map<std::string, std::vector<FlowPointInfo>> pointMap;
for (const auto &rank : rankIds) {
auto pipeline = dataMap[rank];
std::vector<Protocol::ThreadTraces> p2pTraceList;
for (auto &item : pipeline.componentDataList) {
if (item.component == LANE_P2P_OP) {
p2pTraceList = item.traceList;
break;
}
}
for (const auto &item : p2pTraceList) {
if (item.opConnectionId.empty()) {
continue;
}
FlowPointInfo point{rank, item.startTime, item.name};
pointMap[item.opConnectionId].push_back(point);
}
}
const size_t flowPointNumber = 2;
for (auto &item : pointMap) {
if (item.second.size() != flowPointNumber) {
continue;
}
FlowInfo flowInfo;
flowInfo.flowPointList = item.second;
std::sort(flowInfo.flowPointList.begin(), flowInfo.flowPointList.end(),
[](const FlowPointInfo &pointA, const FlowPointInfo &pointB) {
if (StringUtil::Contains(StringUtil::ToLower(pointA.opName), "send")) {
return true;
}
if (StringUtil::Contains(StringUtil::ToLower(pointB.opName), "receive")) {
return true;
}
return false;
});
flowList.push_back(flowInfo);
}
}
bool QueryFwdBwdTimelineHandler::QueryFwdBwdTimelineFromFlow(const std::string &rankId, const std::string &stepId,
const std::shared_ptr<Dic::Module::Timeline::VirtualTraceDatabase> &database) {
uint64_t offset = Timeline::TraceTime::Instance().GetStartTime();
auto rank = &dataMap.at(rankId);
Protocol::ExtremumTimestamp range = {offset, (uint64_t)INT64_MAX};
database->QueryStepDuration(stepId, range.minTimestamp, range.maxTimestamp);
rank->rankId = rankId;
PipelineFwdBwdTimelineByComponent fwdBwdData = {LANE_FP_BP, {}};
if (!database->QueryFwdBwdDataByFlow(rankId, offset, range, fwdBwdData.traceList)) {
ServerLog::Warn("Failed to query fwd/bwd detail trace data for rank ", rankId);
}
rank->componentDataList.push_back(fwdBwdData);
PipelineFwdBwdTimelineByComponent p2pOpData = {LANE_P2P_OP, {}};
if (!database->QueryP2PCommunicationOpData(rankId, offset, range, p2pOpData.traceList)) {
ServerLog::Warn("Failed to query p2p operator detail for rank ", rankId);
}
rank->componentDataList.push_back(p2pOpData);
return true;
}
}