* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef PROFILER_SERVER_DATABASEMANAGER_H
#define PROFILER_SERVER_DATABASEMANAGER_H
#include <map>
#include <memory>
#include <mutex>
#include <unordered_set>
#include "DBConnectionPool.h"
#include "TextTraceDatabase.h"
#include "DbTraceDataBase.h"
#include "TextClusterDatabase.h"
#include "VirtualMemoryDataBase.h"
#include "MemScopeDatabase.h"
#include "MemSnapshotDatabase.h"
#include "KernelParse.h"
namespace Dic {
namespace Module {
namespace Timeline {
using namespace Dic::Module::FullDb;
enum class DatabaseType { TRACE, SUMMARY, MEMORY, MEM_SCOPE, MEM_SNAPSHOT };
enum class DataType { TEXT, DB };
enum class FileType { MS_PROF, PYTORCH };
class DataBaseManager {
public:
static DataBaseManager &Instance();
DataBaseManager(const DataBaseManager &) = delete;
DataBaseManager &operator=(const DataBaseManager &) = delete;
DataBaseManager(DataBaseManager &&) = delete;
DataBaseManager &operator=(DataBaseManager &&) = delete;
bool CreateTraceConnectionPool(const std::string &rankId, const std::string &dbPath);
std::shared_ptr<VirtualTraceDatabase> GetTraceDatabaseByRankId(const std::string &rankId);
std::shared_ptr<VirtualTraceDatabase> GetTraceDatabaseByFileId(const std::string &fileId);
std::vector<DBConnectionPool<VirtualTraceDatabase> *> GetAllTraceDatabase();
std::vector<std::string> GetAllRankId();
void Clear();
void Clear(DatabaseType type);
void EraseClusterDb(const std::string &uniqueKey);
void ClearClusterDb();
void ReleaseDatabaseByRankId(const std::string &rankId);
void ReleaseDatabaseByFileId(const std::string &fileId);
bool HasRankId(DatabaseType type, const std::string &rankId);
void CreateClusterConnectionPool(const std::string &projectPath, const std::string &dbPath, DataType type);
std::shared_ptr<VirtualClusterDatabase> GetClusterDatabase(const std::string &uniqueKey);
std::vector<std::shared_ptr<VirtualClusterDatabase>> GetAllClusterDatabase();
std::shared_ptr<Memory::VirtualMemoryDataBase> CreateMemoryDataBase(
const std::string &rankId, const std::string &dbPath);
std::shared_ptr<Memory::VirtualMemoryDataBase> GetMemoryDatabaseByRankId(const std::string &rankId);
std::shared_ptr<Memory::VirtualMemoryDataBase> GetMemoryDatabaseByFileId(const std::string &fileId);
std::vector<Memory::VirtualMemoryDataBase *> GetAllMemoryDatabase();
std::shared_ptr<FullDb::MemScopeDatabase> GetMemScopeDatabase(const std::string &fileId);
std::shared_ptr<FullDb::MemSnapshotDatabase> GetMemSnapshotDatabase(const std::string &fileId);
std::vector<FullDb::MemScopeDatabase *> GetAllMemScopeDatabase();
std::vector<FullDb::MemSnapshotDatabase *> GetAllMemSnapshotDatabase();
std::shared_ptr<Summary::VirtualSummaryDataBase> GetSummaryDatabaseByRankId(const std::string &rankId);
std::shared_ptr<Summary::VirtualSummaryDataBase> GetSummaryDatabaseWithCluster(
const std::string &cluster, const std::string &rankId);
std::shared_ptr<Summary::VirtualSummaryDataBase> GetSummaryDataBaseByFileId(const std::string &fileId);
std::shared_ptr<Summary::VirtualSummaryDataBase> CreateSummaryDatabase(
const std::string &rankId, const std::string &dbPath);
std::vector<Summary::VirtualSummaryDataBase *> GetAllSummaryDatabase();
std::string GetDbPathByRankId(const std::string &rankId);
std::shared_ptr<VirtualTraceDatabase> GetTraceDatabaseWithOutHost(const std::string &rankId);
std::shared_ptr<VirtualTraceDatabase> GetTraceDatabaseInCluster(
const std::string &clusterPath, const std::string &rankId);
DataType GetDataType(const std::string &fileId);
DataType GetDataTypeByRank(const std::string &rankId);
void SetDataType(DataType type, const std::string &fileId);
FileType GetFileType(const std::string &fileId);
FileType GetFileTypeByRankId(const std::string &rankId);
void SetFileType(FileType type, const std::string &fileId);
bool ResetBaseline(bool force);
void SetDbPathMapping(const std::string &rankId, const std::string &dbPath, const std::string &hostId);
bool IsContainDatabasePath(const std::string &databasePath);
std::string GetDeviceIdFromRankId(const std::string &rankId);
inline std::vector<std::string> GetDbPathByHost(const std::string &id) {
if (host2DbPath.find(id) != host2DbPath.end()) {
return host2DbPath[id];
}
return {};
}
std::string GetRankIdByFileId(const std::string &fileId);
std::string GetAnyTraceDatabaseId();
std::string GetFileIdByRankId(const std::string &rankId) const;
void UpdateRankIdToDeviceId(const std::string &fileId, const std::string &rankId, const std::string &deviceId);
void SetRankIdFileIdMapping(const std::string &rankId, const std::string &fileId);
private:
using RankId = std::string;
using FileId = std::string;
using DbPath = FileId;
using HostId = std::string;
using ClusterProjectPath = std::string;
using ClusterDbPath = std::string;
DataBaseManager() = default;
~DataBaseManager() = default;
std::recursive_mutex mutex;
std::unordered_map<FileId, DataType> dataTypeMap;
std::unordered_map<FileId, FileType> fileTypeMap;
std::map<std::string, std::recursive_mutex> dbMutexMap;
std::map<RankId, DbPath> dbFilePathMap;
std::map<RankId, FileId> rankId2FileIdMap;
std::map<FileId, RankId> fileIdToRankIdMap;
std::map<ClusterProjectPath, ClusterDbPath> clusterProject2DbPathMap;
std::map<HostId, std::vector<DbPath>> host2DbPath;
std::unordered_set<std::string> databasePathSet;
std::map<FileId, std::shared_ptr<DBConnectionPool<VirtualTraceDatabase>>> traceDatabaseMap;
std::map<ClusterDbPath, std::shared_ptr<DBConnectionPool<VirtualClusterDatabase>>> clusterDatabaseMap;
std::map<RankId, std::shared_ptr<Memory::VirtualMemoryDataBase>> memoryDatabaseMap;
std::map<FileId, std::shared_ptr<FullDb::MemScopeDatabase>> memScopeDatabaseMap;
std::map<FileId, std::shared_ptr<FullDb::MemSnapshotDatabase>> memSnapshotDatabaseMap;
std::map<RankId, std::shared_ptr<Summary::VirtualSummaryDataBase>> summaryDatabaseMap;
std::map<RankId, std::shared_ptr<Memory::VirtualMemoryDataBase>> memoryBaselineDatabaseMap;
std::map<RankId, std::shared_ptr<Summary::VirtualSummaryDataBase>> summaryBaselineDatabaseMap;
std::map<std::string, std::string> rankIdToDeviceIdMap;
std::recursive_mutex &GetDbMutex(const std::string &fileId);
void SetClusterProjectDbPathMapping(const std::string &projectPath, const std::string &dbPath);
};
}
}
}
#endif