* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "pch.h"
#include "WsSender.h"
#include "TraceTime.h"
#include "TableDefs.h"
#include "DataBaseManager.h"
#include "ProtocolDefs.h"
#include "TrackInfoManager.h"
#include "DbMemoryDataBase.h"
namespace Dic {
namespace Module {
namespace FullDb {
using namespace Dic::Server;
using namespace Dic::Module::Timeline;
std::map<std::string, Protocol::MemorySuccess> FullDb::DbMemoryDataBase::ranks = {};
bool DbMemoryDataBase::OpenDb(const std::string &dbPath, bool clearAllTable) {
auto result = Database::OpenDb(dbPath, clearAllTable) && QueryMetaVersion();
deviceIdColumnName = "deviceId";
return result;
}
bool DbMemoryDataBase::QueryMemoryType(std::string &type, std::vector<std::string> &graphId) {
return ExecuteMemoryType(graphId, type);
}
bool DbMemoryDataBase::QueryMemoryResourceType(std::string &type) {
type = "Pytorch";
return true;
}
std::string DbMemoryDataBase::BuildOperatorDetailSql(const uint64_t baseTimestamp) {
std::string selectColumns = GetSelectOperatorMemoryFullColumnsWithCount(baseTimestamp);
std::string nameJoinStringIdsAlias = GetJoinStringIDSAlias(OpMemoryColumn::NAME);
std::string sql = StringUtil::FormatString("SELECT {} FROM {} JOIN STRING_IDS AS {} ON {}.id = {} WHERE {} = ? ",
selectColumns, TABLE_OPERATOR_MEMORY, nameJoinStringIdsAlias, nameJoinStringIdsAlias, OpMemoryColumn::NAME,
OpMemoryColumn::DEVICE_ID);
return sql;
}
int64_t DbMemoryDataBase::QueryOperatorDetail(
Protocol::MemoryOperatorParams &requestParams, std::vector<Protocol::MemoryOperator> &opDetails) {
if (!GetMemoryDbContext().withOperatorMemory) {
ServerLog::Warn("Missing table % on querying operator detail, nothing will be done.", TABLE_OPERATOR_MEMORY);
return 0;
}
std::string sql;
const FileType type = DataBaseManager::Instance().GetFileType(path);
const uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
const uint64_t offsetTime =
Timeline::TraceTime::Instance().GetOffsetByFileIdUsingMinTimestamp(requestParams.rankId);
if (startTime > std::numeric_limits<uint64_t>::max() - offsetTime) {
ServerLog::Error("Failed to calculate relative to the reference time due to integer overflow.");
return -1;
}
if (type == FileType::PYTORCH) {
sql = DbMemoryDataBase::BuildOperatorDetailSql(startTime + offsetTime);
} else {
ServerLog::Error("Memory tab does not support msprof data.");
return -1;
}
AddOperatorSql(requestParams, sql);
return ExecuteOperatorDetail(requestParams, opDetails, sql);
}
bool DbMemoryDataBase::QueryEntireOperatorTable(Protocol::MemoryOperatorParams &requestParams,
std::vector<Protocol::MemoryOperator> &opDetails, uint64_t offsetTime) {
if (!GetMemoryDbContext().withOperatorMemory) {
ServerLog::Warn(
"Missing table % on querying entire operator table, nothing will be done.", TABLE_OPERATOR_MEMORY);
return true;
}
std::string sql;
FileType type = DataBaseManager::Instance().GetFileType(path);
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
std::string startTimeStr = std::to_string(startTime);
std::string offsetTimeStr = std::to_string(offsetTime);
if (type == FileType::PYTORCH) {
sql = BuildOperatorDetailSql(NumberSafe::Add(startTime, offsetTime));
} else {
ServerLog::Error("Memory tab does not support msprof data.");
return false;
}
return ExecuteQueryEntireOperatorTable(requestParams, opDetails, sql);
}
bool DbMemoryDataBase::QueryComponentDetail(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryTableColumnAttr> &columnAttr,
std::vector<Protocol::MemoryComponent> &componentDetails) {
if (!GetMemoryDbContext().withNpuModuleMem) {
ServerLog::Warn("Missing table % on querying component detail, nothing will be done.", TABLE_NPU_MODULE_MEM);
return true;
}
std::string sql;
FileType type = DataBaseManager::Instance().GetFileType(path);
if (type == FileType::PYTORCH) {
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
uint64_t offsetTime = Timeline::TraceTime::Instance().GetOffsetByFileIdUsingMinTimestamp(requestParams.rankId);
sql = "SELECT t4.name AS componentColumn, ROUND(t3.size / (1024.0 * 1024.0), 2) AS totalReservedColumn,"
" t3.timestamp_maxsize AS timestampColumn FROM "
"(SELECT t1.moduleId AS id, t1.totalReserved AS size, MIN(ROUND((t1.timestampNs - " +
std::to_string(NumberSafe::Add(startTime, offsetTime)) +
") / (1000.0 * 1000.0), 3)) AS timestamp_maxsize FROM " + TABLE_NPU_MODULE_MEM + " AS t1 JOIN " +
"(SELECT moduleId, MAX(totalReserved) AS max_total_reserved FROM " + TABLE_NPU_MODULE_MEM +
" GROUP BY moduleId HAVING max_total_reserved >= " + std::to_string(componentThresholdByte) +
") AS t2 ON t1.moduleId = t2.moduleId AND t1.totalReserved = t2.max_total_reserved "
" WHERE t1.deviceId = ? "
"GROUP BY t1.moduleId, t1.totalReserved) AS t3 JOIN ENUM_MODULE AS t4 ON t3.id = t4.id";
if (!requestParams.order.empty() && !requestParams.orderBy.empty()) {
sql += " ORDER BY " + requestParams.orderBy + "Column";
if (requestParams.order == "ascend") {
sql += " ASC ";
} else {
sql += " DESC ";
}
}
sql += " LIMIT ? OFFSET ? ";
} else {
ServerLog::Error("Failed to query component detail: Memory tab does not support msprof data.");
return false;
}
return ExecuteComponentDetail(requestParams, columnAttr, componentDetails, sql);
}
bool DbMemoryDataBase::QueryEntireComponentTable(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryComponent> &componentDetails, uint64_t offsetTime) {
if (!GetMemoryDbContext().withNpuModuleMem) {
ServerLog::Warn("Missing table % on querying entire component, nothing will be done.", TABLE_NPU_MODULE_MEM);
return true;
}
std::string sql;
FileType type = DataBaseManager::Instance().GetFileType(path);
if (type == FileType::PYTORCH) {
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
sql = "SELECT t4.name, ROUND(t3.size / (1024.0 * 1024.0), 2), t3.timestamp_maxsize FROM "
"(SELECT t1.moduleId AS id, t1.totalReserved AS size, MIN(ROUND((t1.timestampNs - " +
std::to_string(NumberSafe::Add(startTime, offsetTime)) +
") / (1000.0 * 1000.0), 3)) AS timestamp_maxsize FROM " + TABLE_NPU_MODULE_MEM + " AS t1 JOIN " +
"(SELECT moduleId, MAX(totalReserved) AS max_total_reserved FROM " + TABLE_NPU_MODULE_MEM +
" GROUP BY moduleId HAVING max_total_reserved >= " + std::to_string(componentThresholdByte) +
") AS t2 ON t1.moduleId = t2.moduleId AND t1.totalReserved = t2.max_total_reserved "
"WHERE t1.deviceId = ? "
"GROUP BY t1.moduleId, t1.totalReserved) AS t3 JOIN ENUM_MODULE AS t4 ON t3.id = t4.id ";
} else {
ServerLog::Error("Failed to query entire component table: Memory tab does not support msprof data.");
return false;
}
return ExecuteQueryEntireComponentTable(requestParams, componentDetails, sql);
}
bool DbMemoryDataBase::QueryMemoryView(
Protocol::MemoryViewParams &requestParams, Protocol::MemoryViewData &operatorBody, uint64_t offsetTime) {
std::string sql = "";
FileType type = DataBaseManager::Instance().GetFileType(path);
if (!GetMemoryDbContext().withMemoryRecord) {
ServerLog::Warn("Missing table % on querying memory view, nothing will be done.", TABLE_MEMORY_RECORD);
return true;
}
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
if (type == FileType::PYTORCH) {
sql += "select * from ( ";
sql += "SELECT NAME.value AS component, ROUND((timestamp - " + std::to_string(startTime) + " - " +
std::to_string(offsetTime) +
") / (1000.0 * 1000.0), 3) as timestamp, "
"ROUND(totalAllocated / (1024.0 * 1024.0), 2) as totalAllocated, "
" ROUND(totalReserved / (1024.0 * 1024.0), 2) as totalReserve, "
"ROUND(totalActive / (1024.0 * 1024.0), 2) as totalActive, streamPtr as stream, " +
deviceIdColumnName + " FROM ";
sql += TABLE_MEMORY_RECORD + " JOIN STRING_IDS AS NAME ON NAME.id = MEMORY_RECORD.component ";
if (GetMemoryDbContext().withNpuMem) {
sql += " UNION ALL select 'APP' as component, ROUND((timestampNs - " + std::to_string(startTime) +
" ) / (1000.0 * 1000.0), 2) as timestampNs, "
" 0 as totalAllocated, ROUND((hbm + ddr) / (1024.0 * 1024.0), 2) as totalReserve, "
" 0 as totalActive, '' as stream, deviceId from NPU_MEM join STRING_IDS as ids on ids.id = type "
" where value = 'app' ";
}
sql += " ) WHERE " + deviceIdColumnName + " = ? ";
} else {
ServerLog::Error("Memory tab does not support msprof data.");
return false;
}
std::vector<Protocol::ComponentDto> componentDtoVec;
std::vector<std::string> streams;
if (!ExecuteQueryMemoryViewExecuteSql(requestParams, componentDtoVec, streams, sql, deviceIdColumnName)) {
return false;
}
return ExecuteQueryMemoryViewGetGraph(requestParams, componentDtoVec, streams, operatorBody);
}
bool DbMemoryDataBase::QueryComponentsTotalNum(Protocol::MemoryComponentParams &requestParams, int64_t &totalNum) {
if (!GetMemoryDbContext().withNpuModuleMem) {
ServerLog::Warn("Missing table % on querying component detail, nothing will be done.", TABLE_NPU_MODULE_MEM);
return true;
}
std::string sql;
FileType type = DataBaseManager::Instance().GetFileType(path);
if (type == FileType::PYTORCH) {
sql = "SELECT count(*) FROM (SELECT t2.name FROM " + TABLE_NPU_MODULE_MEM +
" AS t1 JOIN ENUM_MODULE AS t2 ON t1.moduleId = t2.id WHERE deviceId = ? "
" GROUP BY t2.name HAVING MAX(t1.totalReserved) >= " +
std::to_string(componentThresholdByte) + ") AS t3";
} else {
ServerLog::Error("Failed to query components total num: Memory tab does not support msprof data.");
return false;
}
return ExecuteComponentTotalNum(requestParams, totalNum, sql);
}
bool DbMemoryDataBase::QueryOperatorSize(Protocol::MemoryOperatorSizeParams &requestParams, double &min, double &max) {
FileType type = DataBaseManager::Instance().GetFileType(path);
std::string sql = "";
if (!GetMemoryDbContext().withOperatorMemory) {
ServerLog::Warn("Missing table % on querying operator size, nothing will be done.", TABLE_OPERATOR_MEMORY);
return true;
}
if (type == FileType::PYTORCH) {
sql += "SELECT ROUND(min(size)/ 1024.0, 2) as minSize, "
" ROUND(max(size)/ 1024.0, 2) as maxSize FROM " +
TABLE_OPERATOR_MEMORY + " WHERE " + deviceIdColumnName + " = ? ";
} else {
ServerLog::Error("Memory tab does not support msprof data.");
return false;
}
return ExecuteOperatorSize(requestParams, min, max, sql);
}
bool DbMemoryDataBase::QueryStaticOperatorSize(
Protocol::StaticOperatorSizeParams &requestParams, double &min, double &max) {
return false;
}
int64_t DbMemoryDataBase::QueryStaticOperatorList(
Protocol::StaticOperatorListParams &requestParams, std::vector<Protocol::StaticOperatorItem> &opDetails) {
return -1;
}
bool DbMemoryDataBase::QueryEntireStaticOperatorTable(
Protocol::StaticOperatorListParams &requestParams, std::vector<Protocol::StaticOperatorItem> &opDetails) {
return false;
}
bool DbMemoryDataBase::QueryStaticOperatorGraph(
Protocol::StaticOperatorGraphParams &requestParams, Protocol::StaticOperatorGraphItem &graphItem) {
return false;
}
void DbMemoryDataBase::ParserEnd(std::string rankId, bool result, std::string fileId) {
if (!result) {
return;
}
Server::ServerLog::Info("[Memory]Parser ends, Rank ID: ", rankId);
if (ranks.count(rankId) == 0) {
Protocol::MemorySuccess success;
success.rankId = rankId;
success.parseSuccess = true;
success.hasFile = true;
success.fileId = fileId;
auto rankInfos = TrackInfoManager::Instance().GetRankListByFileId(fileId, rankId);
if (!rankInfos.empty()) {
success.rankInfo = rankInfos[0];
}
ranks.emplace(rankId, success);
} else {
ranks[rankId].parseSuccess = true;
ranks[rankId].hasFile = true;
}
}
void DbMemoryDataBase::ParseCallBack(
const std::string &rankId, const std::string &fileId, bool result, const std::string &msg) {
if (rankId.empty()) {
ranks.clear();
auto event = std::make_unique<Protocol::ModuleResetEvent>();
event->moduleName = Protocol::MODULE_MEMORY;
event->result = true;
event->reset = true;
SendEvent(std::move(event));
} else {
auto event = std::make_unique<Protocol::ParseMemoryCompletedEvent>();
event->moduleName = Protocol::MODULE_TIMELINE;
event->result = result;
event->isCluster = true;
event->fileId = fileId;
std::vector<Protocol::MemorySuccess> memoryResult;
memoryResult.push_back(ranks[rankId]);
event->memoryResult = memoryResult;
SendEvent(std::move(event));
}
}
std::map<std::string, Protocol::MemorySuccess> DbMemoryDataBase::GetRanks() { return ranks; }
void DbMemoryDataBase::Reset() {
ServerLog::Info("Memory reset. Wait task completed.");
ranks.clear();
ServerLog::Info("Memory task completed.");
auto databaseList = Timeline::DataBaseManager::Instance().GetAllMemoryDatabase();
for (auto &db : databaseList) {
auto database = dynamic_cast<DbMemoryDataBase *>(db);
if (database != nullptr) {
database->CloseDb();
}
}
Timeline::DataBaseManager::Instance().Clear(Timeline::DatabaseType::MEMORY);
}
void DbMemoryDataBase::GetSelectOperatorMemoryColumnAndAlias(
std::string_view columnKey, uint64_t baseTimestamp, std::string &column, std::string &alias) {
if (columnKey == "id") {
column = StringUtil::FormatString("{}.{}", TABLE_OPERATOR_MEMORY, OpMemoryColumn::ID);
alias = columnKey;
return;
}
alias = StringUtil::FormatString("_{}", columnKey);
if (OPERATOR_MEMORY_ARA_SIZE_COLUMNS.find(columnKey) != OPERATOR_MEMORY_ARA_SIZE_COLUMNS.end()) {
column = StringUtil::FormatString("ROUND({}/(1024.0*1024.0), 2)", columnKey);
return;
}
std::string baseTimestampStr;
if (OPERATOR_MEMORY_TIMESTAMP_NS_COLUMNS_SET.find(columnKey) != OPERATOR_MEMORY_TIMESTAMP_NS_COLUMNS_SET.end()) {
if (columnKey == OpMemoryColumn::DURATION || columnKey == OpMemoryColumn::ACTIVE_DURATION) {
baseTimestampStr = "0";
} else {
baseTimestampStr = std::to_string(baseTimestamp);
}
column = StringUtil::FormatString("ROUND(({} - {})/(1000.0*1000.0), 3)", columnKey, baseTimestampStr);
return;
}
if (columnKey == OpMemoryColumn::SIZE) {
column = StringUtil::FormatString("ROUND({}/1024.0, 2)", columnKey);
return;
}
if (columnKey == OpMemoryColumn::NAME) {
column = StringUtil::FormatString("{}.value", GetJoinStringIDSAlias(columnKey));
return;
}
if (columnKey == OpMemoryColumn::ID) {
column = StringUtil::FormatString("{}.rowid", TABLE_OPERATOR_MEMORY);
return;
}
column = std::string(columnKey);
}
std::string DbMemoryDataBase::GetJoinStringIDSAlias(std::string_view joinCol) {
return StringUtil::FormatString("SI_{}", joinCol);
}
MemoryDataBaseContext DbMemoryDataBase::GetMemoryDbContext() {
if (!initContextFlag) {
memDbContext.withMemoryRecord = CheckTableExist(TABLE_MEMORY_RECORD);
memDbContext.withOperatorMemory = CheckTableExist(TABLE_OPERATOR_MEMORY);
memDbContext.withNpuModuleMem = CheckTableExist(TABLE_NPU_MODULE_MEM);
memDbContext.withNpuMem = CheckTableExist(TABLE_NPU_MEM);
initContextFlag = true;
}
return memDbContext;
}
}
}
}