* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "TrackInfoManager.h"
#include "HostFlowRepo.h"
namespace Dic::Module::Timeline {
void HostFlowRepo::QueryFwdbwd(const FlowQuery &flowQuery, std::vector<FlowPoint> &flowPointVec) {
uint64_t nameId = QueryEnqueueNameId(flowQuery);
std::vector<PytorchApiPO> pythonApiPOS = QueryNotEnqueuePythonApi(flowQuery, nameId);
std::unordered_map<uint64_t, uint64_t> connectionIdMap = QueryConnectionIdMap(flowQuery);
auto &instance = TrackInfoManager::Instance();
std::unordered_map<std::string, PytorchApiPO> pytorchApiPOLog;
for (const auto &item : pythonApiPOS) {
if (connectionIdMap.count(item.connectionId) == 0) {
continue;
}
std::string flowId = std::to_string(connectionIdMap[item.connectionId]);
if (pytorchApiPOLog.count(flowId) == 0) {
pytorchApiPOLog[flowId] = item;
continue;
}
FlowPoint startPoint;
startPoint.type = "s";
startPoint.flowId = flowId;
PytorchApiPO start = pytorchApiPOLog[flowId];
startPoint.id = start.id;
startPoint.timestamp = start.timestamp - flowQuery.minTimestamp;
startPoint.trackId = instance.GetTrackId(flowQuery.fileId, std::to_string(start.globalTid), pythonApiTid);
startPoint.rankId = flowQuery.fileId;
FlowPoint endPoint;
endPoint.type = "f";
endPoint.flowId = flowId;
endPoint.id = item.id;
endPoint.timestamp = item.timestamp - flowQuery.minTimestamp;
endPoint.trackId = instance.GetTrackId(flowQuery.fileId, std::to_string(item.globalTid), pythonApiTid);
endPoint.rankId = flowQuery.fileId;
flowPointVec.emplace_back(startPoint);
flowPointVec.emplace_back(endPoint);
pytorchApiPOLog.erase(flowId);
}
std::sort(flowPointVec.begin(), flowPointVec.end());
}
uint64_t HostFlowRepo::QueryEnqueueNameId(const FlowQuery &flowQuery) {
uint64_t nameId = 0;
std::vector<StringIdsPO> stringIdsPOS;
stringIdsTable->Select(StringIdsColumn::ID)
.Like(StringIdsColumn::VALUE, "Enqueue%")
.ExcuteQuery(flowQuery.fileId, stringIdsPOS);
if (!std::empty(stringIdsPOS)) {
nameId = stringIdsPOS[0].id;
}
return nameId;
}
std::vector<PytorchApiPO> HostFlowRepo::QueryNotEnqueuePythonApi(const FlowQuery &flowQuery, uint64_t nameId) {
std::vector<PytorchApiPO> pythonApiPOS;
uint64_t minConnectionId = 0;
pytorchApiTable->Select(PytorchApiColumn::ID, PytorchApiColumn::TIMESTAMP)
.Select(PytorchApiColumn::CONNECTIONID, PytorchApiColumn::GLOBAL_TID)
.NotEq(PytorchApiColumn::NAME, nameId)
.GreaterEq(PytorchApiColumn::CONNECTIONID, minConnectionId)
.OrderBy(PytorchApiColumn::TIMESTAMP, TableOrder::ASC)
.ExcuteQuery(flowQuery.fileId, pythonApiPOS);
return pythonApiPOS;
}
void HostFlowRepo::QueryAsyncTaskQueue(const FlowQuery &flowQuery, std::vector<FlowPoint> &flowPointVec) {
std::vector<StringIdsPO> stringIdsPOS;
stringIdsTable->Select(StringIdsColumn::ID)
.Like(StringIdsColumn::VALUE, "Enqueue%")
.ExcuteQuery(flowQuery.fileId, stringIdsPOS);
if (std::empty(stringIdsPOS)) {
return;
}
uint64_t nameId = stringIdsPOS[0].id;
std::vector<uint64_t> pythonConnectionIds = QueryEnqueueFlowConnectionIds(flowQuery, nameId);
std::vector<uint64_t> realConnectionIds = QueryRealConnectionIds(flowQuery, pythonConnectionIds);
std::unordered_map<uint64_t, uint64_t> connectionIdMap;
std::vector<uint64_t> allPythonConnectionIds;
QueryAllPythonConnectionIds(flowQuery, realConnectionIds, connectionIdMap, allPythonConnectionIds);
std::vector<PytorchApiPO> pythonApiPOS;
pytorchApiTable->Select(PytorchApiColumn::ID, PytorchApiColumn::TIMESTAMP)
.Select(PytorchApiColumn::CONNECTIONID, PytorchApiColumn::GLOBAL_TID)
.In(PytorchApiColumn::CONNECTIONID, allPythonConnectionIds)
.OrderBy(PytorchApiColumn::TIMESTAMP, TableOrder::ASC)
.ExcuteQuery(flowQuery.fileId, pythonApiPOS);
std::unordered_set<std::string> connectionIdLog;
auto &instance = TrackInfoManager::Instance();
std::string dbPath = flowQuery.fileId;
for (auto &item : pythonApiPOS) {
FlowPoint flowPoint;
flowPoint.flowId = std::to_string(connectionIdMap[item.connectionId]);
if (connectionIdLog.count(flowPoint.flowId) == 0) {
flowPoint.type = "s";
connectionIdLog.emplace(flowPoint.flowId);
} else {
flowPoint.type = "f";
}
flowPoint.id = item.id;
flowPoint.trackId = instance.GetTrackId(flowQuery.fileId, std::to_string(item.globalTid), pythonApiTid);
flowPoint.timestamp = item.timestamp - flowQuery.minTimestamp;
flowPoint.rankId = dbPath;
flowPointVec.emplace_back(flowPoint);
}
std::sort(flowPointVec.begin(), flowPointVec.end());
}
std::unordered_map<uint64_t, uint64_t> HostFlowRepo::QueryConnectionIdMap(const FlowQuery &flowQuery) {
std::unordered_map<uint64_t, uint64_t> connectionIdMap;
std::vector<ConnectionIdsPO> connectionIdsPOVec;
connectionIdsTable->Select(ConnectionIdsColumn::ID, ConnectionIdsColumn::CONNECTIONID)
.OrderBy(ConnectionIdsColumn::ID, TableOrder::ASC)
.ExcuteQuery(flowQuery.fileId, connectionIdsPOVec);
for (const auto &item : connectionIdsPOVec) {
connectionIdMap[item.id] = item.connectionId;
}
return connectionIdMap;
}
std::vector<uint64_t> HostFlowRepo::QueryEnqueueFlowConnectionIds(const FlowQuery &flowQuery, uint64_t nameId) {
std::vector<PytorchApiPO> pythonApiConnectionIdPOS;
pytorchApiTable->Select(PytorchApiColumn::CONNECTIONID)
.Eq(PytorchApiColumn::NAME, nameId)
.ExcuteQuery(flowQuery.fileId, pythonApiConnectionIdPOS);
std::vector<uint64_t> pythonConnectionIds;
pythonConnectionIds.reserve(pythonApiConnectionIdPOS.size());
for (const auto &item : pythonApiConnectionIdPOS) {
pythonConnectionIds.emplace_back(item.connectionId);
}
return pythonConnectionIds;
}
std::vector<uint64_t> HostFlowRepo::QueryRealConnectionIds(
const FlowQuery &flowQuery, const std::vector<uint64_t> &pythonConnectionIds) {
std::vector<ConnectionIdsPO> connectionIdsPOVec;
connectionIdsTable->Select(ConnectionIdsColumn::CONNECTIONID)
.In(ConnectionIdsColumn::ID, pythonConnectionIds)
.ExcuteQuery(flowQuery.fileId, connectionIdsPOVec);
std::vector<uint64_t> realConnectionIds;
realConnectionIds.reserve(connectionIdsPOVec.size());
for (const auto &item : connectionIdsPOVec) {
realConnectionIds.emplace_back(item.connectionId);
}
return realConnectionIds;
}
void HostFlowRepo::QueryAllPythonConnectionIds(const FlowQuery &flowQuery,
const std::vector<uint64_t> &realConnectionIds, std::unordered_map<uint64_t, uint64_t> &connectionIdMap,
std::vector<uint64_t> &allPythonConnectionIds) {
std::vector<ConnectionIdsPO> allConnectionIdsPOVec;
connectionIdsTable->Select(ConnectionIdsColumn::ID, ConnectionIdsColumn::CONNECTIONID)
.In(ConnectionIdsColumn::CONNECTIONID, realConnectionIds)
.ExcuteQuery(flowQuery.fileId, allConnectionIdsPOVec);
for (const auto &item : allConnectionIdsPOVec) {
allPythonConnectionIds.emplace_back(item.id);
connectionIdMap[item.id] = item.connectionId;
}
}
void HostFlowRepo::AddAsyncNpuFlowPoint(const FlowQuery &flowQuery, std::vector<FlowPoint> &flowPointVec) {
uint64_t nameId = QueryEnqueueNameId(flowQuery);
std::vector<PytorchApiPO> pythonApiPOS = QueryNotEnqueuePythonApi(flowQuery, nameId);
std::unordered_map<uint64_t, uint64_t> connectionIdMap = QueryConnectionIdMap(flowQuery);
std::unordered_set<std::string> pytorchApiFlowIdSet;
std::unordered_set<std::string> fwdbwdFlowIdSet;
for (const auto &item : pythonApiPOS) {
if (connectionIdMap.count(item.connectionId) == 0) {
continue;
}
std::string flowId = std::to_string(connectionIdMap[item.connectionId]);
if (pytorchApiFlowIdSet.count(flowId) > 0) {
fwdbwdFlowIdSet.emplace(flowId);
continue;
}
pytorchApiFlowIdSet.emplace(flowId);
}
auto &instance = TrackInfoManager::Instance();
for (const auto &item : pythonApiPOS) {
uint64_t realConnectionId = connectionIdMap[item.connectionId];
std::string flowId = std::to_string(realConnectionId);
if (fwdbwdFlowIdSet.count(flowId) > 0) {
continue;
}
FlowPoint startPoint;
startPoint.type = "s";
startPoint.flowId = flowId;
startPoint.id = item.id;
startPoint.timestamp = item.timestamp - flowQuery.minTimestamp;
startPoint.rankId = flowQuery.fileId;
startPoint.trackId = instance.GetTrackId(flowQuery.fileId, std::to_string(item.globalTid), pythonApiTid);
flowPointVec.emplace_back(startPoint);
}
}
std::vector<uint64_t> HostFlowRepo::AddMstxFlowPoint(const FlowQuery &flowQuery, std::vector<FlowPoint> &flowPointVec) {
std::vector<MstxEventsPO> mstxPOs;
mstxEventsTable->Select(MstxEventsColumn::ID, MstxEventsColumn::CONNECTION_ID)
.Select(MstxEventsColumn::GLOBAL_TID, MstxEventsColumn::TIMESTAMP, MstxEventsColumn::DOMAIN_ID)
.NotEq(MstxEventsColumn::CONNECTION_ID, WRONG_DATA)
.ExcuteQuery(flowQuery.fileId, mstxPOs);
std::vector<uint64_t> connectionIds;
auto &instance = TrackInfoManager::Instance();
for (const auto &item : mstxPOs) {
FlowPoint startPoint;
startPoint.type = "s";
startPoint.id = item.id;
startPoint.flowId = std::to_string(item.connectionId);
startPoint.timestamp = item.timestamp - flowQuery.minTimestamp;
startPoint.rankId = flowQuery.fileId;
startPoint.trackId =
instance.GetTrackId(flowQuery.fileId, std::to_string(item.globalTid), std::to_string(item.domainId));
flowPointVec.emplace_back(startPoint);
connectionIds.emplace_back(item.connectionId);
}
return connectionIds;
}
void HostFlowRepo::AddCANNFlowPoint(const FlowQuery &flowQuery, std::vector<FlowPoint> &flowPointVec) {
std::vector<CANNApiPO> cannApiPOs;
cannApiTable->Select(CANNApiColumn::ID, CANNApiColumn::TIMESTAMP)
.Select(CANNApiColumn::GLOBAL_TID, CANNApiColumn::TYPE)
.ExcuteQuery(flowQuery.fileId, cannApiPOs);
auto &instance = TrackInfoManager::Instance();
for (const auto &item : cannApiPOs) {
FlowPoint startPoint;
startPoint.type = "s";
startPoint.id = item.id;
startPoint.flowId = std::to_string(item.id);
if (item.timestamp < flowQuery.minTimestamp) {
ServerLog::Warn("CANN flow timestamp invalid! id: %", item.id);
}
startPoint.timestamp = item.timestamp - flowQuery.minTimestamp;
startPoint.rankId = flowQuery.fileId;
startPoint.trackId =
instance.GetTrackId(flowQuery.fileId, std::to_string(item.globalTid), std::to_string(item.type));
flowPointVec.emplace_back(startPoint);
}
}
}