* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <gtest/gtest.h>
#include "vector"
#include "../../TestSuit.h"
#include "DataBaseManager.h"
#include "OperatorProtocolRequest.h"
#include "OperatorProtocolResponse.h"
class OperatorTestSuit : public TestSuit {};
const std::string GROUP_OPERATOR = "Operator";
const std::string GROUP_OPERATOR_TYPE = "Operator Type";
const std::string GROUP_INPUT_SHAPE = "Input Shape";
TEST_F(TestSuit, QueryOperatorDurationInfoByOpType) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorDurationReqParams params = {"0", "0", GROUP_OPERATOR_TYPE, 15};
std::vector<Dic::Protocol::OperatorDurationRes> data = {};
bool result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::CATEGORY, data);
EXPECT_EQ(result, true);
int size = 8;
EXPECT_EQ(data.size(), size);
data.clear();
result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::COMPUTE_UNIT, data);
EXPECT_EQ(result, true);
int unitSize = 6;
EXPECT_EQ(data.size(), unitSize);
}
TEST_F(TestSuit, QueryOperatorDurationInfoByOpTypeAndInputShape) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorDurationReqParams params = {"0", "0", GROUP_INPUT_SHAPE, 15};
std::vector<Dic::Protocol::OperatorDurationRes> data = {};
bool result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::CATEGORY, data);
EXPECT_EQ(result, true);
int size = 9;
EXPECT_EQ(data.size(), size);
data.clear();
result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::COMPUTE_UNIT, data);
EXPECT_EQ(result, true);
int unitSize = 6;
EXPECT_EQ(data.size(), unitSize);
}
TEST_F(TestSuit, QueryOperatorDurationInfoByOperator) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorDurationReqParams params = {"0", "0", GROUP_OPERATOR, 15};
std::vector<Dic::Protocol::OperatorDurationRes> data = {};
bool result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::CATEGORY, data);
EXPECT_EQ(result, true);
int size = 15;
EXPECT_EQ(data.size(), size);
data.clear();
result = db->QueryOperatorDurationInfo(params, Dic::Protocol::QueryType::COMPUTE_UNIT, data);
EXPECT_EQ(result, true);
int cnt = 6;
EXPECT_EQ(data.size(), cnt);
}
TEST_F(TestSuit, QueryOperatorStatisticInfoByOpType) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorStatisticReqParams reqParams = {false, "0", "0", GROUP_OPERATOR_TYPE, 15, 1, 10, "", ""};
Dic::Protocol::OperatorStatisticInfoResponse response = {};
bool result = db->QueryOperatorStatisticInfo(reqParams, response);
EXPECT_EQ(result, true);
int total = 8;
EXPECT_EQ(response.total, total);
EXPECT_EQ(response.data.size(), total);
}
TEST_F(TestSuit, QueryAllOperatorStatisticInfoByOpType) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorStatisticReqParams reqParams = {true, "0", "0", GROUP_OPERATOR_TYPE, 15, 1, 10, "", ""};
Dic::Protocol::OperatorStatisticInfoResponse response = {};
std::vector<Protocol::OperatorStatisticInfoRes> compareRes;
bool result = db->QueryAllOperatorStatisticInfo(reqParams, compareRes);
EXPECT_EQ(result, true);
int total = 8;
EXPECT_EQ(compareRes.size(), total);
}
TEST_F(TestSuit, QueryAllOperatorStatisticInfoByOpTypeAndInputShape) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorStatisticReqParams reqParams = {false, "0", "0", GROUP_INPUT_SHAPE, 15, 1, 5, "", ""};
Dic::Protocol::OperatorStatisticInfoResponse response = {};
bool result = db->QueryOperatorStatisticInfo(reqParams, response);
EXPECT_EQ(result, true);
int total = 9;
EXPECT_EQ(response.total, total);
int size = 5;
EXPECT_EQ(response.data.size(), size);
}
TEST_F(TestSuit, QueryOperatorDetailInfoByOperator) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorStatisticReqParams reqParams = {false, "0", "0", GROUP_OPERATOR, 15, 1, 10, "", ""};
Dic::Protocol::OperatorDetailInfoResponse response = {};
bool result = db->QueryOperatorDetailInfo(reqParams, response);
EXPECT_EQ(result, true);
int total = 15;
EXPECT_EQ(response.total, total);
EXPECT_EQ(response.level, "l1");
int size = 10;
EXPECT_EQ(response.data.size(), size);
}
TEST_F(TestSuit, QueryAllOperatorDetailInfoByOperator) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorStatisticReqParams reqParams = {false, "0", "0", GROUP_OPERATOR, 15, 1, 10, "", ""};
Dic::Protocol::OperatorDetailInfoResponse response = {};
std::vector<Protocol::OperatorDetailInfoRes> baselineRes;
bool result = db->QueryAllOperatorDetailInfo(reqParams, baselineRes, response.level);
EXPECT_EQ(result, true);
int total = 16;
EXPECT_EQ(response.level, "l1");
EXPECT_EQ(baselineRes.size(), total);
}
TEST_F(TestSuit, QueryOperatorMoreInfoByOpType) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorMoreInfoReqParams reqParams = {
"0", "0", GROUP_OPERATOR_TYPE, 15, "Cast", "", "", "AI_CORE", 1, 10, "", ""};
Dic::Protocol::OperatorMoreInfoResponse response = {};
bool result = db->QueryOperatorMoreInfo(reqParams, response);
EXPECT_EQ(result, true);
int64_t total = 1;
EXPECT_EQ(response.total, total);
EXPECT_EQ(response.level, "l1");
EXPECT_EQ(response.data.size(), total);
}
TEST_F(TestSuit, QueryOperatorMoreInfoByInputShape) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
Dic::Protocol::OperatorMoreInfoReqParams reqParams = {
"0", "0", GROUP_INPUT_SHAPE, 15, "", "NonZero", R"("""16""")", "MIX_AIV", 1, 10, "", ""};
Dic::Protocol::OperatorMoreInfoResponse response = {};
bool result = db->QueryOperatorMoreInfo(reqParams, response);
EXPECT_EQ(result, true);
int total = 0;
EXPECT_EQ(response.total, total);
EXPECT_EQ(response.level, "l1");
EXPECT_EQ(response.data.size(), total);
}
TEST_F(TestSuit, QueryBandwidthContentionMatMulDataTest) {
auto db = Dic::Module::Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId("0");
std::vector<Dic::Module::BandwidthContentionMatMulInfo> res;
bool result = db->QueryBandwidthContentionMatMulData(res);
ASSERT_EQ(result, true);
ASSERT_EQ(res.size(), 0);
}