* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "pch.h"
#include "SummaryProtocolRequest.h"
#include "SummaryProtocolResponse.h"
#include "TraceTime.h"
#include "OperatorProtocol.h"
#include "ConstantDefs.h"
#include "DataBaseManager.h"
#include "TextSummaryDataBase.h"
namespace Dic::Module::Summary {
using namespace Server;
TextSummaryDataBase::TextSummaryDataBase(std::recursive_mutex &sqlMutex) : VirtualSummaryDataBase(sqlMutex) {}
TextSummaryDataBase::~TextSummaryDataBase() {
if (hasInitStmt) {
ReleaseStmt();
}
CloseDb();
}
bool TextSummaryDataBase::SetConfig() { return Database::SetConfig(); }
bool TextSummaryDataBase::CreateTable() {
if (!isOpen) {
ServerLog::Error("Failed to set config. Database is not open.");
return false;
}
std::string sql = "CREATE TABLE " + TABLE_KERNEL +
" (id INTEGER PRIMARY KEY AUTOINCREMENT, deviceId TEXT, step_id TEXT, " +
"task_id TEXT, name TEXT, op_type TEXT, accelerator_core TEXT, start_time INTEGER, duration INTEGER, " +
"wait_time INTEGER, block_num INTEGER, input_shapes TEXT, input_data_types TEXT, input_formats TEXT, " +
"output_shapes TEXT, output_data_types TEXT, output_formats TEXT);" + "CREATE INDEX rank_index ON " +
TABLE_KERNEL + " (deviceId);";
std::lock_guard<std::recursive_mutex> lock(mutex);
return ExecSql(sql);
}
bool TextSummaryDataBase::DropTable() {
std::vector<std::string> tables = {TABLE_KERNEL};
std::lock_guard<std::recursive_mutex> lock(mutex);
return DropSomeTables(tables);
}
bool TextSummaryDataBase::InitStmt(const std::vector<std::string> &columns) {
if (hasInitStmt) {
return true;
}
std::string columnSql;
std::string placeHolder;
if (!columns.empty()) {
for (const auto &column : columns) {
columnSql += ", " + column;
placeHolder += ",?";
}
}
std::string sql = "INSERT INTO " + TABLE_KERNEL +
" (deviceId, step_id, task_id, name, op_type, accelerator_core, " +
"start_time, duration, wait_time, block_num, input_shapes, input_data_types, input_formats, " +
"output_shapes, output_data_types, output_formats" + columnSql + ") " +
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?" + placeHolder + ")";
for (size_t i = 0; i < cacheSize - 1; ++i) {
sql.append(",(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?" + placeHolder + ")");
}
if (sqlite3_prepare_v2(db, sql.c_str(), -1, &insertKernelStmt, nullptr) != SQLITE_OK) {
ServerLog::Error("Failed to prepare insert kernel detail statement. error:", sqlite3_errmsg(db));
return false;
}
hasInitStmt = true;
return true;
}
void TextSummaryDataBase::ReleaseStmt() {
if (insertKernelStmt != nullptr) {
sqlite3_finalize(insertKernelStmt);
insertKernelStmt = nullptr;
hasInitStmt = false;
}
}
void TextSummaryDataBase::InsertKernelDetailList(
const std::vector<Kernel> &kernelVec, const std::vector<std::string> &columns) {
sqlite3_stmt *stmt = GetKernelStmt(kernelVec.size(), columns);
if (stmt == nullptr) {
ServerLog::Error("Failed to get kernel stmt.");
return;
}
int idx = bindStartIndex;
for (const auto &event : kernelVec) {
sqlite3_bind_text(stmt, idx++, event.rankId.c_str(), event.rankId.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.stepId.c_str(), event.stepId.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.taskId.c_str(), event.taskId.length(), SQLITE_TRANSIENT),
sqlite3_bind_text(stmt, idx++, event.name.c_str(), event.name.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.type.c_str(), event.type.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.acceleratorCore.c_str(), event.acceleratorCore.length(), SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, idx++, event.startTime);
sqlite3_bind_double(stmt, idx++, event.duration);
sqlite3_bind_double(stmt, idx++, event.waitTime);
sqlite3_bind_int64(stmt, idx++, event.blockNum);
sqlite3_bind_text(stmt, idx++, event.inputShapes.c_str(), event.inputShapes.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.inputDataTypes.c_str(), event.inputDataTypes.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.inputFormats.c_str(), event.inputFormats.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.outputShapes.c_str(), event.outputShapes.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.outputDataTypes.c_str(), event.outputDataTypes.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, idx++, event.outputFormats.c_str(), event.outputFormats.length(), SQLITE_TRANSIENT);
if (columns.empty() || columns.size() != event.utilizationInfo.size()) {
continue;
}
for (const auto &item : event.utilizationInfo) {
sqlite3_bind_text(stmt, idx++, item.c_str(), item.length(), SQLITE_TRANSIENT);
}
}
std::lock_guard<std::recursive_mutex> lock(mutex);
auto result = sqlite3_step(stmt);
if (kernelVec.size() != cacheSize) {
sqlite3_finalize(stmt);
}
if (result != SQLITE_DONE) {
ServerLog::Error("Insert kernel detail fail. ", sqlite3_errmsg(db));
sqlite3_exec(db, "ROLLBACK", nullptr, nullptr, nullptr);
return;
}
}
void TextSummaryDataBase::InsertKernelDetail(const Kernel &kernel, const std::vector<std::string> &columns) {
std::lock_guard lock(mutex);
kernelCache.emplace_back(kernel);
if (kernelCache.size() == cacheSize) {
InsertKernelDetailList(kernelCache, columns);
kernelCache.clear();
}
}
void TextSummaryDataBase::SaveKernelDetail(const std::vector<std::string> &columns) {
if (kernelCache.size() > 0) {
InsertKernelDetailList(kernelCache, columns);
kernelCache.clear();
}
}
sqlite3_stmt *TextSummaryDataBase::GetKernelStmt(uint64_t paramLen, const std::vector<std::string> &columns) {
sqlite3_stmt *stmt = nullptr;
if (paramLen == cacheSize) {
stmt = insertKernelStmt;
sqlite3_reset(stmt);
} else {
std::string columnSql;
std::string placeHolder;
if (!columns.empty()) {
for (const auto &column : columns) {
columnSql += ", " + column;
placeHolder += ",?";
}
}
std::string sql = "INSERT INTO " + TABLE_KERNEL +
" (deviceId, step_id, task_id, name, op_type, accelerator_core, " +
"start_time, duration, wait_time, block_num, input_shapes, input_data_types, input_formats, " +
"output_shapes, output_data_types, output_formats" + columnSql + ") " +
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?" + placeHolder + ")";
for (uint64_t i = 0; i < paramLen - 1; ++i) {
sql.append(",(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?" + placeHolder + ")");
}
if (sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr) != SQLITE_OK) {
ServerLog::Error("Failed to prepare for insert kernel. length:", paramLen, " . error:", sqlite3_errmsg(db));
return nullptr;
}
}
return stmt;
}
bool TextSummaryDataBase::UpdateParseStatus(const std::string &status) {
return UpdateValueIntoStatusInfoTable(kernelParseState, status);
}
bool TextSummaryDataBase::HasFinishedParseLastTime() {
return CheckValueFromStatusInfoTable(kernelParseState, FINISH_STATUS);
}
bool TextSummaryDataBase::OpenDb(const std::string &dbPath, bool clearAllTable) {
if (!Database::OpenDb(dbPath, clearAllTable)) {
return false;
}
return SetConfig() && CheckAndResetDatabaseOnVersionChange();
}
uint64_t TextSummaryDataBase::QueryMinStartTime() {
std::string sql = "Select MIN(start_time) FROM " + TABLE_KERNEL + " WHERE start_time > 0";
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to prepare sql for query minimum start time.", sqlite3_errmsg(db));
return 0;
}
int64_t min = 0;
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
min = sqlite3_column_int64(stmt, col++);
}
sqlite3_finalize(stmt);
if (min <= 0) {
min = INT64_MAX;
}
return min;
}
bool TextSummaryDataBase::QueryComputeOpDetail(
Protocol::ComputeDetailParams params, std::vector<Protocol::ComputeDetail> &computeDetails) {
std::string sql = GenComputeSql(params);
std::string timeFlag = params.timeFlag;
uint64_t startTime = NumberUtil::CeilingClamp(Timeline::TraceTime::Instance().GetStartTime(), (uint64_t)INT64_MAX);
int64_t offset = (params.currentPage - 1) * params.pageSize;
sqlite3_stmt *stmt = nullptr;
int index = bindStartIndex;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Query operator detail failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
sqlite3_bind_int64(stmt, index++, startTime);
sqlite3_bind_text(stmt, index++, params.timeFlag.c_str(), params.timeFlag.length(), nullptr);
sqlite3_bind_int64(stmt, index++, params.pageSize);
sqlite3_bind_int64(stmt, index++, offset);
std::vector<Protocol::ComputeDetail> computeVec;
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
Protocol::ComputeDetail computeDetail{};
computeDetail.name = sqlite3_column_string(stmt, col++);
computeDetail.type = sqlite3_column_string(stmt, col++);
computeDetail.startTime = sqlite3_column_string(stmt, col++);
computeDetail.duration = sqlite3_column_double(stmt, col++);
computeDetail.waitTime = sqlite3_column_double(stmt, col++);
computeDetail.blockNum = sqlite3_column_int64(stmt, col++);
computeDetail.inputShapes = sqlite3_column_string(stmt, col++);
computeDetail.inputDataTypes = sqlite3_column_string(stmt, col++);
computeDetail.inputFormats = sqlite3_column_string(stmt, col++);
computeDetail.outputShapes = sqlite3_column_string(stmt, col++);
computeDetail.outputDataTypes = sqlite3_column_string(stmt, col++);
computeDetail.outputFormats = sqlite3_column_string(stmt, col++);
computeVec.emplace_back(computeDetail);
}
computeDetails = computeVec;
sqlite3_finalize(stmt);
return true;
}
std::string TextSummaryDataBase::GenSortSql(std::string orderBy, std::string order) {
std::string orderBySql;
if (!StringUtil::CheckSqlValid(orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy to generate sort sql.");
return orderBySql;
}
if (order == "descend") {
orderBySql = " ORDER BY " + orderBy + " DESC";
} else {
orderBySql = " ORDER BY " + orderBy + " ASC";
}
return orderBySql;
}
std::string TextSummaryDataBase::GenComputeSql(Protocol::ComputeDetailParams request) {
std::string orderBy = GenSortSql(request.orderBy, request.order);
std::string sql;
if (request.orderBy.empty()) {
sql = "SELECT name, op_type as type, "
"CASE WHEN start_time == 0 THEN 0 ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 4) END AS startTime, "
"duration, wait_time as waitTime, block_num as blockNum, "
"input_shapes as inputShapes, input_data_types as inputDataTypes, input_formats as inputFormats, "
"output_shapes as outputShapes, output_data_types as outputDataTypes, output_formats as outputFormats "
"FROM " +
TABLE_KERNEL + " WHERE accelerator_core = ? LIMIT ? offset ?";
} else {
sql = "SELECT name, op_type as type, "
"CASE WHEN start_time == 0 THEN 0 ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 4) END AS startTime, "
"duration, wait_time as waitTime, block_num as blockNum, "
"input_shapes as inputShapes, input_data_types as inputDataTypes, input_formats as inputFormats, "
"output_shapes as outputShapes, output_data_types as outputDataTypes, output_formats as outputFormats "
"FROM " +
TABLE_KERNEL + " WHERE accelerator_core = ? " + orderBy + " LIMIT ? offset ?";
}
return sql;
}
bool TextSummaryDataBase::QueryTotalNumByAcceleratorCore(std::string name, int64_t &totalNum) {
sqlite3_stmt *stmt = nullptr;
std::string sql;
if (name == "Communication") {
sql = "SELECT count(*) as nums FROM " + TABLE_KERNEL + " WHERE accelerator_core " + GetAcceleratorCoreSql(true);
} else {
sql = "SELECT count(*) as nums FROM " + TABLE_KERNEL + " WHERE accelerator_core = ?";
}
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Get total num failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
if (name != "Communication") {
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, name.c_str(), name.length(), nullptr);
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
totalNum = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::string TextSummaryDataBase::GetCommSql(Protocol::CommunicationDetailParams request) {
std::string orderBy = GenSortSql(request.orderBy, request.order);
std::string sql;
if (request.orderBy.empty()) {
sql = "SELECT name, op_type as type, CASE WHEN start_time == 0 THEN 'NA' "
"ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 4) END AS startTime, "
"ROUND(duration, 4) as duration, ROUND(wait_time, 4) as waitTime FROM " +
TABLE_KERNEL + " WHERE accelerator_core " + GetAcceleratorCoreSql(true) + " LIMIT ? offset ?";
} else {
sql = "SELECT name, op_type as type, CASE WHEN start_time == 0 THEN 'NA' "
"ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 4) END AS startTime, "
"ROUND(duration, 4) as duration, ROUND(wait_time, 4) as waitTime FROM " +
TABLE_KERNEL + " WHERE accelerator_core " + GetAcceleratorCoreSql(true) + orderBy + " LIMIT ? offset ?";
}
return sql;
}
bool TextSummaryDataBase::QueryCommunicationOpDetail(
Protocol::CommunicationDetailParams params, std::vector<Protocol::CommunicationDetail> &computeDetails) {
std::string sql = GetCommSql(params);
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
int64_t offset = (params.currentPage - 1) * params.pageSize;
sqlite3_stmt *stmt = nullptr;
int index = bindStartIndex;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Query common detail failed! Failed to prepare sql.", sqlite3_errmsg(db));
return false;
}
sqlite3_bind_int64(stmt, index++, NumberUtil::CeilingClamp(startTime, (uint64_t)INT64_MAX));
sqlite3_bind_int64(stmt, index++, params.pageSize);
sqlite3_bind_int64(stmt, index++, offset);
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
Protocol::CommunicationDetail computeDetail{};
computeDetail.name = sqlite3_column_string(stmt, col++);
computeDetail.type = sqlite3_column_string(stmt, col++);
computeDetail.startTime = sqlite3_column_string(stmt, col++);
computeDetail.duration = sqlite3_column_double(stmt, col++);
computeDetail.waitTime = sqlite3_column_double(stmt, col++);
computeDetails.emplace_back(computeDetail);
}
sqlite3_finalize(stmt);
return true;
}
std::string TextSummaryDataBase::GenerateQueryCategoryDurationSql(Protocol::OperatorDurationReqParams &reqParams) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
if (operatorGroup == OperatorGroupConverter::OperatorGroup::UNKNOWN) {
ServerLog::Error("Category duration sql generate failed, unknown operator group.");
return "";
}
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string sql;
if (operatorGroup == OperatorGroupConverter::OperatorGroup::OP_NAME_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_NAME_GROUP) {
std::string name = "name";
sql = " SELECT " + name + " as name , ROUND(duration, 2) as duration FROM " + TABLE_KERNEL +
" WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" ORDER BY duration DESC LIMIT ?";
} else {
std::string name;
std::string group;
if (IsOperatorGroupInType(operatorGroup)) {
name = "op_type";
group = "op_type || accelerator_core";
} else {
name = R"(name || '[' || input_shapes || ']')";
group = R"(name || '[' || input_shapes || ']' || accelerator_core)";
}
sql = " SELECT " + name + " as name, ROUND(sum(duration), 2) as duration FROM " + TABLE_KERNEL +
" WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) + " GROUP by " +
group + " ORDER BY duration DESC LIMIT ?";
}
return sql;
}
std::string TextSummaryDataBase::GenerateQueryComputeUnitDurationSql(Protocol::OperatorDurationReqParams &reqParams) {
std::string group;
if (reqParams.group == Protocol::OP_TYPE_GROUP) {
group = "op_type || accelerator_core";
} else if (reqParams.group == Protocol::OPERATOR_GROUP) {
group = "name || accelerator_core";
} else {
group = R"(name || '[' || input_shapes || ']' || accelerator_core)";
}
std::string sql = " SELECT accelerator_core as name, ROUND(SUM(duration), 2) as duration"
" FROM ("
" SELECT " +
group + ", accelerator_core," +
(reqParams.group == Protocol::OPERATOR_GROUP ? " ROUND(duration, 2) as duration"
: " ROUND(SUM(duration), 2) as duration") +
" FROM " + TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(false) +
(reqParams.group == Protocol::OPERATOR_GROUP ? " " : (" GROUP BY " + group)) +
" ORDER BY duration DESC LIMIT ?"
" ) subquery" +
" GROUP by accelerator_core"
" ORDER BY duration DESC";
return sql;
}
bool TextSummaryDataBase::QueryOperatorDurationInfo(Protocol::OperatorDurationReqParams &reqParams,
Protocol::QueryType type, std::vector<Protocol::OperatorDurationRes> &data) {
std::string sql;
if (type == Protocol::QueryType::CATEGORY) {
sql = GenerateQueryCategoryDurationSql(reqParams);
} else {
sql = GenerateQueryComputeUnitDurationSql(reqParams);
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Duration Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), reqParams.deviceId.length(), SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, index++, reqParams.topK);
std::vector<Protocol::OperatorDurationRes> res;
while (sqlite3_step(stmt) == SQLITE_ROW) {
Protocol::OperatorDurationRes one{};
int col = 0;
one.name = sqlite3_column_string(stmt, col++);
one.duration = sqlite3_column_double(stmt, col++);
if (res.size() >= maxCategorySize) {
res[maxCategorySize - 1].name = "Others";
res[maxCategorySize - 1].duration += one.duration;
} else {
res.emplace_back(one);
}
}
data = res;
sqlite3_finalize(stmt);
return true;
}
bool TextSummaryDataBase::QueryStatisticTotalNum(Protocol::OperatorStatisticReqParams &reqParams, int64_t &total) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string group;
if (IsOperatorGroupInType(operatorGroup)) {
group = "op_type || accelerator_core";
} else {
group = R"(name || '[' || input_shapes || ']' || accelerator_core)";
}
std::string sql = " SELECT COUNT(*) as nums"
" FROM ("
" SELECT accelerator_core as accCore, op_type as opType, name as opName FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" GROUP by " + group +
" ORDER by ROUND(SUM(duration), 2) DESC LIMIT ?"
" ) subquery";
GenerateQueryFiltersSql<Protocol::OperatorStatisticReqParams>(reqParams, sql);
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Statistic Num. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::string TextSummaryDataBase::GenerateQueryStatisticSql(Protocol::OperatorStatisticReqParams &reqParams) {
std::string sql = GetQueryBaseStaticSql(reqParams);
if (sql.empty()) {
return "";
}
if (!StringUtil::CheckSqlValid(reqParams.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy"
"to generate query statistic sql.");
} else if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + reqParams.orderBy + " " + (reqParams.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? OFFSET ?";
return sql;
}
std::string TextSummaryDataBase::GetQueryBaseStaticSql(Protocol::OperatorStatisticReqParams &reqParams) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string name;
std::string group;
if (IsOperatorGroupInType(operatorGroup)) {
group = "op_type || accelerator_core";
name = "''";
} else {
group = R"(name || input_shapes || accelerator_core)";
name = "name";
}
std::string sql = GetQuerySqlNofilter(reqParams, isCommunication, group, name);
GenerateQueryFiltersSql<Protocol::OperatorStatisticReqParams>(reqParams, sql);
return sql;
}
std::string TextSummaryDataBase::GetAcceleratorCoreSql(const bool isCommunication) {
if (isCommunication) {
return " IN ('HCCL', 'COMMUNICATION') ";
} else {
return " NOT IN ('HCCL', 'COMMUNICATION') ";
}
}
std::string TextSummaryDataBase::GetQuerySqlNofilter(Protocol::OperatorStatisticReqParams &reqParams,
const bool isCommunication, const std::string &group, const std::string &name) {
std::string limitSql = " SELECT * FROM ("
" SELECT op_type as opType, " +
name +
" as opName, input_shapes as inputShape, "
" accelerator_core as accCore,"
" ROUND(SUM(duration), 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(duration) / COUNT(0), 2) as avgTime,"
" ROUND(max(duration), 2) as maxTime,"
" ROUND(min(duration), 2) as minTime"
" FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" GROUP by " + group +
" ORDER by totalTime DESC LIMIT ?"
" ) subquery ";
std::string baseNolimitSql = " SELECT * FROM ("
" SELECT op_type as opType, " +
name +
" as opName, input_shapes as inputShape, "
" accelerator_core as accCore,"
" ROUND(SUM(duration), 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(duration) / COUNT(0), 2) as avgTime,"
" ROUND(max(duration), 2) as maxTime,"
" ROUND(min(duration), 2) as minTime"
" FROM " +
TABLE_KERNEL + " WHERE accelerator_core " + GetAcceleratorCoreSql(isCommunication) + " GROUP by " +
group + " ) subquery";
std::string noLimitsql = " SELECT * FROM ("
" SELECT op_type as opType, " +
name +
" as opName, input_shapes as inputShape, accelerator_core as accCore,"
" ROUND(SUM(duration), 2) as totalTime, COUNT(0) as count,"
" ROUND(SUM(duration) / COUNT(0), 2) as avgTime,"
" ROUND(max(duration), 2) as maxTime,"
" ROUND(min(duration), 2) as minTime"
" FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" GROUP by " + group + " ) subquery";
if (!reqParams.isCompare) {
return limitSql;
}
if (reqParams.isCompare && !reqParams.deviceId.empty()) {
return noLimitsql;
}
if (reqParams.isCompare && reqParams.deviceId.empty()) {
return baseNolimitSql;
}
return "";
}
bool TextSummaryDataBase::QueryAllOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, std::vector<Protocol::OperatorStatisticInfoRes> &res) {
std::string sql = GetQueryBaseStaticSql(reqParams);
return ExecSqlGetStaticInfo(sql, reqParams, res);
}
bool TextSummaryDataBase::QueryOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorStatisticInfoResponse &response) {
if (!QueryStatisticTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of statistic info.");
return false;
}
std::string sql = GenerateQueryStatisticSql(reqParams);
std::vector<Protocol::OperatorStatisticInfoRes> res;
if (!ExecSqlGetStaticInfo(sql, reqParams, res)) {
return false;
}
std::vector<Protocol::OperatorStatisticCmpInfoRes> cmpRes;
for (auto &data : res) {
OperatorStatisticCmpInfoRes tmpInfo;
tmpInfo.compare = data;
cmpRes.emplace_back(tmpInfo);
}
response.data = cmpRes;
return true;
}
bool TextSummaryDataBase::ExecSqlGetStaticInfo(const std::string &sql, Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorStatisticInfoRes> &res) {
if (sql.empty()) {
ServerLog::Error("Failed to generate query statistic sql.");
return false;
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Statistic Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
if (!reqParams.deviceId.empty()) {
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), reqParams.deviceId.length(), SQLITE_TRANSIENT);
}
if (!reqParams.isCompare) {
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, reqParams.pageSize * (reqParams.current - 1));
} else {
BindQueryFilters(reqParams, stmt, index);
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = 0;
Protocol::OperatorStatisticInfoRes one{};
one.opType = sqlite3_column_string(stmt, col++);
one.opName = sqlite3_column_string(stmt, col++);
one.inputShape = sqlite3_column_string(stmt, col++);
one.accCore = sqlite3_column_string(stmt, col++);
one.totalTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.count = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.avgTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.maxTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.minTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
res.emplace_back(one);
}
sqlite3_finalize(stmt);
return true;
}
bool TextSummaryDataBase::QueryDetailTotalNum(Protocol::OperatorStatisticReqParams &reqParams, int64_t &total) {
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
sqlite3_stmt *stmt = nullptr;
std::string sql = " SELECT COUNT(*) as nums"
" FROM ("
" SELECT name, op_type as type, accelerator_core as accCore FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" ORDER BY duration DESC LIMIT ?"
" ) subquery";
GenerateQueryFiltersSql<Protocol::OperatorStatisticReqParams>(reqParams, sql);
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Detail Total Num. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::set<std::string> TextSummaryDataBase::FetchPmuColumnNames() {
if (!pmuColumns_.empty()) {
return pmuColumns_;
}
std::string queryColumnSql = "SELECT name "
"FROM pragma_table_info('" +
TABLE_KERNEL +
"') "
"WHERE cid >= ( "
" SELECT cid "
" FROM pragma_table_info('" +
TABLE_KERNEL +
"') "
" WHERE name = 'aicore_time_us_'"
");";
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, queryColumnSql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Detail Info. Msg:", sqlite3_errmsg(db), " ", result);
return pmuColumns_;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
std::string colName = sqlite3_column_string(stmt, 0);
if (!RegexUtil::RegexMatch(colName, PMU_HEADER_WHITE_LIST_REG)) {
ServerLog::Error("There is an SQL injection attack on this parameter. error param: %", colName);
sqlite3_finalize(stmt);
return {};
}
pmuColumns_.insert(colName);
}
sqlite3_finalize(stmt);
return pmuColumns_;
}
std::string TextSummaryDataBase::GetQueryDetailBaseSql(Protocol::OperatorStatisticReqParams &reqParams, bool isLimit) {
bool isCommunication = Protocol::OperatorGroupConverter::IsCommunication(reqParams.group);
std::string pmuColumnNames;
if (!isCommunication) {
std::set<std::string> pmuClos = FetchPmuColumnNames();
pmuColumnNames = JoinExtraColName(std::vector<std::string>(pmuClos.begin(), pmuClos.end()));
}
std::string sql;
std::string sqlTab =
" SELECT deviceId, step_id , name, op_type as type, accelerator_core as accCore,"
" CASE WHEN start_time == 0 THEN 'NA' ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 2)"
" END AS startTime, duration, wait_time as waitTime, block_num as blockNum,"
" input_shapes as inputShape, input_data_types as inputType, input_formats as inputFormat, "
" output_shapes as outputShape, output_data_types as outputType, output_formats as outputFormat " +
pmuColumnNames;
std::string conditionalQuerySql = " FROM ("
" SELECT * FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" ORDER by duration DESC LIMIT ?"
" ) subquery ";
std::string allQuerySql = " FROM ("
" SELECT * FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" ORDER by start_time DESC"
" ) subquery ";
std::string baseAllQuerySql = " FROM ("
" SELECT * FROM " +
TABLE_KERNEL + " WHERE accelerator_core " + GetAcceleratorCoreSql(isCommunication) +
" ORDER by start_time DESC"
" ) subquery ";
if (isLimit && !reqParams.deviceId.empty()) {
sql = sqlTab + conditionalQuerySql;
} else {
if (reqParams.deviceId.empty()) {
sql = sqlTab + baseAllQuerySql;
} else {
sql = sqlTab + allQuerySql;
}
}
GenerateQueryFiltersSql<Protocol::OperatorStatisticReqParams>(reqParams, sql);
return sql;
}
std::string TextSummaryDataBase::GenerateQueryDetailSql(Protocol::OperatorStatisticReqParams &reqParams) {
std::string sql = GetQueryDetailBaseSql(reqParams, true);
if (sql.empty()) {
return "";
}
if (!StringUtil::CheckSqlValid(reqParams.orderBy)) {
ServerLog::Error("There is an SQL injection attack on the parameter of orderBy"
"to generate query detail sql.");
} else if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + reqParams.orderBy + " " + (reqParams.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? OFFSET ?";
return sql;
}
bool TextSummaryDataBase::ExecSqlGetDetailInfo(std::string sql, Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorDetailInfoRes> &res, std::string &level) {
if (sql.empty()) {
ServerLog::Error("Failed to generate query statistic sql.");
return false;
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Detail Info. Msg:", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_int64(
stmt, index++, NumberUtil::CeilingClamp(Timeline::TraceTime::Instance().GetStartTime(), (uint64_t)INT64_MAX));
if (!reqParams.deviceId.empty()) {
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), reqParams.deviceId.length(), SQLITE_TRANSIENT);
}
if (!reqParams.isCompare) {
sqlite3_bind_int64(stmt, index++, reqParams.topK);
BindQueryFilters(reqParams, stmt, index);
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, (reqParams.current - 1) * reqParams.pageSize);
} else {
BindQueryFilters(reqParams, stmt, index);
}
ExecSqlGetRes(stmt, res);
level = OperatorGetLevel(res);
sqlite3_finalize(stmt);
return true;
}
bool TextSummaryDataBase::ExecSqlGetRes(sqlite3_stmt *stmt, std::vector<Protocol::OperatorDetailInfoRes> &res) {
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = 0;
Protocol::OperatorDetailInfoRes one{};
one.rankId = sqlite3_column_string(stmt, col++);
one.stepId = sqlite3_column_string(stmt, col++);
one.name = sqlite3_column_string(stmt, col++);
one.type = sqlite3_column_string(stmt, col++);
one.accCore = sqlite3_column_string(stmt, col++);
one.startTime = sqlite3_column_string(stmt, col++);
one.duration = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.waitTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.blockNum = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.inputShape = sqlite3_column_string(stmt, col++);
one.inputType = sqlite3_column_string(stmt, col++);
one.inputFormat = sqlite3_column_string(stmt, col++);
one.outputShape = sqlite3_column_string(stmt, col++);
one.outputType = sqlite3_column_string(stmt, col++);
one.outputFormat = sqlite3_column_string(stmt, col++);
for (const auto &pmuCol : FetchPmuColumnNames()) {
one.pmuDatas[pmuCol] = sqlite3_column_string(stmt, col++);
}
res.emplace_back(one);
}
return true;
}
bool TextSummaryDataBase::QueryAllOperatorDetailInfo(Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorDetailInfoRes> &res, std::string &level) {
std::string sql = GetQueryDetailBaseSql(reqParams, false);
return ExecSqlGetDetailInfo(sql, reqParams, res, level);
}
bool TextSummaryDataBase::QueryOperatorDetailInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorDetailInfoResponse &response) {
if (!QueryDetailTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of detail info.");
return false;
}
std::string sql = GenerateQueryDetailSql(reqParams);
std::vector<Protocol::OperatorDetailInfoRes> sqlRes;
if (!ExecSqlGetDetailInfo(sql, reqParams, sqlRes, response.level)) {
return false;
}
std::vector<Protocol::OperatorDetailCmpInfoRes> resultData;
for (auto &data : sqlRes) {
OperatorDetailCmpInfoRes tmpInfo;
tmpInfo.compare = data;
resultData.emplace_back(tmpInfo);
}
response.data = resultData;
response.pmuHeaders = FetchPmuColumnNames();
return true;
}
bool TextSummaryDataBase::QueryMoreInfoTotalNum(Protocol::OperatorMoreInfoReqParams &reqParams, int64_t &total) {
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
std::string condition;
if (IsOperatorGroupInType(operatorGroup)) {
condition = " op_type = ?";
} else {
condition = " name = ? AND input_shapes = ?";
}
std::string sql = " SELECT COUNT(*) as nums"
" FROM ("
" SELECT name, op_type as type, accelerator_core as accCore"
" FROM " +
TABLE_KERNEL + " WHERE deviceId = ? AND accelerator_core = ? AND" + condition + " ) subquery";
GenerateQueryFiltersSql<Protocol::OperatorMoreInfoReqParams>(reqParams, sql);
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get More Total Num. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
int index = bindStartIndex;
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), reqParams.deviceId.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, index++, reqParams.accCore.c_str(), reqParams.accCore.length(), SQLITE_TRANSIENT);
if (IsOperatorGroupInType(operatorGroup)) {
sqlite3_bind_text(stmt, index++, reqParams.opType.c_str(), reqParams.opType.length(), SQLITE_TRANSIENT);
} else {
sqlite3_bind_text(stmt, index++, reqParams.opName.c_str(), reqParams.opName.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, index++, reqParams.shape.c_str(), reqParams.shape.length(), SQLITE_TRANSIENT);
}
BindQueryFilters(reqParams, stmt, index);
while (sqlite3_step(stmt) == SQLITE_ROW) {
total = sqlite3_column_int64(stmt, resultStartIndex);
}
sqlite3_finalize(stmt);
return true;
}
std::string TextSummaryDataBase::GenerateQueryMoreInfoSql(Protocol::OperatorMoreInfoReqParams &reqParams) {
std::set<std::string> pmuClos = FetchPmuColumnNames();
std::string pmuColumnNames = JoinExtraColName(std::vector<std::string>(pmuClos.begin(), pmuClos.end()));
std::string sql =
" SELECT deviceId, step_id, name, op_type as type, accelerator_core as accCore,"
" CASE WHEN start_time == 0 THEN 'NA' ELSE ROUND((start_time - ?) / (1000.0 * 1000.0), 2)"
" END AS startTime, duration, wait_time as waitTime, block_num as blockNum,"
" input_shapes as inputShape, input_data_types as inputType, input_formats as inputFormat, "
" output_shapes as outputShape, output_data_types as outputType, output_formats as outputFormat " +
pmuColumnNames +
" FROM ("
" SELECT * FROM " +
TABLE_KERNEL +
" WHERE deviceId = ? AND accelerator_core = ?"
" ORDER by duration DESC"
" ) subquery ";
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
if (IsOperatorGroupInType(operatorGroup)) {
sql += " WHERE type = ?";
} else {
sql += " WHERE name = ? AND inputShape = ?";
}
for (const auto &filter : reqParams.filters) {
sql += " AND " + filter.first + " LIKE ? ";
}
if (!reqParams.orderBy.empty() && !reqParams.order.empty()) {
sql += " ORDER by " + reqParams.orderBy + " " + (reqParams.order == "ascend" ? "ASC" : "DESC");
}
sql += " LIMIT ? OFFSET ?";
return sql;
}
bool TextSummaryDataBase::QueryOperatorMoreInfo(
Protocol::OperatorMoreInfoReqParams &reqParams, Protocol::OperatorMoreInfoResponse &response) {
if (!QueryMoreInfoTotalNum(reqParams, response.total)) {
ServerLog::Error("[Operator]Failed to query total num of more info.");
return false;
}
std::string sql = GenerateQueryMoreInfoSql(reqParams);
if (sql.empty()) {
ServerLog::Error("Failed to generate query statistic sql.");
return false;
}
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to get Op More Info. Msg: ", sqlite3_errmsg(db), " ", result);
return false;
}
if (reqParams.current <= 0) {
ServerLog::Error("The current page is less than or equal to 0");
sqlite3_finalize(stmt);
return false;
}
BindSqliteParam(stmt, reqParams);
std::vector<Protocol::OperatorDetailInfoRes> res = ExecSqlGetMoreInfo(stmt);
response.level = OperatorGetLevel(res);
response.data = res;
response.pmuHeaders = pmuColumns_;
sqlite3_finalize(stmt);
return true;
}
std::vector<Protocol::OperatorDetailInfoRes> TextSummaryDataBase::ExecSqlGetMoreInfo(sqlite3_stmt *stmt) {
std::vector<Protocol::OperatorDetailInfoRes> res;
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = 0;
Protocol::OperatorDetailInfoRes one{};
one.rankId = sqlite3_column_string(stmt, col++);
one.stepId = sqlite3_column_string(stmt, col++);
one.name = sqlite3_column_string(stmt, col++);
one.type = sqlite3_column_string(stmt, col++);
one.accCore = sqlite3_column_string(stmt, col++);
one.startTime = sqlite3_column_string(stmt, col++);
one.duration = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.waitTime = Sqlite3ColumnConvertStr(SQLITE_FLOAT, stmt, col++);
one.blockNum = Sqlite3ColumnConvertStr(SQLITE_INTEGER, stmt, col++);
one.inputShape = sqlite3_column_string(stmt, col++);
one.inputType = sqlite3_column_string(stmt, col++);
one.inputFormat = sqlite3_column_string(stmt, col++);
one.outputShape = sqlite3_column_string(stmt, col++);
one.outputType = sqlite3_column_string(stmt, col++);
one.outputFormat = sqlite3_column_string(stmt, col++);
for (const auto &pmuCol : pmuColumns_) {
one.pmuDatas[pmuCol] = sqlite3_column_string(stmt, col++);
}
res.emplace_back(one);
}
return res;
}
void TextSummaryDataBase::BindSqliteParam(sqlite3_stmt *stmt, Protocol::OperatorMoreInfoReqParams &reqParams) {
uint64_t startTime = Timeline::TraceTime::Instance().GetStartTime();
int index = bindStartIndex;
sqlite3_bind_int64(stmt, index++, NumberUtil::CeilingClamp(startTime, (uint64_t)INT64_MAX));
sqlite3_bind_text(stmt, index++, reqParams.deviceId.c_str(), reqParams.deviceId.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, index++, reqParams.accCore.c_str(), reqParams.accCore.length(), SQLITE_TRANSIENT);
OperatorGroupConverter::OperatorGroup operatorGroup = Protocol::OperatorGroupConverter::ToEnum(reqParams.group);
if (IsOperatorGroupInType(operatorGroup)) {
sqlite3_bind_text(stmt, index++, reqParams.opType.c_str(), reqParams.opType.length(), SQLITE_TRANSIENT);
} else {
sqlite3_bind_text(stmt, index++, reqParams.opName.c_str(), reqParams.opName.length(), SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, index++, reqParams.shape.c_str(), reqParams.shape.length(), SQLITE_TRANSIENT);
}
BindQueryFilters(reqParams, stmt, index);
sqlite3_bind_int64(stmt, index++, reqParams.pageSize);
sqlite3_bind_int64(stmt, index++, (reqParams.current - 1) * reqParams.pageSize);
}
template <typename T> void TextSummaryDataBase::GenerateQueryFiltersSql(T &reqParams, std::string &sql) {
if (reqParams.filters.empty()) {
return;
}
sql += " WHERE ";
for (uint64_t index = 0; index < reqParams.filters.size(); index++) {
std::pair<std::string, std::string> filter = reqParams.filters[index];
if (index != 0) {
sql += " AND ";
}
sql += filter.first + " LIKE ? ";
}
}
template <typename T> void TextSummaryDataBase::BindQueryFilters(T &reqParams, sqlite3_stmt *stmt, int &index) {
if (reqParams.filters.empty()) {
return;
}
for (uint64_t i = 0; i < reqParams.filters.size(); i++) {
std::pair<std::string, std::string> filter = reqParams.filters[i];
std::string filterParam = "%" + filter.second + "%";
sqlite3_bind_text(stmt, index++, filterParam.c_str(), filterParam.length(), SQLITE_TRANSIENT);
}
}
std::set<std::string> TextSummaryDataBase::QueryRankIds() {
std::set<std::string> rankIds = {};
std::string sql = "SELECT deviceId FROM " + TABLE_KERNEL + " GROUP BY deviceId";
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to prepare stmt for QueryRankIds. Msg: ", sqlite3_errmsg(db));
return rankIds;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
std::string rank = sqlite3_column_string(stmt, col++);
rankIds.emplace(rank);
}
sqlite3_finalize(stmt);
return rankIds;
}
bool TextSummaryDataBase::IsOperatorGroupInType(OperatorGroupConverter::OperatorGroup operatorGroup) {
return operatorGroup == OperatorGroupConverter::OperatorGroup::OP_TYPE_GROUP ||
operatorGroup == OperatorGroupConverter::OperatorGroup::COMMUNICATION_TYPE_GROUP;
}
bool TextSummaryDataBase::QueryBandwidthContentionMatMulData(std::vector<BandwidthContentionMatMulInfo> &res) {
std::string sql = "SELECT name, ROUND(start_time / 1000.0, 3) AS startTime, duration FROM " + TABLE_KERNEL +
" WHERE name LIKE '%matmul%' ORDER BY startTime";
return ExecuteQueryBandwidthContentionMatMulData(res, sql);
}
}