* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2026 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "ServerLog.h"
#include "Paginator.h"
#include "MemcpyOverallDatabaseAccesser.h"
namespace Dic::Module::Timeline {
DataType MemcpyOverallDatabaseAccesser::GetDatabaseType() const {
if (fileId_.empty()) {
return DataType::TEXT;
}
return DataBaseManager::Instance().GetDataType(fileId_);
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyRecords(
const uint64_t startTime, const uint64_t endTime, std::vector<MemcpyRecord> &records) const {
if (!database_ || fileId_.empty()) {
return false;
}
DataType dataType = GetDatabaseType();
const uint64_t minTimestamp = TraceTime::Instance().GetStartTime();
uint64_t absStart, absEnd;
if (!SafeAddUint64(startTime, minTimestamp, absStart) || !SafeAddUint64(endTime, minTimestamp, absEnd)) {
Server::ServerLog::Error("Time conversion overflow: relative time + base timestamp exceeds uint64_t limit");
return false;
}
if (dataType == DataType::TEXT) {
GetMemcpyRecordsFromText(absStart, absEnd, records);
} else if (dataType == DataType::DB) {
GetMemcpyRecordsFromDb(absStart, absEnd, records);
} else {
return false;
}
for (auto &record : records) {
if (record.startTime < minTimestamp || record.endTime < minTimestamp) {
ServerLog::Warn("Unexpected condition: slice time is less than min timestamp. "
"time: [",
record.startTime, ", ", record.endTime, " Min time stamp: ", minTimestamp);
return false;
}
record.startTime -= minTimestamp;
record.endTime -= minTimestamp;
}
return true;
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyDetailRecordsPaged(uint64_t startTime, uint64_t endTime,
const std::string &tid, const std::optional<std::string> &memcpyType, uint32_t current, uint32_t pageSize,
const OrderParam &orderParam, std::vector<MemcpyDetailRecord> &records, uint64_t &total) const {
if (!database_ || fileId_.empty()) {
return false;
}
DataType dataType = GetDatabaseType();
const uint64_t minTimestamp = TraceTime::Instance().GetStartTime();
uint64_t absStart, absEnd;
if (!SafeAddUint64(startTime, minTimestamp, absStart) || !SafeAddUint64(endTime, minTimestamp, absEnd)) {
Server::ServerLog::Error("Time conversion overflow: relative time + base timestamp exceeds uint64_t limit");
return false;
}
if (dataType == DataType::TEXT) {
auto [sortField, sortDir] = ParseSortParams(orderParam.orderBy, orderParam.GetNormalizeOrderType());
GetMemcpyDetailRecordsPagedFromText(
absStart, absEnd, tid, memcpyType, current, pageSize, sortField, sortDir, records, total);
} else if (dataType == DataType::DB) {
GetMemcpyDetailRecordsPagedFromDb(
absStart, absEnd, tid, memcpyType, current, pageSize, orderParam.GenerateSql(), records, total);
} else {
return false;
}
for (auto &record : records) {
if (record.timestamp < minTimestamp) {
ServerLog::Warn("Unexpected condition: slice timestamp is less than min timestamp. "
"timestamp: ",
record.timestamp, " Min time stamp: ", minTimestamp);
return false;
}
record.timestamp -= minTimestamp;
}
return true;
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyRecordsFromText(
uint64_t startTime, uint64_t endTime, std::vector<MemcpyRecord> &records) const {
if (!database_) {
return false;
}
try {
const bool useTimeSearch = startTime != endTime;
std::string sql = "SELECT t.tid, t.thread_name, s.args, s.timestamp, s.end_time FROM slice s "
"LEFT JOIN thread t ON s.track_id = t.track_id "
"WHERE s.name = 'MEMCPY_ASYNC' ";
if (useTimeSearch) {
sql += "AND timestamp >= ? AND end_time <= ?";
}
auto stmt = database_->CreatPreparedStatement(sql);
if (stmt == nullptr) {
Server::ServerLog::Error("Querying memcpy records from db has error, Fail to get stmt.");
return false;
}
if (useTimeSearch) {
stmt->BindParams(startTime, endTime);
}
auto resultSet = stmt->ExecuteQuery();
if (resultSet == nullptr) {
Server::ServerLog::Error("Querying memcpy records from db has error, Fail to execute query.");
return false;
}
while (resultSet->Next()) {
MemcpyRecord record;
record.threadId = resultSet->GetString("tid");
record.threadName = resultSet->GetString("thread_name");
std::string argsStr = resultSet->GetString("args");
const auto parsed = ParseOperationAndSizeFromJson(argsStr);
record.memcpyType = parsed.first;
record.size = parsed.second;
record.startTime = resultSet->GetUint64("timestamp");
record.endTime = resultSet->GetUint64("end_time");
record.duration = static_cast<double>(record.endTime - record.startTime);
records.push_back(record);
}
return true;
} catch (...) {
return false;
}
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyRecordsFromDb(
uint64_t startTime, uint64_t endTime, std::vector<MemcpyRecord> &records) const {
if (!database_) {
return false;
}
try {
const bool useTimeSearch = startTime != endTime;
std::string sql = "SELECT t.streamId AS tid, emo.name AS memcpyOperation, mi.size, t.startNs, t.endNs "
"FROM " +
TABLE_TASK + " t JOIN " + TABLE_MEMCPY_INFO +
" mi "
"ON t.globalTaskId = mi.globalTaskId "
"LEFT JOIN " +
TABLE_ENUM_MEMCPY_OPERATION + " emo ON mi.memcpyOperation = emo.id ";
if (useTimeSearch) {
sql += "WHERE t.startNs >= ? AND t.endNs <= ?";
}
auto stmt = database_->CreatPreparedStatement(sql);
if (stmt == nullptr) {
Server::ServerLog::Error("Querying memcpy records from db has error, Fail to get stmt.");
return false;
}
if (useTimeSearch) {
stmt->BindParams(startTime, endTime);
}
auto resultSet = stmt->ExecuteQuery();
if (resultSet == nullptr) {
Server::ServerLog::Error("Querying memcpy records from db has error, Fail to execute query.");
return false;
}
while (resultSet->Next()) {
MemcpyRecord record;
record.threadId = resultSet->GetString("tid");
record.threadName = "Stream " + record.threadId;
record.memcpyType = resultSet->GetString("memcpyOperation");
record.size = resultSet->GetUint64("size");
record.startTime = resultSet->GetUint64("startNs");
record.endTime = resultSet->GetUint64("endNs");
record.duration = static_cast<double>(record.endTime - record.startTime);
records.push_back(record);
}
return true;
} catch (...) {
return false;
}
}
std::pair<std::string, std::vector<std::string>> MemcpyOverallDatabaseAccesser::BuildMemcpyDetailBaseQueryText(
uint64_t startTime, uint64_t endTime, const std::string &tidFilter) {
std::vector<std::string> params;
std::ostringstream sql;
sql << "SELECT s.id AS slice_id, s.name, s.args, s.timestamp, s.duration FROM slice s "
<< "LEFT JOIN thread t ON s.track_id = t.track_id "
<< "WHERE s.name = 'MEMCPY_ASYNC' ";
if (startTime != endTime) {
sql << "AND s.timestamp >= ? AND s.end_time <= ? ";
params.emplace_back(std::to_string(startTime));
params.emplace_back(std::to_string(endTime));
}
if (!tidFilter.empty()) {
sql << "AND t.tid = ? ";
params.emplace_back(tidFilter);
}
return {sql.str(), params};
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyDetailRecordsPagedFromText(uint64_t startTime, uint64_t endTime,
const std::string &tid, const std::optional<std::string> &memcpyType, uint32_t current, uint32_t pageSize,
SortField orderByField, SortDirection orderDir, std::vector<MemcpyDetailRecord> &records, uint64_t &total) const {
if (!database_) {
return false;
}
try {
auto [baseQuery, baseParams] = BuildMemcpyDetailBaseQueryText(startTime, endTime, tid);
if (baseQuery.empty()) {
return false;
}
auto stmt = database_->CreatPreparedStatement(baseQuery);
if (!stmt) {
Server::ServerLog::Error("Failed to create prepared statement for memcpy detail (TEXT)");
return false;
}
for (const auto ¶m : baseParams) {
stmt->BindParams(param);
}
auto resultSet = stmt->ExecuteQuery();
if (!resultSet) {
Server::ServerLog::Error("Failed to execute query for memcpy detail (TEXT)");
return false;
}
std::vector<MemcpyDetailRecord> filtered;
while (resultSet->Next()) {
std::string argsStr = resultSet->GetString("args");
auto [operation, size] = ParseOperationAndSizeFromJson(argsStr);
if (memcpyType.has_value() && operation != memcpyType) {
continue;
}
MemcpyDetailRecord record;
record.timestamp = resultSet->GetUint64("timestamp");
record.duration = resultSet->GetUint64("duration");
record.size = size;
record.name = resultSet->GetString("name");
record.id = std::to_string(resultSet->GetUint64("slice_id"));
filtered.push_back(record);
}
SortRecordsInMemory(filtered, orderByField, orderDir);
const Paginator<MemcpyDetailRecord> paginator(filtered, pageSize);
total = paginator.GetTotal();
records = paginator.GetPage(current);
return true;
} catch (const std::exception &e) {
Server::ServerLog::Error("Exception in GetMemcpyDetailRecordsPagedFromText: " + std::string(e.what()));
return false;
} catch (...) {
Server::ServerLog::Error("Unknown exception in GetMemcpyDetailRecordsPagedFromText");
return false;
}
}
std::pair<std::string, std::vector<std::string>> MemcpyOverallDatabaseAccesser::BuildMemcpyDetailBaseQueryDb(
uint64_t startTime, uint64_t endTime, const std::string &tidFilter,
const std::optional<std::string> &memcpyTypeFilter) {
std::vector<std::string> params;
std::ostringstream sql;
* @note 获取 name 的解释
* 1. MSTX_EVENTS 事件目前不可能有 MEMCPY 的算子,忽略
* 2. 非 MSTX_EVENTS 事件算子名称: DbSqlDefs.h 文件中的 ASCEND_THREADS_EXCLUDING_MSTX_BY_PID 语句获取名称的 SQL 是
* `coalesce(c.name, s.name, main.taskType) as name` c 表示 COMPUTE, s 表示 COMMUNICATION
* 在业务逻辑上 MEMCPY 既不是 COMPUTE 也不是 COMMUNICATION,因此只能取 main.taskType, main 表示 TASK
**/
sql << "SELECT t.globalTaskId, t.streamId, emo.name AS memcpyOperation, si.value AS name, "
<< "mi.size AS size, t.startNs AS startTime, t.endNs - t.startNs AS duration FROM " << TABLE_TASK << " t "
<< "JOIN " << TABLE_MEMCPY_INFO << " mi ON t.globalTaskId = mi.globalTaskId "
<< "LEFT JOIN " << TABLE_ENUM_MEMCPY_OPERATION << " emo ON mi.memcpyOperation = emo.id "
<< "LEFT JOIN " << TABLE_STRING_IDS << " si ON si.id = t.taskType "
<< "WHERE 1=1 ";
if (startTime != endTime) {
sql << "AND t.startNs >= ? AND t.endNs <= ? ";
params.emplace_back(std::to_string(startTime));
params.emplace_back(std::to_string(endTime));
}
if (!tidFilter.empty()) {
sql << "AND t.streamId = ? ";
params.emplace_back(tidFilter);
}
if (memcpyTypeFilter.has_value()) {
sql << "AND emo.name = ? ";
params.emplace_back(memcpyTypeFilter.value());
}
return {sql.str(), params};
}
void MemcpyOverallDatabaseAccesser::GetMemcpyDetailTotalFromDb(
const std::string &baseQuery, const std::vector<std::string> &baseParams, uint64_t &total) const {
const std::string totalSql = "WITH filtered AS (" + baseQuery + ") SELECT COUNT(*) AS total FROM filtered";
auto totalStmt = database_->CreatPreparedStatement(totalSql);
if (!totalStmt) {
Server::ServerLog::Error("Failed to create prepared statement for memcpy detail total (DB)");
return;
}
for (const auto ¶m : baseParams) {
totalStmt->BindParams(param);
}
auto resultSet = totalStmt->ExecuteQuery();
if (!resultSet || !resultSet->Next()) {
Server::ServerLog::Error("Failed to execute query for memcpy detail total (DB)");
return;
}
total = resultSet->GetUint64("total");
}
bool MemcpyOverallDatabaseAccesser::GetMemcpyDetailRecordsPagedFromDb(uint64_t startTime, uint64_t endTime,
const std::string &tid, const std::optional<std::string> &memcpyType, uint32_t current, uint32_t pageSize,
const std::string &orderSql, std::vector<MemcpyDetailRecord> &records, uint64_t &total) const {
if (!database_) {
return false;
}
try {
auto [baseQuery, baseParams] = BuildMemcpyDetailBaseQueryDb(startTime, endTime, tid, memcpyType);
if (baseQuery.empty()) {
return false;
}
GetMemcpyDetailTotalFromDb(baseQuery, baseParams, total);
std::string dataSql =
"WITH filtered AS (" + baseQuery + ") SELECT * FROM filtered " + orderSql + "LIMIT ? OFFSET ?";
if (current - 1 != 0 && pageSize > UINT64_MAX / (current - 1)) {
Server::ServerLog::Error("Pagination overflow, it exceeds uint64_t limit");
return false;
}
uint64_t offset = static_cast<uint64_t>(current - 1) * pageSize;
auto dataStmt = database_->CreatPreparedStatement(dataSql);
if (!dataStmt) {
Server::ServerLog::Error("Failed to create prepared statement for memcpy detail (DB)");
return false;
}
for (const auto ¶m : baseParams) {
dataStmt->BindParams(param);
}
dataStmt->BindParams(pageSize);
dataStmt->BindParams(offset);
auto resultSet = dataStmt->ExecuteQuery();
if (!resultSet) {
Server::ServerLog::Error("Failed to execute query for memcpy detail (DB)");
return false;
}
while (resultSet->Next()) {
MemcpyDetailRecord record;
record.timestamp = resultSet->GetUint64("startTime");
record.duration = resultSet->GetUint64("duration");
record.size = resultSet->GetUint64("size");
record.name = resultSet->GetString("name");
record.id = std::to_string(resultSet->GetUint64("globalTaskId"));
records.push_back(record);
}
return true;
} catch (const std::exception &e) {
Server::ServerLog::Error("Exception in GetMemcpyDetailRecordsPagedFromDb: " + std::string(e.what()));
return false;
} catch (...) {
Server::ServerLog::Error("Unknown exception in GetMemcpyDetailRecordsPagedFromDb");
return false;
}
}
std::pair<std::string, uint64_t> MemcpyOverallDatabaseAccesser::ParseOperationAndSizeFromJson(
const std::string &jsonStr) {
std::string error;
const auto json = JsonUtil::TryParse(jsonStr, error);
if (!json.has_value() || !error.empty()) {
return {"", 0};
}
std::string operation;
if (JsonUtil::IsJsonKeyValid(json.value(), "operation")) {
operation = JsonUtil::GetString(json.value(), "operation");
} else if (JsonUtil::IsJsonKeyValid(json.value(), "Operation")) {
operation = JsonUtil::GetString(json.value(), "Operation");
} else if (JsonUtil::IsJsonKeyValid(json.value(), "OPERATION")) {
operation = JsonUtil::GetString(json.value(), "OPERATION");
}
uint64_t size = 0;
if (JsonUtil::IsJsonKeyValid(json.value(), "size(B)")) {
size = JsonUtil::GetInteger(json.value(), "size(B)");
}
return {operation, size};
}
}