* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "pch.h"
#include "SummaryDef.h"
#include "OperatorProtocolRequest.h"
#include "OperatorGroupConverter.h"
#include "OperatorProtocolResponse.h"
#include "OperatorProtocol.h"
#include "WsSessionManager.h"
#include "OperatorProtocolEvent.h"
#include "TraceTime.h"
#include "DataBaseManager.h"
#include "TrackInfoManager.h"
#include "DbSummaryDataBase.h"
namespace Dic::Module::FullDb {
using namespace Server;
using namespace Dic::Module::Timeline;
bool DbSummaryDataBase::OpenDb(const std::string &dbPath, bool clearAllTable) {
auto result =
Database::OpenDb(dbPath, clearAllTable) && QueryMetaVersion() && AddCommunicationOpTableOpTypeIfNotExists();
if (CheckColumnExist(TABLE_COMPUTE_TASK_INFO, "blockNum")) {
blockNumColumnName = "blockNum";
} else {
blockNumColumnName = "blockDim";
}
return result;
}
bool DbSummaryDataBase::QueryComputeOpDetail(
Protocol::ComputeDetailParams params, std::vector<Protocol::ComputeDetail> &computeDetails) {
std::string sql = GenComputeSql(params);
std::string timeFlag = params.timeFlag;
uint64_t startTime = NumberUtil::CeilingClamp(Timeline::TraceTime::Instance().GetStartTime(), (uint64_t)INT64_MAX);
double offset = (params.currentPage - 1) * params.pageSize;
sqlite3_stmt *stmt = nullptr;
int index = bindStartIndex;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Query operator detail failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
sqlite3_bind_int64(stmt, index++, startTime);
sqlite3_bind_text(stmt, index++, params.timeFlag.c_str(), params.timeFlag.length(), nullptr);
sqlite3_bind_double(stmt, index++, params.pageSize);
sqlite3_bind_double(stmt, index++, offset);
std::vector<Protocol::ComputeDetail> computeVec;
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
Protocol::ComputeDetail computeDetail{};
computeDetail.name = sqlite3_column_string(stmt, col++);
computeDetail.type = sqlite3_column_string(stmt, col++);
computeDetail.startTime = sqlite3_column_string(stmt, col++);
computeDetail.duration = sqlite3_column_double(stmt, col++);
computeDetail.waitTime = sqlite3_column_double(stmt, col++);
computeDetail.blockNum = sqlite3_column_int64(stmt, col++);
computeDetail.inputShapes = sqlite3_column_string(stmt, col++);
computeDetail.inputDataTypes = sqlite3_column_string(stmt, col++);
computeDetail.inputFormats = sqlite3_column_string(stmt, col++);
computeDetail.outputShapes = sqlite3_column_string(stmt, col++);
computeDetail.outputDataTypes = sqlite3_column_string(stmt, col++);
computeDetail.outputFormats = sqlite3_column_string(stmt, col++);
computeVec.emplace_back(computeDetail);
}
computeDetails = computeVec;
sqlite3_finalize(stmt);
return true;
}
std::string DbSummaryDataBase::GenComputeSql(const Protocol::ComputeDetailParams &request) {
std::string sql = "SELECT NAME.value AS name, "
"OP_TYPE.value as type, "
"CASE WHEN startNs == 0 THEN 0 ELSE ROUND((startNs - ?) /(1000.0 * 1000.0), 4) END AS startTime, "
"ROUND((endNs - startNs)/1000.0, 2) as duration, "
"ROUND((waitNs)/1000.0, 3) as waitTime, " +
blockNumColumnName +
" as blockNum, "
"INPUTSHAPES.value as inputShape, "
"INPUTDATATYPES.value as inputDataType, "
"INPUTFORMATS.value as inputFormat, "
"OUTPUTSHAPES.value as outputShape, "
"OUTPUTDATATYPES.value as outputDataType, "
"OUTPUTFORMATS.value as outputFormat "
"FROM " +
TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name "
" JOIN STRING_IDS AS OP_TYPE ON OP_TYPE.id = COMPUTE_TASK_INFO.opType "
" JOIN STRING_IDS AS INPUTSHAPES ON INPUTSHAPES.id = COMPUTE_TASK_INFO.inputShapes "
" JOIN STRING_IDS AS INPUTDATATYPES ON INPUTDATATYPES.id = COMPUTE_TASK_INFO.inputDataTypes "
" JOIN STRING_IDS AS INPUTFORMATS ON INPUTFORMATS.id = COMPUTE_TASK_INFO.inputFormats "
" JOIN STRING_IDS AS OUTPUTSHAPES ON OUTPUTSHAPES.id = COMPUTE_TASK_INFO.outputShapes "
" JOIN STRING_IDS AS OUTPUTDATATYPES ON OUTPUTDATATYPES.id = COMPUTE_TASK_INFO.outputDataTypes "
" JOIN STRING_IDS AS OUTPUTFORMATS ON OUTPUTFORMATS.id = COMPUTE_TASK_INFO.outputFormats "
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType "
" WHERE TASKTYPE.value = ? ";
if (!StringUtil::CheckSqlValid(request.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy to generate compute sql.");
} else if (!request.orderBy.empty() && !request.order.empty()) {
sql += " ORDER by " + request.orderBy + " " + (request.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? offset ?";
return sql;
}
bool DbSummaryDataBase::QueryTotalNumByAcceleratorCore(std::string name, int64_t &totalNum) {
sqlite3_stmt *stmt = nullptr;
std::string sql = "SELECT count(*) as nums FROM " + TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE COMPUTE_TASK_INFO.taskType = (select id from STRING_IDS WHERE value = ?)";
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result == SQLITE_OK) {
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, name.c_str(), name.length(), nullptr);
} else {
ServerLog::Error("Get total num failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
totalNum = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::string DbSummaryDataBase::GetGroupNameByIdListStr(const std::string &idListStr) {
if (idListStr.empty()) {
return "";
}
std::string res;
std::vector<std::string> idList = StringUtil::Split(idListStr, "_");
if (idList.size() == 1) {
res = DbTraceDataBase::GetStringCacheValue(path, idList[0]);
} else if (idList.size() > 1) {
res = DbTraceDataBase::GetStringCacheValue(path, idList[0]);
for (size_t i = 1; i < idList.size(); ++i) {
res += "[" + DbTraceDataBase::GetStringCacheValue(path, idList[i]) + "]";
}
}
return res;
}
bool DbSummaryDataBase::QueryOperatorDurationInfo(Protocol::OperatorDurationReqParams &reqParams,
Protocol::QueryType type, std::vector<Protocol::OperatorDurationRes> &data) {
if (!CheckOperatorTableExist(reqParams.group, "operator duration info")) {
return true;
}
std::string sql;
if (type == Protocol::QueryType::CATEGORY) {
sql = GenerateQueryCategoryDurationSql(reqParams);
} else {
sql = GenerateQueryComputeUnitDurationSql(reqParams);
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Duration Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
int deviceId = StringUtil::StringToInt(reqParams.deviceId);
sqlite3_bind_int64(stmt, index++, deviceId);
sqlite3_bind_int64(stmt, index++, reqParams.topK);
std::vector<Protocol::OperatorDurationRes> res;
while (sqlite3_step(stmt) == SQLITE_ROW) {
Protocol::OperatorDurationRes one{};
int col = 0;
one.name = GetGroupNameByIdListStr(sqlite3_column_string(stmt, col++));
one.duration = sqlite3_column_double(stmt, col++);
if (res.size() >= maxCategorySize) {
res[maxCategorySize - 1].name = "Others";
res[maxCategorySize - 1].duration += one.duration;
} else {
res.emplace_back(one);
}
}
data = res;
sqlite3_finalize(stmt);
return true;
}
bool DbSummaryDataBase::ExecSqlGetStatisticInfo(std::string sql, Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorStatisticInfoRes> &res) {
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Duration Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
if (!reqParams.deviceId.empty()) {
int deviceId = StringUtil::StringToInt(reqParams.deviceId);
sqlite3_bind_int64(stmt, index++, deviceId);
}
sqlite3_bind_int64(stmt, index++, reqParams.isCompare ? -1 : reqParams.topK);
BindIdList(reqParams.rangeFilters, stmt, index);
if (!reqParams.isCompare) {
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, reqParams.pageSize * (reqParams.current - 1));
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
Protocol::OperatorStatisticInfoRes one{};
int col = 0;
one.opType = DbTraceDataBase::GetStringCacheValue(GetDbPath(), sqlite3_column_string(stmt, col++));
one.opName = DbTraceDataBase::GetStringCacheValue(GetDbPath(), sqlite3_column_string(stmt, col++));
one.inputShape = DbTraceDataBase::GetStringCacheValue(GetDbPath(), sqlite3_column_string(stmt, col++));
one.accCore = DbTraceDataBase::GetStringCacheValue(GetDbPath(), sqlite3_column_string(stmt, col++));
one.totalTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.count = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.avgTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.maxTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.minTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
res.emplace_back(one);
}
sqlite3_finalize(stmt);
return true;
}
bool DbSummaryDataBase::QueryOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorStatisticInfoResponse &response) {
if (!CheckOperatorTableExist(reqParams.group, "operator statistic info")) {
return true;
}
reqParams.rangeFilters = ConvertFiltersToRangeFilters(reqParams.filters);
if (!QueryStatisticTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of statistic info.");
return false;
}
std::string sql = GenerateQueryStatisticSql(reqParams);
std::vector<Protocol::OperatorStatisticInfoRes> res;
if (!ExecSqlGetStatisticInfo(sql, reqParams, res)) {
ServerLog::Error("Failed to exec query detail sql.");
return false;
}
std::vector<Protocol::OperatorStatisticCmpInfoRes> cmpRes;
for (auto &data : res) {
OperatorStatisticCmpInfoRes tmpInfo;
tmpInfo.compare = data;
cmpRes.emplace_back(tmpInfo);
}
response.data = cmpRes;
return true;
}
bool DbSummaryDataBase::QueryAllOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, std::vector<Protocol::OperatorStatisticInfoRes> &res) {
reqParams.rangeFilters = ConvertFiltersToRangeFilters(reqParams.filters);
std::string sql = GenerateQueryStatisticSql(reqParams);
if (!ExecSqlGetStatisticInfo(sql, reqParams, res)) {
ServerLog::Error("Failed to exec query detail sql.");
return false;
}
return true;
}
std::string DbSummaryDataBase::GenerateQueryStatisticSql(Protocol::OperatorStatisticReqParams &reqParams) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (isCommunication) {
sql = GenStatSqlWithCommunication();
} else {
std::string group = operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP
? "opType, accCore"
: "opName, inputShape, accCore";
if (!reqParams.deviceId.empty()) {
sql = GenStatSqlWithDeviceId(group);
} else {
sql = GenStatSql(group);
}
}
GenerateRangeQueryFiltersSql(reqParams.rangeFilters, sql);
if (!StringUtil::CheckSqlValid(reqParams.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy"
"to generate query statistic sql.");
} else if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + reqParams.orderBy + " " + (reqParams.order == "ascend" ? "ASC" : "DESC");
}
if (!reqParams.isCompare) {
sql += " LIMIT ? OFFSET ?";
}
return sql;
}
std::string DbSummaryDataBase::GenStatSqlWithCommunication() {
return " SELECT * FROM (SELECT "
" COMMUNICATION_OP.opType as opType, "
" COMMUNICATION_OP.opName as opName,"
" NULL AS inputShape,"
" NULL as accCore,"
" ROUND(SUM(COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0, 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0 / COUNT(0), 2) as avgTime,"
" ROUND(max(COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0, 2) as maxTime,"
" ROUND(min(COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0, 2) as minTime "
" FROM COMMUNICATION_OP "
" JOIN (SELECT DISTINCT deviceId, connectionId FROM " +
TABLE_TASK +
") "
" AS NTASK ON NTASK.connectionId = COMMUNICATION_OP.connectionId "
" WHERE NTASK.deviceId = ? "
" GROUP BY opType "
" ORDER by totalTime DESC LIMIT ?) subquery ";
}
std::string DbSummaryDataBase::GenStatSqlWithDeviceId(const std::string group) {
return " SELECT * FROM ("
" SELECT "
" COMPUTE_TASK_INFO.opType as opType,"
" COMPUTE_TASK_INFO.name AS opName,"
" COMPUTE_TASK_INFO.inputShapes AS inputShape,"
" COMPUTE_TASK_INFO.taskType as accCore, "
" ROUND(SUM(TASK.endNs - TASK.startNs) / 1000.0, 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(endNs - startNs) / 1000.0 / COUNT(0), 2) as avgTime,"
" ROUND(max(endNs - startNs) / 1000.0, 2) as maxTime,"
" ROUND(min(endNs - startNs) / 1000.0, 2) as minTime"
" FROM COMPUTE_TASK_INFO "
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE TASK.deviceId = ? "
" GROUP BY " +
group +
" ORDER by totalTime DESC LIMIT ? "
" ) subquery ";
}
std::string DbSummaryDataBase::GenStatSql(const std::string group) {
return " SELECT * FROM ("
" COMPUTE_TASK_INFO.opType as opType,"
" COMPUTE_TASK_INFO.name AS opName,"
" COMPUTE_TASK_INFO.inputShapes AS inputShape,"
" COMPUTE_TASK_INFO.taskType as accCore, "
" ROUND(SUM(TASK.endNs - TASK.startNs) / 1000.0, 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(endNs - startNs) / 1000.0 / COUNT(0), 2) as avgTime,"
" ROUND(max(endNs - startNs) / 1000.0, 2) as maxTime,"
" ROUND(min(endNs - startNs) / 1000.0, 2) as minTime"
" FROM COMPUTE_TASK_INFO "
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" GROUP BY " +
group +
" ORDER by totalTime DESC LIMIT ? "
" ) subquery ";
}
bool DbSummaryDataBase::QueryStatisticTotalNum(Protocol::OperatorStatisticReqParams &reqParams, int64_t &total) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
std::string sql;
if (Protocol::OperatorGroupConverter::IsCommunication(reqParams.group)) {
sql = GenStatSqlWithCommunication();
} else {
std::string group = operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP
? "opType, accCore"
: "opName, inputShapes, accCore";
sql = " SELECT COUNT(*) as nums"
" FROM ( "
" SELECT deviceId, startNs, endNs,"
" COMPUTE_TASK_INFO.taskType AS accCore, "
" COMPUTE_TASK_INFO.opType AS opType, COMPUTE_TASK_INFO.name AS opName FROM COMPUTE_TASK_INFO"
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE TASK.deviceId = ? "
" GROUP by " +
group +
" ORDER by ROUND(SUM(TASK.endNs - TASK.startNs) / 1000.0, 2) DESC LIMIT ?"
" ) subquery";
}
GenerateRangeQueryFiltersSql(reqParams.rangeFilters, sql);
if (Protocol::OperatorGroupConverter::IsCommunication(reqParams.group)) {
sql = StringUtil::FormatString("SELECT COUNT(*) FROM ({})", sql);
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Duration Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_int64(stmt, index++, StringUtil::StringToInt(reqParams.deviceId));
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindIdList(reqParams.rangeFilters, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
bool DbSummaryDataBase::QueryOperatorDetailInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorDetailInfoResponse &response) {
if (!CheckOperatorTableExist(reqParams.group, "operator detail info")) {
return true;
}
if (!QueryDetailTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of detail info.");
return false;
}
std::string sql = GenerateQueryDetailSql(reqParams);
std::vector<Protocol::OperatorDetailInfoRes> sqlRes;
if (!ExecSqlGetDetailInfo(sql, reqParams, sqlRes)) {
ServerLog::Error("Failed to exec query detail sql.");
return false;
}
std::vector<Protocol::OperatorDetailCmpInfoRes> resultData;
for (auto &data : sqlRes) {
OperatorDetailCmpInfoRes tmpInfo;
tmpInfo.compare = data;
resultData.emplace_back(tmpInfo);
}
response.pmuHeaders = FetchPmuColumnNames();
response.data = resultData;
response.level = OperatorGetLevel(sqlRes);
return true;
}
bool DbSummaryDataBase::QueryAllOperatorDetailInfo(Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorDetailInfoRes> &res, std::string &level) {
if (!CheckOperatorTableExist(reqParams.group, "all operator detail info")) {
return true;
}
std::string sql = GenerateAllQueryDetailSql(reqParams);
if (!ExecSqlGetDetailInfo(sql, reqParams, res)) {
ServerLog::Error("Failed to exec query detail sql.");
return false;
} else {
level = OperatorGetLevel(res);
}
return true;
}
bool DbSummaryDataBase::CheckOperatorTableExist(const std::string &group, const std::string &logInfo) {
if (OperatorGroupConverter::IsCommunication(group) && !CheckTableExist(TABLE_COMMUNICATION_OP)) {
Server::ServerLog::Warn(
"Missing table % on querying %, nothing will be done.", TABLE_COMMUNICATION_OP, logInfo);
return false;
}
if (OperatorGroupConverter::IsCommunication(group) && !CheckTableExist(TABLE_COMPUTE_TASK_INFO)) {
Server::ServerLog::Warn(
"Missing table % on querying %, nothing will be done.", TABLE_COMPUTE_TASK_INFO, logInfo);
return false;
}
return true;
}
bool DbSummaryDataBase::ExecSqlGetDetailInfo(std::string sql, Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorDetailInfoRes> &res) {
if (sql.empty()) {
ServerLog::Error("Failed to generate query statistic sql.");
return false;
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Detail Info. Msg:", sqlite3_errmsg(db), " ", result);
return false;
}
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
int index = bindStartIndex;
sqlite3_bind_int64(stmt, index++, NumberUtil::CeilingClamp(startTime, (uint64_t)INT64_MAX));
if (!isCommunication && !reqParams.deviceId.empty()) {
sqlite3_bind_int64(stmt, index++, StringUtil::StringToInt(reqParams.deviceId));
}
sqlite3_bind_int64(stmt, index++, reqParams.isCompare ? -1 : reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
if (!reqParams.isCompare) {
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, (reqParams.current - 1) * reqParams.pageSize);
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
OperatorDetailInfoRes one = GetOperatorDetailRow(stmt);
res.emplace_back(one);
}
sqlite3_finalize(stmt);
return true;
}
OperatorDetailInfoRes DbSummaryDataBase::GetOperatorDetailRow(sqlite3_stmt *stmt) {
int col = 0;
OperatorDetailInfoRes one{};
one.rankId = sqlite3_column_string(stmt, col++);
one.stepId = sqlite3_column_string(stmt, col++);
one.name = sqlite3_column_string(stmt, col++);
one.type = sqlite3_column_string(stmt, col++);
one.accCore = sqlite3_column_string(stmt, col++);
one.startTime = sqlite3_column_string(stmt, col++);
one.duration = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
std::string waitTime = Sqlite3ColumnConvertStrReturnNull(SQLITE_FLOAT, stmt, col++);
one.waitTime = waitTime == "NULL" ? "Loading" : waitTime;
one.blockNum = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.inputShape = sqlite3_column_string(stmt, col++);
one.inputType = sqlite3_column_string(stmt, col++);
one.inputFormat = sqlite3_column_string(stmt, col++);
one.outputShape = sqlite3_column_string(stmt, col++);
one.outputType = sqlite3_column_string(stmt, col++);
one.outputFormat = sqlite3_column_string(stmt, col++);
for (const auto &pmuCol : pmuColumns_) {
one.pmuDatas[pmuCol] = sqlite3_column_string(stmt, col++);
}
return one;
}
bool DbSummaryDataBase::QueryMoreInfoTotalNum(OperatorMoreInfoReqParams &reqParams, int64_t &total) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (isCommunication) {
sql = "SELECT COUNT(*) as nums FROM ( SELECT TYPE.value as type, Name.value as name FROM COMMUNICATION_OP"
" JOIN STRING_IDS AS NAME ON NAME.id = COMMUNICATION_OP.opName "
" JOIN STRING_IDS AS TYPE ON TYPE.id = COMMUNICATION_OP.opType"
" WHERE type = ? ) subquery";
} else {
GenerateMoreInfoTotalNumForOther(sql, operatorGroup);
}
GenerateQueryFiltersSql<OperatorMoreInfoReqParams>(reqParams, sql);
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get More Total Num. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
if (!isCommunication) {
int deviceId = StringUtil::StringToInt(reqParams.deviceId);
sqlite3_bind_int64(stmt, index++, deviceId);
sqlite3_bind_text(stmt, index++, reqParams.accCore.c_str(), -1, SQLITE_TRANSIENT);
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP) {
sqlite3_bind_text(stmt, index++, reqParams.opType.c_str(), -1, SQLITE_TRANSIENT);
} else {
sqlite3_bind_text(stmt, index++, reqParams.opName.c_str(), -1, SQLITE_TRANSIENT);
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_NAME_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::OP_INPUT_SHAPE_GROUP) {
sqlite3_bind_text(stmt, index++, reqParams.shape.c_str(), -1, SQLITE_TRANSIENT);
}
BindQueryFilters(reqParams, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::string DbSummaryDataBase::GenerateQueryMoreInfoSql(OperatorMoreInfoReqParams &reqParams) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (isCommunication) {
sql = GenerateQueryMoreInfoSqlForHCCL(sql);
} else {
sql = GenerateQueryMoreInfoSqlForOther(sql);
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP) {
sql += " WHERE type = ?";
} else {
sql += " WHERE name = ? ";
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_NAME_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::OP_INPUT_SHAPE_GROUP) {
sql += "AND inputShape = ?";
}
if (!GenerateQueryMoreInfoFilters(reqParams, sql)) {
return "";
}
if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + (reqParams.orderBy == "blockNum" ? blockNumColumnName : reqParams.orderBy) + " " +
(reqParams.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? OFFSET ?";
return sql;
}
bool DbSummaryDataBase::QueryOperatorMoreInfo(
OperatorMoreInfoReqParams &reqParams, OperatorMoreInfoResponse &response) {
if (!QueryMoreInfoTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of more info.");
return false;
}
std::string sql = GenerateQueryMoreInfoSql(reqParams);
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Op More Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
BindSqliteParam(stmt, reqParams);
std::vector<OperatorDetailInfoRes> res;
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = 0;
OperatorDetailInfoRes one{};
one.rankId = sqlite3_column_string(stmt, col++);
one.stepId = sqlite3_column_string(stmt, col++);
one.name = sqlite3_column_string(stmt, col++);
one.type = sqlite3_column_string(stmt, col++);
one.accCore = sqlite3_column_string(stmt, col++);
one.startTime = sqlite3_column_string(stmt, col++);
one.duration = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
std::string waitTime = Sqlite3ColumnConvertStrReturnNull(SQLITE_FLOAT, stmt, col++);
one.waitTime = waitTime == "NULL" ? "Loading" : waitTime;
one.blockNum = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.inputShape = sqlite3_column_string(stmt, col++);
one.inputType = sqlite3_column_string(stmt, col++);
one.inputFormat = sqlite3_column_string(stmt, col++);
one.outputShape = sqlite3_column_string(stmt, col++);
one.outputType = sqlite3_column_string(stmt, col++);
one.outputFormat = sqlite3_column_string(stmt, col++);
for (const auto &pmuCol : pmuColumns_) {
one.pmuDatas[pmuCol] = sqlite3_column_string(stmt, col++);
}
res.emplace_back(one);
}
response.level = OperatorGetLevel(res);
response.data = res;
response.pmuHeaders = pmuColumns_;
sqlite3_finalize(stmt);
return true;
}
void DbSummaryDataBase::BindSqliteParam(sqlite3_stmt *stmt, Protocol::OperatorMoreInfoReqParams &reqParams) {
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
int index = bindStartIndex;
sqlite3_bind_int64(stmt, index++, NumberUtil::CeilingClamp(startTime, (uint64_t)INT64_MAX));
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
int deviceId = StringUtil::StringToInt(reqParams.deviceId);
sqlite3_bind_int64(stmt, index++, deviceId);
if (!isCommunication) {
sqlite3_bind_text(stmt, index++, reqParams.accCore.c_str(), -1, SQLITE_TRANSIENT);
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP) {
sqlite3_bind_text(stmt, index++, reqParams.opType.c_str(), -1, SQLITE_TRANSIENT);
} else {
sqlite3_bind_text(stmt, index++, reqParams.opName.c_str(), -1, SQLITE_TRANSIENT);
}
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_NAME_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::OP_INPUT_SHAPE_GROUP) {
sqlite3_bind_text(stmt, index++, reqParams.shape.c_str(), -1, SQLITE_TRANSIENT);
}
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, (reqParams.current - 1) * reqParams.pageSize);
}
bool DbSummaryDataBase::QueryCommunicationOpDetail(
Protocol::CommunicationDetailParams params, std::vector<Protocol::CommunicationDetail> &commDetails) {
std::string sql = GetCommSql(params);
std::string timeFlag = params.timeFlag;
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
double offset = (params.currentPage - 1) * params.pageSize;
sqlite3_stmt *stmt = nullptr;
int index = bindStartIndex;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Query common detail failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
sqlite3_bind_int64(stmt, index++, NumberUtil::CeilingClamp(startTime, (uint64_t)INT64_MAX));
sqlite3_bind_text(stmt, index++, params.timeFlag.c_str(), params.timeFlag.length(), nullptr);
sqlite3_bind_double(stmt, index++, params.pageSize);
sqlite3_bind_double(stmt, index++, offset);
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
CommunicationDetail computeDetail{};
computeDetail.name = sqlite3_column_string(stmt, col++);
computeDetail.type = sqlite3_column_string(stmt, col++);
computeDetail.startTime = sqlite3_column_string(stmt, col++);
computeDetail.duration = sqlite3_column_double(stmt, col++);
computeDetail.waitTime = sqlite3_column_double(stmt, col++);
commDetails.emplace_back(computeDetail);
}
sqlite3_finalize(stmt);
return true;
}
std::string DbSummaryDataBase::GetCommSql(const CommunicationDetailParams &request) {
std::string sql = "SELECT name, op_type as type, CASE WHEN start_time == 0 THEN 'NA' "
"ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 4) END AS startTime, "
"ROUND(duration, 4) as duration, ROUND(waitTime, 4) as waitTime FROM ( "
" SELECT NAME.value AS name, OPTYPE.value AS op_type, "
" startNs as start_time, ROUND((endNs - startNs)/1000.0, 3) as duration, "
" TASKTYPE.value AS taskTypes, 0 as waitTime FROM " +
TABLE_COMPUTE_TASK_INFO +
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name"
" JOIN STRING_IDS AS OPTYPE ON OPTYPE.id = COMPUTE_TASK_INFO.opType"
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType"
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE taskTypes = ?"
" GROUP BY TASK.globalTaskId"
" ) subquery ";
if (!StringUtil::CheckSqlValid(request.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy to get common sql.");
} else if (!request.orderBy.empty() && !request.order.empty()) {
sql += " ORDER by " + request.orderBy + " " + (request.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? offset ?";
return sql;
}
std::string DbSummaryDataBase::GenerateQueryCategoryDurationSql(Protocol::OperatorDurationReqParams &reqParams) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
if (operatorGroup == OperatorGroupConverter::OperatorGroup::UNKNOWN) {
ServerLog::Error("Generate query category duration sql failed, unknown operator group.");
return "";
}
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
if (isCommunication) {
return GenerateQueryCategoryDurationSqlForHCCL(operatorGroup);
} else {
std::string group;
std::string name;
std::string duration;
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP) {
name = "COMPUTE_TASK_INFO.opType";
group = " GROUP by op_type, task_type";
duration = " ROUND(sum(endNs - startNs)/1000.0, 2) as duration";
} else if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_NAME_GROUP) {
name = "COMPUTE_TASK_INFO.name";
group = "";
duration = " ROUND((endNs - startNs)/1000.0, 2) as duration";
} else {
name = R"(COMPUTE_TASK_INFO.name || '_' || COMPUTE_TASK_INFO.inputShapes)";
group = " GROUP by name, input_shapes, task_type";
duration = " ROUND(sum(endNs - startNs)/1000.0, 2) as duration";
}
std::string sql = " SELECT name, duration From ("
" SELECT " +
name + " as name," + duration +
" ,COMPUTE_TASK_INFO.taskType as task_type, COMPUTE_TASK_INFO.opType as op_type, "
"COMPUTE_TASK_INFO.inputShapes as input_shapes"
" FROM " +
TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE TASK.deviceId = ? " +
group +
" ORDER BY duration DESC LIMIT ?"
" ) subquery";
return sql;
}
}
std::string DbSummaryDataBase::GenerateQueryComputeUnitDurationSql(Protocol::OperatorDurationReqParams &reqParams) {
std::string group;
if (reqParams.group == OP_TYPE_GROUP) {
group = "COMPUTE_TASK_INFO.opType, COMPUTE_TASK_INFO.taskType";
} else if (reqParams.group == OPERATOR_GROUP) {
group = "COMPUTE_TASK_INFO.name, COMPUTE_TASK_INFO.taskType";
} else {
group = "COMPUTE_TASK_INFO.name, COMPUTE_TASK_INFO.inputShapes, COMPUTE_TASK_INFO.taskType";
}
std::string sql = " SELECT taskTypes as name, ROUND(SUM(duration), 2) as duration"
" FROM ("
" SELECT " +
group + ", COMPUTE_TASK_INFO.taskType as taskTypes, " +
(reqParams.group == Protocol::OPERATOR_GROUP ? " ROUND((endNs - startNs)/1000.0, 2) as duration"
: " ROUND(sum(endNs - startNs)/1000.0, 2) as duration") +
" FROM " + TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" WHERE TASK.deviceId = ?"
" " +
(reqParams.group == Protocol::OPERATOR_GROUP ? "" : " GROUP BY " + group) +
" ORDER BY duration DESC LIMIT ?"
" ) subquery" +
" GROUP by taskTypes"
" ORDER BY duration DESC";
return sql;
}
bool DbSummaryDataBase::QueryDetailTotalNum(OperatorStatisticReqParams &reqParams, int64_t &total) {
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (isCommunication) {
sql = GenerateQueryDetailSqlForHCCL(sql);
} else {
sql = " SELECT COUNT(*) as nums"
" FROM ("
" SELECT ROUND((endNs - startNs)/1000.0, 3) as duration, " +
blockNumColumnName +
" , deviceId, streamId as step_id,NAME.value AS name,"
" OPTYPE.value AS type,TASKTYPE.value as accCore, startNs as startTime"
" FROM " +
TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType"
" JOIN STRING_IDS AS OPTYPE ON OPTYPE.id = COMPUTE_TASK_INFO.opType"
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name"
" WHERE TASK.deviceId = ? "
" ORDER BY duration DESC LIMIT ?"
" ) subquery";
}
GenerateQueryFiltersSql<OperatorStatisticReqParams>(reqParams, sql);
if (isCommunication) {
sql = StringUtil::FormatString("SELECT COUNT(*) FROM ({})", sql);
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Detail Total Num. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
int deviceId = StringUtil::StringToInt(reqParams.deviceId);
sqlite3_bind_int64(stmt, index++, deviceId);
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::set<std::string> DbSummaryDataBase::FetchPmuColumnNames() {
if (!CheckTableExist(TABLE_TASK_PMU_INFO) || !pmuColumns_.empty()) {
return pmuColumns_;
}
std::string queryColumnSql = "SELECT STRING_IDS.value "
"FROM STRING_IDS "
"WHERE STRING_IDS.id IN ( "
" SELECT name "
" FROM TASK_PMU_INFO "
" WHERE globalTaskId = ( "
" SELECT globalTaskId "
" FROM TASK_PMU_INFO "
" ORDER BY RANDOM() "
" LIMIT 1 "
" ) "
");";
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, queryColumnSql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get pmu cols Info. Msg:", sqlite3_errmsg(db), " ", result);
return pmuColumns_;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
std::string colName = sqlite3_column_string(stmt, 0);
if (!RegexUtil::RegexMatch(colName, Summary::PMU_HEADER_WHITE_LIST_REG)) {
sqlite3_finalize(stmt);
ServerLog::Error("There is an SQL injection attack on colName. error colName: %", colName);
return {};
}
pmuColumns_.insert(colName);
}
sqlite3_finalize(stmt);
return pmuColumns_;
}
std::string DbSummaryDataBase::CreatPMUTmpTableSql(const std::set<std::string> &cols) {
if (cols.empty()) {
return "";
}
std::string convertPmuDataSql = "SELECT i.globalTaskId ";
for (auto const &col : cols) {
convertPmuDataSql += ", MAX(CASE WHEN s.value = \'" + col + "\' THEN i.value END) AS " + col;
}
convertPmuDataSql += " FROM TASK_PMU_INFO i "
" JOIN STRING_IDS s ON i.name = s.id "
" GROUP BY i.globalTaskId ";
return " LEFT JOIN ( " + convertPmuDataSql + " ) AS PMU ON COMPUTE_TASK_INFO.globalTaskId = PMU.globalTaskId ";
}
std::string DbSummaryDataBase::GetPMUTmpTableColSql(const std::set<std::string> &cols) {
if (cols.empty()) {
return "";
}
std::vector<std::string> tmpCols;
tmpCols.reserve(cols.size());
for (const std::string &col : cols) {
tmpCols.push_back("PMU." + col);
}
return "," + StringUtil::join(tmpCols, ",");
}
std::string DbSummaryDataBase::GenerateQueryDetailSqlForOperator() {
std::set<std::string> pmuClos = FetchPmuColumnNames();
std::string sql =
" SELECT deviceId, step_id, name, type, accCore,"
" CASE WHEN startTime == 0 THEN 'NA' ELSE ROUND((startTime - ?) / (1000.0 * 1000.0), 2) END AS startTime, "
" duration, waitTime, " +
blockNumColumnName +
","
" inputShape, inputType, inputFormat, outputShape, outputType, outputFormat " +
JoinExtraColName(std::vector<std::string>(pmuClos.begin(), pmuClos.end())) +
" FROM ("
" SELECT " +
blockNumColumnName +
", deviceId, streamId as step_id,NAME.value AS name,"
" OPTYPE.value AS type,TASKTYPE.value as accCore, startNs as startTime, "
" ROUND((endNs - startNs)/1000.0, 3) as duration, ROUND(waitNs/1000.0, 3) as waitTime, "
" INPUTSHAPES.value as inputShape, INPUTDATATYPES.value as inputType, "
" INPUTFORMATS.value as inputFormat, OUTPUTSHAPES.value as outputShape, "
" OUTPUTDATATYPES.value as outputType, OUTPUTFORMATS.value as outputFormat " +
GetPMUTmpTableColSql(pmuClos) + " FROM " + TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name"
" JOIN STRING_IDS AS OPTYPE ON OPTYPE.id = COMPUTE_TASK_INFO.opType"
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType"
" JOIN STRING_IDS AS INPUTSHAPES ON INPUTSHAPES.id = COMPUTE_TASK_INFO.inputShapes"
" JOIN STRING_IDS AS INPUTDATATYPES ON INPUTDATATYPES.id = COMPUTE_TASK_INFO.inputDataTypes"
" JOIN STRING_IDS AS INPUTFORMATS ON INPUTFORMATS.id = COMPUTE_TASK_INFO.inputFormats"
" JOIN STRING_IDS AS OUTPUTSHAPES ON OUTPUTSHAPES.id = COMPUTE_TASK_INFO.outputShapes"
" JOIN STRING_IDS AS OUTPUTDATATYPES ON OUTPUTDATATYPES.id = COMPUTE_TASK_INFO.outputDataTypes"
" JOIN STRING_IDS AS OUTPUTFORMATS ON OUTPUTFORMATS.id = COMPUTE_TASK_INFO.outputFormats " +
CreatPMUTmpTableSql(pmuClos) +
" WHERE TASK.deviceId = ? "
" ORDER by duration DESC LIMIT ? ) subquery ";
return sql;
}
std::string DbSummaryDataBase::GenerateAllQueryDetailSql(OperatorStatisticReqParams &reqParams) {
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (isCommunication) {
sql = GenerateQueryDetailSqlForHCCL(sql);
} else {
sql = GenerateQueryDetailSqlForOperator();
}
GenerateQueryFiltersSql<OperatorStatisticReqParams>(reqParams, sql);
if (!StringUtil::CheckSqlValid(reqParams.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy"
"to generate all query detail sql.");
} else if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + (reqParams.orderBy == "blockNum" ? blockNumColumnName : reqParams.orderBy) + " " +
(reqParams.order == "ascend" ? "ASC" : "DESC");
}
return sql;
}
std::string DbSummaryDataBase::GenerateQueryDetailSql(OperatorStatisticReqParams &reqParams) {
std::string sql = GenerateAllQueryDetailSql(reqParams);
if (std::empty(sql)) {
return sql;
}
sql += " LIMIT ? OFFSET ?";
return sql;
}
std::string DbSummaryDataBase::GenerateQueryCategoryDurationSqlForHCCL(
const OperatorGroupConverter::OperatorGroup &operatorGroup) {
std::string group;
std::string name;
std::string duration;
if (operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP) {
name = " COMMUNICATION_OP.opType as name ";
group = " GROUP by name ";
duration = " ROUND(sum(COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs)/1000.0, 2) as duration";
} else if (operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_NAME_GROUP) {
name = "COMMUNICATION_OP.opName as name";
group = "";
duration = " ROUND((COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs)/1000.0, 2) as duration";
}
std::string sql = " SELECT name, duration From ("
" SELECT " +
name + "," + duration + " FROM " + TABLE_COMMUNICATION_OP +
" JOIN (SELECT DISTINCT deviceId, connectionId FROM " + TABLE_TASK +
") "
" AS NTASK ON NTASK.connectionId = COMMUNICATION_OP.connectionId "
" WHERE NTASK.deviceId = ? " +
group +
" ORDER BY duration DESC LIMIT ?"
" ) subquery";
return sql;
}
std::string &DbSummaryDataBase::GenerateQueryMoreInfoSqlForHCCL(std::string &sql) const {
sql = " SELECT rank_id, step_id, name, type, accCore,"
" CASE WHEN startTime == 0 THEN 'NA' ELSE ROUND((startTime - ?) / (1000.0 * 1000.0), 2)"
" END AS startTime, duration, waitTime, NULL AS " +
blockNumColumnName +
", NULL AS inputShape,"
" NULL AS inputType, NULL AS inputFormat, NULL AS outputShape, NULL AS outputType,"
" NULL AS outputFormat"
" FROM ("
" SELECT NULL as rank_id, NULL as step_id, NAME.value AS name,"
" TYPE.value as type,"
" NULL as accCore,COMMUNICATION_OP.startNs as startTime,"
" ROUND((COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs)/1000.0, 3) as duration,"
" ROUND(COMMUNICATION_OP.waitNs/1000.0, 3) as waitTime FROM COMMUNICATION_OP"
" JOIN STRING_IDS AS NAME ON NAME.id = COMMUNICATION_OP.opName"
" JOIN STRING_IDS AS TYPE ON TYPE.id = COMMUNICATION_OP.opType"
" JOIN (SELECT DISTINCT deviceId, connectionId FROM " +
TABLE_TASK +
") "
" AS NTASK ON NTASK.connectionId = COMMUNICATION_OP.connectionId "
" WHERE NTASK.deviceId = ? "
" ORDER by duration DESC ) subquery ";
return sql;
}
std::string &DbSummaryDataBase::GenerateQueryMoreInfoSqlForOther(std::string &sql) {
std::set<std::string> pmuClos = FetchPmuColumnNames();
sql = " SELECT device_id, step_id, name, type, accCore,"
" CASE WHEN startTime == 0 THEN 'NA' ELSE ROUND((startTime - ?) / (1000.0 * 1000.0), 2)"
" END AS startTime, duration, waitTime, " +
blockNumColumnName +
","
" inputShape, inputType, inputFormat, outputShape, outputType, outputFormat " +
JoinExtraColName(std::vector<std::string>(pmuClos.begin(), pmuClos.end())) +
" FROM ("
" SELECT " +
blockNumColumnName +
", deviceId as device_id, streamId as step_id, "
" NAME.value AS name, OPTYPE.value AS type,"
" TASKTYPE.value as accCore,startNs as startTime,ROUND((endNs - startNs)/1000.0, 3) as duration,"
" ROUND((waitNs)/1000.0, 3) as waitTime, INPUTSHAPES.value as inputShape, "
" INPUTDATATYPES.value as inputType, "
" INPUTFORMATS.value as inputFormat, OUTPUTSHAPES.value as outputShape, "
" OUTPUTDATATYPES.value as outputType, OUTPUTFORMATS.value as outputFormat " +
GetPMUTmpTableColSql(pmuClos) + " FROM " + TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId "
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name"
" JOIN STRING_IDS AS OPTYPE ON OPTYPE.id = COMPUTE_TASK_INFO.opType"
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType"
" JOIN STRING_IDS AS INPUTSHAPES ON INPUTSHAPES.id = COMPUTE_TASK_INFO.inputShapes"
" JOIN STRING_IDS AS INPUTDATATYPES ON INPUTDATATYPES.id = COMPUTE_TASK_INFO.inputDataTypes"
" JOIN STRING_IDS AS INPUTFORMATS ON INPUTFORMATS.id = COMPUTE_TASK_INFO.inputFormats"
" JOIN STRING_IDS AS OUTPUTSHAPES ON OUTPUTSHAPES.id = COMPUTE_TASK_INFO.outputShapes"
" JOIN STRING_IDS AS OUTPUTDATATYPES ON OUTPUTDATATYPES.id = COMPUTE_TASK_INFO.outputDataTypes"
" JOIN STRING_IDS AS OUTPUTFORMATS ON OUTPUTFORMATS.id = COMPUTE_TASK_INFO.outputFormats " +
CreatPMUTmpTableSql(pmuClos) +
" WHERE TASK.deviceId = ? AND accCore = ?"
" ORDER by duration DESC ) subquery ";
return sql;
}
std::string &DbSummaryDataBase::GenerateQueryDetailSqlForHCCL(std::string &sql) const {
sql = " SELECT rank_id, step_id, name, type, accCore,"
" CASE WHEN startTime == 0 THEN 'NA' ELSE ROUND((startTime - ?) / (1000.0 * 1000.0), 2)"
" END AS startTime, duration, waitTime, NULL AS " +
blockNumColumnName +
", NULL AS inputShape,"
" NULL AS inputType, NULL AS inputFormat, NULL AS outputShape, NULL AS outputType,"
" NULL AS outputFormat FROM ("
" SELECT NULL as rank_id, NULL as step_id, NAME.value AS name, TYPE.value AS type, "
" NULL as accCore,COMMUNICATION_OP.startNs as startTime,"
" ROUND((COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs)/1000.0, 3) as duration,"
" ROUND(COMMUNICATION_OP.waitNs/1000.0, 3) as waitTime FROM COMMUNICATION_OP"
" JOIN STRING_IDS AS NAME ON NAME.id = COMMUNICATION_OP.opName "
" JOIN STRING_IDS AS TYPE ON TYPE.id = COMMUNICATION_OP.opType"
" JOIN (SELECT DISTINCT deviceId, connectionId FROM " +
TABLE_TASK +
") "
" AS NTASK ON NTASK.connectionId = COMMUNICATION_OP.connectionId "
" ORDER by duration DESC LIMIT ? ) subquery ";
return sql;
}
void DbSummaryDataBase::GenerateMoreInfoTotalNumForOther(
std::string &sql, OperatorGroupConverter::OperatorGroup opGroup) const {
std::string condition = (opGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP)
? " type = ?"
: " name = ? AND inputShape = ?";
sql = " SELECT COUNT(*) as nums FROM ("
" SELECT NAME.value AS name, INPUTSHAPES.value AS inputShape, TASKTYPE.value AS accCore, "
" OPTYPE.value AS type"
" FROM " +
TABLE_COMPUTE_TASK_INFO +
" JOIN TASK ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId"
" JOIN STRING_IDS AS NAME ON NAME.id = COMPUTE_TASK_INFO.name"
" JOIN STRING_IDS AS OPTYPE ON OPTYPE.id = COMPUTE_TASK_INFO.opType"
" JOIN STRING_IDS AS INPUTSHAPES ON INPUTSHAPES.id = COMPUTE_TASK_INFO.inputShapes"
" JOIN STRING_IDS AS TASKTYPE ON TASKTYPE.id = COMPUTE_TASK_INFO.taskType"
" WHERE TASK.deviceId = ? AND accCore = ? AND" +
condition + " ) subquery";
}
std::vector<std::pair<std::string, std::vector<std::string>>> DbSummaryDataBase::ConvertFiltersToRangeFilters(
std::vector<std::pair<std::string, std::string>> &filters) {
std::vector<std::pair<std::string, std::vector<std::string>>> res;
if (filters.empty()) {
return {};
}
for (const auto &filter : filters) {
if (filter.second.empty()) {
continue;
}
res.emplace_back(filter.first, DbTraceDataBase::GetIdListByFuzzNameFromCache(path, filter.second, false));
}
return res;
}
void DbSummaryDataBase::GenerateRangeQueryFiltersSql(
std::vector<std::pair<std::string, std::vector<std::string>>> &rangeFilters, std::string &sql) {
if (rangeFilters.empty()) {
return;
}
std::vector<std::string> sqlList;
for (const auto &item : rangeFilters) {
sqlList.push_back(item.first + " IN (" + StringUtil::CreateQuestionMarkString(item.second.size()) + ")");
}
sql += " WHERE " + StringUtil::join(sqlList, " AND ");
}
template <typename T> void DbSummaryDataBase::GenerateQueryFiltersSql(T &reqParams, std::string &sql) {
if (reqParams.filters.empty()) {
return;
}
sql += " WHERE ";
for (size_t index = 0; index < reqParams.filters.size(); index++) {
std::pair<std::string, std::string> filter = reqParams.filters[index];
if (index != 0) {
sql += " AND ";
}
sql += filter.first + " LIKE ?";
}
}
void DbSummaryDataBase::BindIdList(
const std::vector<std::pair<std::string, std::vector<std::string>>> &rangeFilters, sqlite3_stmt *stmt, int &index) {
if (rangeFilters.empty()) {
return;
}
for (const auto &item : rangeFilters) {
for (const auto &id : item.second) {
sqlite3_bind_text(stmt, index++, id.c_str(), id.length(), SQLITE_TRANSIENT);
}
}
}
template <typename T> void DbSummaryDataBase::BindQueryFilters(T &reqParams, sqlite3_stmt *stmt, int &index) {
if (reqParams.filters.empty()) {
return;
}
for (uint64_t i = 0; i < reqParams.filters.size(); i++) {
std::pair<std::string, std::string> filter = reqParams.filters[i];
std::string filterParam = "%" + filter.second + "%";
sqlite3_bind_text(stmt, index++, filterParam.c_str(), filterParam.length(), SQLITE_TRANSIENT);
}
}
bool DbSummaryDataBase::GenerateQueryMoreInfoFilters(OperatorMoreInfoReqParams &reqParams, std::string &sql) {
for (const auto &filter : reqParams.filters) {
if (!StringUtil::CheckSqlValid(filter.first) || !StringUtil::CheckSqlValid(filter.second)) {
ServerLog::Error("There is an SQL injection attack on the parameter of filter"
"to generate query more info filters.");
return false;
}
sql += " AND " + filter.first + " LIKE '%" + filter.second + "%' ";
}
return true;
}
void DbSummaryDataBase::ParserEnd(
const std::string &rankId, const std::string &fileId, bool result, const std::string &msg) {
WsSession *session = WsSessionManager::Instance().GetSession();
if (session == nullptr) {
ServerLog::Error("Failed to get session for summary callback.");
return;
}
if (rankId.empty()) {
auto event = std::make_unique<Protocol::ModuleResetEvent>();
event->moduleName = MODULE_OPERATOR;
event->result = true;
event->reset = true;
session->OnEvent(std::move(event));
} else {
auto event = std::make_unique<Protocol::OperatorParseStatusEvent>();
event->moduleName = MODULE_OPERATOR;
event->result = true;
event->data.rankId = rankId;
event->data.status = result;
event->data.error = msg;
event->fileId = fileId;
event->rankList = TrackInfoManager::Instance().GetRankListByFileId(fileId, rankId);
session->OnEvent(std::move(event));
}
}
void DbSummaryDataBase::Reset() {
ServerLog::Info("Summary reset. wait task completed.");
ServerLog::Info("Summary task completed.");
auto databaseList = Timeline::DataBaseManager::Instance().GetAllSummaryDatabase();
for (auto &db : databaseList) {
auto database = dynamic_cast<DbSummaryDataBase *>(db);
if (database != nullptr) {
database->CloseDb();
}
}
Timeline::DataBaseManager::Instance().Clear(Timeline::DatabaseType::SUMMARY);
}
bool DbSummaryDataBase::QueryBandwidthContentionMatMulData(std::vector<BandwidthContentionMatMulInfo> &res) {
std::string sql = "SELECT " + TABLE_STRING_IDS +
".value, ROUND(startNs / 1000.0, 3) AS startTime,"
" ROUND((endNs - startNs) / 1000.0, 3) FROM " +
TABLE_COMPUTE_TASK_INFO + " INNER JOIN " + TABLE_STRING_IDS + " ON " + TABLE_COMPUTE_TASK_INFO +
".name = " + TABLE_STRING_IDS + ".id INNER JOIN " + TABLE_TASK + " ON " + TABLE_COMPUTE_TASK_INFO +
".globalTaskId = " + TABLE_TASK + ".globalTaskId WHERE " + TABLE_STRING_IDS +
".value LIKE '%matmul%' ORDER BY startTime";
return ExecuteQueryBandwidthContentionMatMulData(res, sql);
}
bool DbSummaryDataBase::AddCommunicationOpTableOpTypeIfNotExists() {
if (CheckColumnExist(TABLE_COMMUNICATION_OP, "opType")) {
return true;
}
const std::string opTypeColumnName = "opType";
std::string sql = StringUtil::FormatString(
"ALTER TABLE {} ADD COLUMN {} INTEGER DEFAULT -1;", TABLE_COMMUNICATION_OP, opTypeColumnName);
if (!ExecSql(sql)) {
ServerLog::Error("Failed to add column % to table %", TABLE_COMMUNICATION_OP, opTypeColumnName);
return false;
}
sql = StringUtil::FormatString(
"UPDATE {} "
"SET opType = ( "
" SELECT "
" tid "
" FROM "
" ( "
" SELECT DISTINCT "
" Name.id AS nid, "
" TYPE.id AS tid "
" FROM "
" {} "
" JOIN STRING_IDS AS NAME ON NAME.id = {}.opName "
" JOIN STRING_IDS AS TYPE ON TYPE.value = SUBSTR(NAME.value, 1, INSTR (NAME.value, '__'))) AS SUBVIEW "
"WHERE "
" SUBVIEW.nid == opName)",
TABLE_COMMUNICATION_OP, TABLE_COMMUNICATION_OP, TABLE_COMMUNICATION_OP);
if (!ExecSql(sql)) {
ServerLog::Error("Failed to set opType from opName.");
return false;
}
return true;
}
}