* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef PROFILER_SERVER_SUMMARY_DATABASE_H
#define PROFILER_SERVER_SUMMARY_DATABASE_H
#include "Database.h"
#include "ProtocolMessage.h"
#include "SummaryDef.h"
#include "SummaryProtocolRequest.h"
#include "SummaryProtocolResponse.h"
#include "OperatorProtocolRequest.h"
#include "OperatorProtocolResponse.h"
#include "TableDefs.h"
#include "NumDefs.h"
#include "CommonDefs.h"
#include "ClusterDef.h"
namespace Dic::Module::Summary {
using namespace Dic::Server;
class VirtualSummaryDataBase : public Database {
public:
bool levelIsL0 = true;
explicit VirtualSummaryDataBase(std::recursive_mutex &sqlMutex) : Database(sqlMutex) {};
~VirtualSummaryDataBase() override = default;
virtual bool QueryComputeOpDetail(
Protocol::ComputeDetailParams params, std::vector<Protocol::ComputeDetail> &computeDetails) = 0;
virtual bool QueryTotalNumByAcceleratorCore(std::string name, int64_t &totalNum) = 0;
virtual bool QueryCommunicationOpDetail(
Protocol::CommunicationDetailParams params, std::vector<Protocol::CommunicationDetail> &computeDetails) = 0;
virtual bool QueryOperatorDurationInfo(Protocol::OperatorDurationReqParams &reqParams, Protocol::QueryType type,
std::vector<Protocol::OperatorDurationRes> &data) = 0;
virtual bool QueryOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorStatisticInfoResponse &response) = 0;
virtual bool QueryOperatorDetailInfo(
Protocol::OperatorStatisticReqParams &reqParams, Protocol::OperatorDetailInfoResponse &response) = 0;
virtual bool QueryAllOperatorDetailInfo(Protocol::OperatorStatisticReqParams &reqParams,
std::vector<Protocol::OperatorDetailInfoRes> &res, std::string &level) = 0;
virtual bool QueryOperatorMoreInfo(
Protocol::OperatorMoreInfoReqParams &reqParams, Protocol::OperatorMoreInfoResponse &response) = 0;
virtual bool QueryAllOperatorStatisticInfo(
Protocol::OperatorStatisticReqParams &reqParams, std::vector<Protocol::OperatorStatisticInfoRes> &res) = 0;
virtual bool QueryBandwidthContentionMatMulData(std::vector<BandwidthContentionMatMulInfo> &res) = 0;
bool ExecuteQueryBandwidthContentionMatMulData(std::vector<BandwidthContentionMatMulInfo> &res, std::string &sql) {
sqlite3_stmt *stmt = nullptr;
int result = sqlite3_prepare_v2(db, sql.c_str(), -1, &stmt, nullptr);
if (result != SQLITE_OK) {
ServerLog::Error("Failed to query bandwidth contention matmul data. Error: ", sqlite3_errmsg(db));
return false;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
int col = resultStartIndex;
BandwidthContentionMatMulInfo info;
info.name = sqlite3_column_string(stmt, col++);
info.startTime = sqlite3_column_double(stmt, col++);
info.duration = sqlite3_column_double(stmt, col++);
res.emplace_back(info);
}
sqlite3_finalize(stmt);
return true;
}
uint64_t QueryMinStartTime();
static inline std::string GetFileIdFromCombinationId(const std::string &str) {
auto len = MSPROF_PREFIX.length();
if (str.length() <= len || str.compare(0, len, MSPROF_PREFIX) != 0) {
return str;
}
auto index = str.find_last_of(MSPROF_CONNECT);
if (index == std::string::npos) {
return str;
}
return str.substr(len, index - len - 1);
}
static inline std::string GetDeviceIdFromCombinationId(const std::string &str) {
auto len = MSPROF_PREFIX.length();
if (str.length() <= len || str.compare(0, len, MSPROF_PREFIX) != 0) {
return str;
}
auto index = str.find_last_of(MSPROF_CONNECT);
if (index == std::string::npos) {
return str;
}
return str.substr(index + MSPROF_CONNECT.length() - 1);
}
std::string OperatorGetLevel(const std::vector<Protocol::OperatorDetailInfoRes> &res) {
std::string level;
if (res.empty()) {
level = levelIsL0 ? "l0" : "l1";
} else if (res.at(0).inputShape.empty()) {
level = "l0";
levelIsL0 = true;
} else {
level = "l1";
levelIsL0 = false;
}
return level;
}
std::string JoinExtraColName(const std::vector<std::string> &cols) {
std::string pmuColumnNames;
if (!cols.empty()) {
pmuColumnNames = ",";
pmuColumnNames += StringUtil::join(cols, ',');
}
return pmuColumnNames;
}
const std::set<std::string> &GetPmuColumns() const { return pmuColumns_; }
std::set<std::string> pmuColumns_;
};
}
#endif