* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef PROFILER_SERVER_VIRTUALMEMORYDATABASE_H
#define PROFILER_SERVER_VIRTUALMEMORYDATABASE_H
#include <vector>
#include <cstdint>
#include "Database.h"
#include "MemoryProtocolUtil.h"
#include "MemoryProtocolRequest.h"
#include "MemoryProtocolRespose.h"
#include "MemoryTableView.h"
namespace Dic {
namespace Module {
namespace Memory {
using componentDtoVector = std::vector<Protocol::ComponentDto>;
using namespace Dic::Protocol;
struct MemoryDataBaseContext {
public:
bool withMemoryRecord{false};
bool withNpuModuleMem{false};
bool withOperatorMemory{false};
bool withNpuMem{false};
};
class VirtualMemoryDataBase : public Database {
public:
explicit VirtualMemoryDataBase(std::recursive_mutex &sqlMutex) : Database(sqlMutex) {};
;
~VirtualMemoryDataBase() override = default;
virtual bool QueryMemoryType(std::string &type, std::vector<std::string> &graphId) = 0;
virtual bool QueryMemoryResourceType(std::string &type) = 0;
virtual int64_t QueryOperatorDetail(
Protocol::MemoryOperatorParams &requestParams, std::vector<Protocol::MemoryOperator> &opDetails) = 0;
virtual bool QueryComponentDetail(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryTableColumnAttr> &columnAttr,
std::vector<Protocol::MemoryComponent> &componentDetails) = 0;
virtual bool QueryMemoryView(
Protocol::MemoryViewParams &requestParams, Protocol::MemoryViewData &operatorBody, uint64_t offsetTime) = 0;
virtual int64_t QueryStaticOperatorList(
Protocol::StaticOperatorListParams &requestParams, std::vector<Protocol::StaticOperatorItem> &opDetails) = 0;
virtual bool QueryStaticOperatorGraph(
Protocol::StaticOperatorGraphParams &requestParams, Protocol::StaticOperatorGraphItem &graphItem) = 0;
virtual bool QueryComponentsTotalNum(Protocol::MemoryComponentParams &requestParams, int64_t &totalNum) = 0;
virtual bool QueryOperatorSize(Protocol::MemoryOperatorSizeParams &requestParams, double &min, double &max) = 0;
virtual bool QueryStaticOperatorSize(
Protocol::StaticOperatorSizeParams &requestParams, double &min, double &max) = 0;
virtual bool QueryEntireOperatorTable(Protocol::MemoryOperatorParams &requestParams,
std::vector<Protocol::MemoryOperator> &opDetails, uint64_t offsetTime) = 0;
virtual bool QueryEntireComponentTable(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryComponent> &componentDetails, uint64_t offsetTime) = 0;
virtual bool QueryEntireStaticOperatorTable(
Protocol::StaticOperatorListParams &requestParams, std::vector<Protocol::StaticOperatorItem> &opDetails) = 0;
virtual void GetSelectOperatorMemoryColumnAndAlias(
std::string_view columnKey, uint64_t baseTimestamp, std::string &column, std::string &alias) = 0;
void GetStaticOperatorColumns(std::vector<Protocol::MemoryTableColumnAttr> ©To);
virtual MemoryDataBaseContext GetMemoryDbContext() = 0;
const int defaultPageSize = 10;
const int64_t maxPageSize = 1000;
protected:
const std::string operatorTable = "operator";
const std::string recordTable = "record";
const int64_t maxUnsignedInt = 4294967295;
const int64_t maxCurrentPage = 10000000000;
const double kbSizeDouble = 1024.0;
const double staticDefaultTotalSize = -1.0;
const double componentThresholdMb = 100.0;
const double componentThresholdByte = 100.0 * 1024.0 * 1024.0;
bool isInference = false;
bool initContextFlag = false;
MemoryDataBaseContext memDbContext = {};
const std::vector<std::string> baseLegends = {
"Time (ms)", "Operators Allocated", "Operators Activated", "Operators Reserved"};
const std::vector<std::string> workspaceLegends = {"Workspace Allocated", "Workspace Reserved"};
const std::vector<std::string> componentTimeLegends = {"Time (ms)"};
const std::vector<std::string> componentPtaLegends = {"PTA Allocated", "PTA Activated", "PTA Reserved"};
const std::vector<std::string> componentGeLegends = {"GE Allocated", "GE Activated", "GE Reserved"};
const std::vector<std::string> staticGraphLegends = {"Node Index", "Size", "Total Size"};
const std::string appLegend = "App Reserved";
const std::vector<Protocol::MemoryTableColumnAttr> tableColumnAttr = {{"Name", "string", "name"},
{"Size(KB)", "number", "size"}, {"Allocation Time(ms)", "number", "allocationTime"},
{"Release Time(ms)", "number", "releaseTime"}, {"Duration(ms)", "number", "duration"},
{"Active Release Time(ms)", "number", "activeReleaseTime"}, {"Active Duration(ms)", "number", "activeDuration"},
{"Allocation Total Allocated(MB)", "number", "allocationAllocated"},
{"Allocation Total Reserved(MB)", "number", "allocationReserved"},
{"Allocation Total Active(MB)", "number", "allocationActive"},
{"Release Total Allocated(MB)", "number", "releaseAllocated"},
{"Release Total Reserved(MB)", "number", "releaseReserved"},
{"Release Total Active(MB)", "number", "releaseActive"}, {"Stream", "string", "streamId"}};
const std::vector<Protocol::MemoryTableColumnAttr> staticOpTableColumnAttr = {
{"Device ID", "string", std::string(StaticOpColumn::DEVICE_ID)},
{"Name", "string", std::string(StaticOpColumn::OP_NAME)},
{"Node Index Start", "number", std::string(StaticOpColumn::NODE_INDEX_START)},
{"Node Index End", "number", std::string(StaticOpColumn::NODE_INDEX_END)},
{"Size(MB)", "number", std::string(StaticOpColumn::SIZE)}};
const std::vector<std::string> activeRelatedColumn = {"Active Release Time(ms)", "Active Duration(ms)",
"Allocation Total Active(MB)", "Release Total Active(MB)", "Stream"};
const std::vector<Protocol::MemoryTableColumnAttr> componentTableColumnAttr = {{"Component", "string", "component"},
{"Peak Memory Reserved(MB)", "number", "totalReserved"}, {"Timestamp(ms)", "number", "timestamp"}};
const std::set<std::string_view> timestampColumn = {
OpMemoryColumn::ALLOCATION_TIME, OpMemoryColumn::RELEASE_TIME, OpMemoryColumn::ACTIVE_RELEASE_TIME};
const std::string COMPONENT_APP = "APP";
const std::string COMPONENT_GE = "GE";
const std::string MIND_SPORE = "MindSpore";
const std::string COMPONENT_PTA = "PTA";
const std::string COMPONENT_PTA_AND_GE = "PTA+GE";
const std::string COMPONENT_WORKSPACE = "WORKSPACE";
const std::string MIND_SPORE_GE = "MindSpore+GE";
const std::set<std::string_view> OPERATOR_MEMORY_ARA_SIZE_COLUMNS = {OpMemoryColumn::ALLOCATION_ALLOCATED,
OpMemoryColumn::ALLOCATION_RESERVE, OpMemoryColumn::ALLOCATION_ACTIVE, OpMemoryColumn::RELEASE_ALLOCATED,
OpMemoryColumn::RELEASE_RESERVE, OpMemoryColumn::RELEASE_ACTIVE};
std::vector<std::string> GetStreamLists(std::string deviceId, std::string deviceIdColumnName);
bool ExecuteMemoryType(std::vector<std::string> &graphId, std::string &type);
bool ExecuteMemoryResourceType(std::string &type, std::string sql);
bool ExecuteOperatorSize(
Protocol::MemoryOperatorSizeParams &requestParams, double &min, double &max, std::string sql);
bool ExecuteStaticOperatorSize(
Protocol::StaticOperatorSizeParams &requestParams, double &min, double &max, const std::string &sql);
bool ExecuteComponentTotalNum(Protocol::MemoryComponentParams &requestParams, int64_t &totalNum, std::string &sql);
bool ExecuteStaticOperatorListTotalNum(
Protocol::StaticOperatorListParams &requestParams, int64_t &totalNum, std::string sql);
bool ExecuteQueryMemoryViewExecuteSql(Protocol::MemoryViewParams &requestParams,
std::vector<Protocol::ComponentDto> &componentDtoVec, std::vector<std::string> &streams, std::string &sql,
std::string deviceIdColumnName);
bool ExecuteQueryMemoryViewGetGraph(Protocol::MemoryViewParams &requestParams,
std::vector<Protocol::ComponentDto> &componentDtoVec, std::vector<std::string> &streams,
Protocol::MemoryViewData &operatorBody);
int64_t ExecuteOperatorDetail(Protocol::MemoryOperatorParams &requestParams,
std::vector<Protocol::MemoryOperator> &opDetails, std::string &sql);
bool ExecuteQueryEntireOperatorTable(Protocol::MemoryOperatorParams &requestParams,
std::vector<Protocol::MemoryOperator> &opDetails, const std::string &sql);
bool ExecuteComponentDetail(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryTableColumnAttr> &columnAttr,
std::vector<Protocol::MemoryComponent> &componentDetails, std::string &sql);
bool ExecuteQueryEntireComponentTable(Protocol::MemoryComponentParams &requestParams,
std::vector<Protocol::MemoryComponent> &componentDetails, std::string &sql);
bool ExecuteStaticOperatorGraph(Protocol::StaticOperatorGraphParams &requestParams,
Protocol::StaticOperatorGraphItem &graphItem, const std::string &totalSql, const std::string &graphStartSql,
const std::string &graphEndSql);
bool ExecuteStaticGraphTotalSize(
Protocol::StaticOperatorGraphParams &requestParams, const std::string &graphStartSql, double &maxIndex);
bool ExecuteStaticGraphStartIndex(Protocol::StaticOperatorGraphParams &requestParams,
const std::string &graphStartSql, std::map<int64_t, double> &graphSizeMap, int64_t &maxIndex);
bool ExecuteStaticGraphEndIndex(Protocol::StaticOperatorGraphParams &requestParams, const std::string &graphEndSql,
std::map<int64_t, double> &graphSizeMap, int64_t &maxIndex);
int64_t ExecuteStaticOperatorDetail(Protocol::StaticOperatorListParams &requestParams,
std::vector<Protocol::StaticOperatorItem> &opDetails, const std::string &sql);
bool ExecuteQueryEntireStaticOperatorTable(Protocol::StaticOperatorListParams &requestParams,
std::vector<Protocol::StaticOperatorItem> &opDetails, const std::string &sql);
void AddOperatorSql(Protocol::MemoryOperatorParams requestParams, std::string &sql);
void AddStableOperatorSql(Protocol::StaticOperatorListParams requestParams, std::string &sql);
std::string GetSelectOperatorMemoryFullColumnsWithCount(uint64_t baseTimestamp);
static std::string BuildQueryOperatorMemoryTimeCondition(const Protocol::MemoryOperatorParams &requestParams);
static std::string BuildQueryFiltersCondition(const FiltersParam &requestParams);
static std::string BuildQueryRangeFiltersCondition(const RangeFiltersParam &requestParams);
static std::string BuildQueryOrderByCondition(const OrderByParam &orderParam);
static void SqlBindQueryFilters(sqlite3_stmt *stmt, int &bindIndex, const FiltersParam ¶ms);
static void SqlBindQueryRangeFilters(sqlite3_stmt *stmt, int &bindIndex, const RangeFiltersParam ¶ms);
private:
void BuildOverallLinesComponentPoints(const Protocol::ComponentDto &item, const std::vector<std::string> &streams,
Protocol::MemoryPeak &peak, std::vector<double> &lines);
void BuildOverallLinesFrameworkPoints(const Protocol::ComponentDto &item, const std::vector<std::string> &streams,
Protocol::MemoryPeak &peak, std::vector<double> &lines);
void BuildOverallLinesWorkspacePoints(const Protocol::ComponentDto &item, const std::vector<std::string> &streams,
Protocol::MemoryPeak &peak, std::vector<double> &lines);
void GetOverallLines(const componentDtoVector &componentDtoVec, std::vector<double> &lines,
std::vector<std::string> &legends, Protocol::MemoryPeak &peak, const std::vector<std::string> &streams);
void GetOverallLinesLegends(const componentDtoVector &componentDtoVec, std::vector<std::string> &legends,
Protocol::MemoryPeak &peak, const std::vector<std::string> &streams);
std::string GetPeakMemory(const Protocol::MemoryPeak &peak, const std::vector<std::string> &streams);
void GetComponentLines(const componentDtoVector &componentDtoVec, std::vector<double> &lines,
std::vector<std::string> &legends, Protocol::MemoryPeak &peak, const std::vector<std::string> &streams);
void GetComponentLinesLegends(
const componentDtoVector &componentDtoVec, std::vector<std::string> &legends, Protocol::MemoryPeak &peak);
void InsertSize(std::vector<double> &points, const Protocol::ComponentDto &item);
void InsertStringNull(std::vector<double> &points, const int times);
void GetStreamLines(const componentDtoVector &componentDtoVec, std::vector<double> &lines,
std::vector<std::string> &legends, Protocol::MemoryPeak &peak, const std::vector<std::string> &streams);
int64_t QueryOperatorDetailByStepWithCount(sqlite3_stmt *stmt, std::vector<Protocol::MemoryOperator> &operators);
std::string GetCurveSql(const Protocol::MemoryViewParams &requestParams, std::string &sql) const;
static std::string ConvertTimestampStr(const std::string ×tampStr);
};
};
}
}
#endif