* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <gtest/gtest.h>
#include "OperatorMemoryService.h"
#include "../../../DatabaseTestCaseMockUtil.h"
#include "OperatorTable.h"
#include "OpMemoryTable.h"
using namespace Dic::Module::Memory;
class OperatorMemoryServiceTest : public ::testing::Test {
protected:
class MemoryTableDefaultMock {
public:
void SetDb(sqlite3 *dbPtr) { db = dbPtr; }
protected:
sqlite3 *db = nullptr;
};
class OperatorTableMock : public OperatorTable, public MemoryTableDefaultMock {
protected:
void ExcuteQuery(const std::string &fileId, std::vector<OperatorPO> &result) override {
OperatorTable::ExcuteQuery(db, result);
ClearThreadLocal();
}
};
class OpMemoryTableMock : public OpMemoryTable, public MemoryTableDefaultMock {
protected:
void ExcuteQuery(const std::string &fileId, std::vector<OpMemoryPO> &result) override {
OpMemoryTable::ExcuteQuery(db, result);
ClearThreadLocal();
}
};
const std::string operatorSql =
"CREATE TABLE operator (name TEXT, size INTEGER, allocationTime INTEGER, releaseTime INTEGER, "
"activeReleaseTime INTEGER, duration INTEGER, activeDuration INTEGER, allocationTotalAllocated INTEGER, "
"allocationTotalReserved INTEGER, allocationTotalActive INTEGER, releaseTotalAllocated INTEGER, "
"releaseTotalReserved INTEGER, releaseTotalActive INTEGER, streamPtr TEXT, deviceId TEXT);";
const std::string opMemorySql =
"CREATE TABLE OP_MEMORY (name INTEGER, size INTEGER, allocationTime INTEGER, releaseTime INTEGER, "
"activeReleaseTime INTEGER, duration INTEGER, activeDuration INTEGER, allocationTotalAllocated INTEGER, "
"allocationTotalReserved INTEGER, allocationTotalActive INTEGER, releaseTotalAllocated INTEGER, "
"releaseTotalReserved INTEGER, releaseTotalActive INTEGER, streamPtr INTEGER, deviceId INTEGER);";
};
* text场景根据id查询算子内存分配信息
*/
TEST_F(OperatorMemoryServiceTest, TestComputeAllocationTimeByIdWhenTextSceneThenReturnTextData) {
std::unique_ptr<OperatorTableMock> operatorTable = std::make_unique<OperatorTableMock>();
sqlite3 *db = nullptr;
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::OpenDB(db);
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::CreateTable(db, operatorSql);
const std::string operatorData =
"INSERT INTO \"main\".\"operator\" (\"name\", \"size\", \"allocationTime\", \"releaseTime\", "
"\"activeReleaseTime\", \"duration\", \"activeDuration\", \"allocationTotalAllocated\", "
"\"allocationTotalReserved\", "
"\"allocationTotalActive\", \"releaseTotalAllocated\", \"releaseTotalReserved\", \"releaseTotalActive\", "
"\"streamPtr\", "
"\"deviceId\") VALUES ('aten::empty_strided', 32.5, 1724670453465053360, 1724670453467680330, 2626.97, "
"1724670453467680020, 2626.66, 18180.814453125, 25750, 18180.814453125, 18181.8559570313, 25750, "
"18181.8559570313, "
"'187651271017536', '0');";
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::InsertData(db, operatorData);
operatorTable->SetDb(db);
OperatorMemoryService operatorMemoryService(std::move(operatorTable));
OperatorDomain target = operatorMemoryService.ComputeAllocationTimeById("lll", "1");
const uint64_t expectAllocationTime = 1724670453465053360;
EXPECT_EQ(target.allocationTime, expectAllocationTime);
EXPECT_EQ(target.metaType, "TEXT");
}
* db场景根据id查询算子内存分配信息
*/
TEST_F(OperatorMemoryServiceTest, TestComputeAllocationTimeByIdWhenDbSceneThenReturnDbData) {
std::unique_ptr<OperatorTableMock> operatorTable = std::make_unique<OperatorTableMock>();
std::unique_ptr<OpMemoryTableMock> opMemoryTable = std::make_unique<OpMemoryTableMock>();
sqlite3 *db = nullptr;
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::OpenDB(db);
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::CreateTable(db, opMemorySql);
const std::string opMemoryData =
"INSERT INTO \"main\".\"OP_MEMORY\" (\"name\", \"size\", \"allocationTime\", \"releaseTime\", "
"\"activeReleaseTime\", \"duration\", \"activeDuration\", \"allocationTotalAllocated\", "
"\"allocationTotalReserved\", \"allocationTotalActive\", \"releaseTotalAllocated\", \"releaseTotalReserved\", "
"\"releaseTotalActive\", \"streamPtr\", \"deviceId\") VALUES (536870922, 4608, 1724670453468255710, "
"1724670453468599630, 1724670453468599320, 343920, 343610, 19065050112, 27000832000, 19065050112, 19065045504, "
"27000832000, 19065045504, 187651271017536, 0);";
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::InsertData(db, opMemoryData);
operatorTable->SetDb(db);
opMemoryTable->SetDb(db);
OperatorMemoryService operatorMemoryService(std::move(operatorTable), std::move(opMemoryTable));
OperatorDomain target = operatorMemoryService.ComputeAllocationTimeById("lll", "1");
const uint64_t expectAllocationTime = 1724670453468255710;
EXPECT_EQ(target.allocationTime, expectAllocationTime);
EXPECT_EQ(target.metaType, "PYTORCH_API");
}
* 都不存在
*/
TEST_F(OperatorMemoryServiceTest, TestComputeAllocationTimeByIdWhenDataNotExistThenMetaTypeIsEmpty) {
std::unique_ptr<OperatorTableMock> operatorTable = std::make_unique<OperatorTableMock>();
std::unique_ptr<OpMemoryTableMock> opMemoryTable = std::make_unique<OpMemoryTableMock>();
sqlite3 *db = nullptr;
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::OpenDB(db);
Dic::Global::PROFILER::MockUtil::DatabaseTestCaseMockUtil::CreateTable(db, opMemorySql);
operatorTable->SetDb(db);
opMemoryTable->SetDb(db);
OperatorMemoryService operatorMemoryService(std::move(operatorTable), std::move(opMemoryTable));
OperatorDomain target = operatorMemoryService.ComputeAllocationTimeById("lll", "1");
EXPECT_EQ(std::empty(target.metaType), true);
}