* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <sstream>
#include "DbTraceDataBase.h"
namespace Dic::Module::FullDb {
using namespace Server;
std::string DbTraceDataBase::GetKernelDetailSql(
const KernelDetailsParams &requestParams, const std::string &blockNumColumnName) {
try {
std::ostringstream sqlStream;
sqlStream << "with nameIds as (select id, value as realName from STRING_IDS),\n"
<< GetKernelDetailSqlWithHCCL(requestParams, blockNumColumnName) << ",\n"
<< GetKernelDetailSqlWithoutHCCL(requestParams, blockNumColumnName)
<< ",\n"
<< "main_tmp as (select * from main_hccl UNION ALL select * from main_other), "
<< "main as (SELECT ROWID as id, name, type, acceleratorCore, startTime,\n"
<< "duration, waitTime, " + blockNumColumnName + ", inputShapes, inputDataTypes, inputFormats,\n"
<< "outputShapes, outputDataTypes, outputFormats, taskId FROM main_tmp \n";
if (requestParams.startTime != requestParams.endTime) {
sqlStream << " WHERE (startTime + duration*1000) >= ? AND startTime <= ? ";
}
sqlStream << ") SELECT id, (SELECT COUNT(*) FROM main) as num, name, type, acceleratorCore, startTime, "
<< "duration, waitTime, " + blockNumColumnName + ", inputShapes, inputDataTypes, inputFormats, "
<< "outputShapes, outputDataTypes, outputFormats, taskId FROM main ";
if (!StringUtil::CheckSqlValid(requestParams.orderBy)) {
ServerLog::Error(
"There is an SQL injection attack on this parameter. error param: %", requestParams.orderBy);
} else if (!requestParams.orderBy.empty() && !requestParams.order.empty()) {
sqlStream << " ORDER by "
<< (requestParams.orderBy == "blockNum" ? blockNumColumnName : requestParams.orderBy) << " "
<< (requestParams.order == "ascend" ? "ASC" : "DESC");
}
sqlStream << " LIMIT ? OFFSET ?";
return sqlStream.str();
} catch (DatabaseException &) {
return "";
} catch (const std::exception &e) {
ServerLog::Error("An unexpected exception occurred: %", e.what());
return "";
}
}
std::string DbTraceDataBase::GetKernelDetailFilterSqlWithHCCL(const KernelDetailsParams &requestParams) {
if (requestParams.filters.empty()) {
return "";
}
std::string sql = " WHERE 1";
for (const auto &[key, value] : requestParams.filters) {
(void)value;
if (!StringUtil::CheckSqlValid(key)) {
ServerLog::Error("There is an SQL injection attack on this parameter. param: filter");
throw DatabaseException("filter first value is invalid for sql.");
}
sql += " AND lower(" + key + ") LIKE lower(?) ";
}
return sql;
}
std::string DbTraceDataBase::GetKernelDetailFilterSqlWithoutHCCL(const KernelDetailsParams &requestParams) {
if (requestParams.filters.empty()) {
return "";
}
std::string sql = " WHERE 1";
for (const auto &[key, value] : requestParams.filters) {
(void)value;
if (!StringUtil::CheckSqlValid(key)) {
ServerLog::Error("There is an SQL injection attack on this parameter. param: filter");
throw DatabaseException("filter first value is invalid for sql.");
}
if (key == "name" || key == "taskId") {
sql += " AND lower(" + key + ") LIKE lower(?) ";
} else {
sql += " AND " + key + " IN ( SELECT id FROM STRING_IDS WHERE lower(value) LIKE lower(?) )";
}
}
return sql;
}
std::string DbTraceDataBase::GetKernelDetailSqlWithHCCL(
const KernelDetailsParams &requestParams, const std::string &blockNumColumnName) {
const std::string filterSql = GetKernelDetailFilterSqlWithHCCL(requestParams);
return " main_hccl_tmp as ("
" select info.ROWID, nameIds.realName as name, substr(realName, 0, instr(realName, '__') + 1) as type, "
" 'HCCL' as acceleratorCore, info.startNs as startTime, "
" round((info.endNs - info.startNs)/1000.0, 3) as duration, "
" 0 as " +
blockNumColumnName +
", round(waitNs/1000.0, 3) as waitTime, "
" 'N/A' as inputShapes, 'N/A' as inputDataTypes, 'N/A' as inputFormats, "
" 'N/A' as outputShapes, 'N/A' as outputDataTypes, 'N/A' as outputFormats, "
" TASK.taskId as taskId"
" from COMMUNICATION_OP info JOIN TASK ON info.connectionId = TASK.connectionId "
" join nameIds on opName = nameIds.id group by info.opName "
"), "
"main_hccl as ( select * from main_hccl_tmp " +
filterSql + ") ";
}
std::string DbTraceDataBase::GetKernelDetailSqlWithoutHCCL(
const KernelDetailsParams &requestParams, const std::string &blockNumColumnName) {
const std::string filterSql = GetKernelDetailFilterSqlWithoutHCCL(requestParams);
return " main_other_tmp as ("
" select TASK.ROWID, nameIds.realName as name, opType as type, info.taskType as acceleratorCore,"
" startNs as startTime, round((endNs - startNs)/1000.0, 3) as duration, "
" " +
blockNumColumnName +
", round(waitNs/1000.0, 3) as waitTime, "
" inputShapes, inputDataTypes, inputFormats, outputShapes, outputDataTypes, outputFormats, "
" TASK.taskId as taskId "
" from COMPUTE_TASK_INFO info JOIN TASK ON info.globalTaskId = TASK.globalTaskId "
" join nameIds on name = nameIds.id where deviceId = ?"
"), "
" main_other as (select * from main_other_tmp " +
filterSql + ") ";
}
}