* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <map>
#include <fstream>
#include <sstream>
#include <cstring>
#include <numeric>
#include <re2/re2.h>
#include "include/ErrorCode.h"
#include "include/Macro.h"
#include "utils/FileUtils.h"
#include "base/ErrorInfosManager.h"
#include "DebuggerConfigFieldMap.h"
#include "DebuggerConfig.h"
namespace MindStudioDebugger {
template<typename T>
DebuggerErrno ParseJsonBaseObj2Var(const nlohmann::json& content, const std::string& field, T& output,
bool mandatory = false)
{
nlohmann::json::const_iterator iter = content.find(field);
if (iter == content.end()) {
if (mandatory) {
return DebuggerErrno::ERROR_FIELD_NOT_EXISTS;
} else {
return DebuggerErrno::OK;
}
}
try {
output = iter->get<T>();
return DebuggerErrno::OK;
} catch (const nlohmann::detail::type_error& e) {
return DebuggerErrno::ERROR_INVALID_FORMAT;
}
}
template<typename T>
DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content,
const std::string& field,
const std::map<int32_t, std::string>& enum2name,
T& output,
bool mandatory = false)
{
DebuggerErrno ret;
std::string value;
ret = ParseJsonBaseObj2Var<std::string>(content, field, value, true);
if (ret == DebuggerErrno::ERROR_FIELD_NOT_EXISTS && !mandatory) {
return DebuggerErrno::OK;
}
if (ret != DebuggerErrno::OK) {
return ret;
}
int32_t enumId = GetEnumIdFromName(enum2name, value);
if (enumId == DEBUGGER_INVALID_ENUM) {
return DebuggerErrno::ERROR_UNKNOWN_VALUE;
}
output = static_cast<T>(enumId);
return DebuggerErrno::OK;
}
#define PARSE_OPTIONAL_FIELD_CHECK_RET(content, field, output) \
{ \
if (ParseJsonBaseObj2Var<decltype(output)>(content, field, output) != DebuggerErrno::OK) { \
LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, \
"Field " + std::string(field) + " cannot be parsed."); \
} \
}
#define PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, field, transMap, output) \
{ \
if (ParseJsonStringAndTrans<decltype(output)>(content, field, transMap, output) != DebuggerErrno::OK) { \
LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, \
"Value of field " + std::string(field) + " is unknown."); \
} \
}
static bool DebuggerCfgParseUIntRangeGetBorder(const std::string& exp, uint32_t& left, uint32_t& right)
{
if (std::count(exp.begin(), exp.end(), '-') != 1) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT,
"When using a range expression, it should be formatted as \"a-b\".");
return false;
}
std::istringstream iss(exp);
char dash;
iss >> left >> dash >> right;
if (iss.fail() || dash != '-') {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT,
"When using a range expression, it should be formatted as \"a-b\".");
return false;
}
if (left >= right) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT,
"When using a range expression, the left border should be smaller than the right.");
return false;
}
return true;
}
void DebuggerCfgParseUIntRange(const nlohmann::json& content, const std::string& name, std::vector<uint32_t>& range)
{
if (!content.contains(name)) {
return;
}
const nlohmann::json& array = content[name];
if (!array.is_array()) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, name + " should be empty or an array.");
return;
}
range.clear();
range.reserve(array.size());
std::vector<std::pair<uint32_t, uint32_t>> buf;
buf.reserve(array.size());
uint32_t realLen = 0;
for (const auto& element : array) {
if (element.is_number()) {
range.emplace_back(element.get<uint32_t>());
realLen++;
} else if (element.is_string()) {
std::string exp = element.get<std::string>();
uint32_t begin;
uint32_t end;
if (!DebuggerCfgParseUIntRangeGetBorder(exp, begin, end)) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "Failed to parse " + name + ".");
return;
}
uint32_t rangeSize = end - begin;
if (realLen > UINT32_MAX - (rangeSize + 1)) {
LOG_ERROR(DebuggerErrno::ERROR_VALUE_OVERFLOW, name + " size exceeds limit");
return;
}
realLen += (rangeSize + 1);
buf.emplace_back(std::make_pair(begin, end));
}
}
constexpr uint32_t maxEleNum = 65536;
if (realLen > maxEleNum) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT,
"When using a range expression in " + name + ", maximum of 65536 elements can be expressed.");
return;
}
if (!buf.empty()) {
range.reserve(realLen);
for (const auto& border : buf) {
for (uint32_t i = border.first; i <= border.second; ++i) {
range.emplace_back(i);
}
}
}
return;
}
void CommonCfgParseTasks(const nlohmann::json& content, std::vector<DebuggerTaskType>& tasks)
{
std::vector<std::string> taskNameList;
std::string taskName;
DebuggerErrno ret;
ret = ParseJsonBaseObj2Var<std::string>(content, TASK, taskName, true);
if (ret == DebuggerErrno::ERROR_FIELD_NOT_EXISTS) {
ret = ParseJsonBaseObj2Var<std::vector<std::string>>(content, TASKS, taskNameList, true);
} else {
taskNameList.emplace_back(taskName);
}
if (ret != DebuggerErrno::OK) {
LOG_ERROR(ret, "Value of field task(s) should be string or list.");
return;
}
for (auto& ele : taskNameList) {
int32_t enumId = GetEnumIdFromName(TASK_TYPE_ENUM_2_NAME, ele);
if (enumId == DEBUGGER_INVALID_ENUM) {
LOG_WARNING(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Task " + ele + " is unknown.");
continue;
}
if (!ELE_IN_VECTOR(tasks, static_cast<DebuggerTaskType>(enumId))) {
tasks.emplace_back(static_cast<DebuggerTaskType>(enumId));
}
}
return;
}
constexpr char REGEX_PREFIX[] = "name-regex(";
constexpr char REGEX_SUFFIX[] = ")";
constexpr size_t REGEX_PREFIX_LEN = sizeof(REGEX_PREFIX) - 1;
constexpr size_t REGEX_SUFFIX_LEN = sizeof(REGEX_SUFFIX) - 1;
constexpr size_t REGEX_INDEX = REGEX_PREFIX_LEN - 1;
void KernelListMatcher::Parse(const std::vector<std::string>& expressions)
{
for (auto& expression : expressions) {
size_t len = expression.size();
if (len < REGEX_SUFFIX_LEN) {
LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid expression when kernal match, too short: " + expression);
continue;
}
if (strncmp(expression.c_str(), REGEX_PREFIX, REGEX_PREFIX_LEN) == 0 &&
strncmp(expression.c_str() + (len - REGEX_SUFFIX_LEN), REGEX_SUFFIX, REGEX_SUFFIX_LEN) == 0) {
regexList.emplace_back(expression.substr(REGEX_INDEX, len - REGEX_INDEX));
} else {
fullNameList.emplace_back(expression);
}
}
}
std::vector<std::string> KernelListMatcher::GenRealKernelList(const char** fullKernelList) const
{
std::vector<std::string> output;
if (this->Empty() || fullKernelList == nullptr) {
return output;
}
output = fullNameList;
for (auto& pattern : regexList) {
re2::RE2 reg(pattern, re2::RE2::Quiet);
if (reg.ok()) {
for (const char** ss = fullKernelList; *ss != nullptr; ++ss) {
std::string ret;
if (re2::RE2::FullMatch(*ss, reg, &ret)) {
output.emplace_back(*ss);
}
}
}
}
if (output.empty()) {
output.emplace_back("");
LOG_INFO("No kernel matches, so nothing will be dumped.");
}
return output;
}
void CommonCfg::Parse(const nlohmann::json& content)
{
CommonCfgParseTasks(content, tasks);
if (tasks.empty()) {
return;
}
PARSE_OPTIONAL_FIELD_CHECK_RET(content, OUTPUT_PATH, outputPath);
outputPath = FileUtils::GetAbsPath(outputPath);
DebuggerCfgParseUIntRange(content, RANK, rank);
DebuggerCfgParseUIntRange(content, STEP, step);
PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, LEVEL, DEBUGGER_LEVEL_ENUM_2_NAME, level);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, SEED, seed);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, IS_DETERMINISTIC, isDeterministic);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, ENABLE_DATALOADER, enableDataloader);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, ACL_CONFIG, aclConfig);
}
void DebuggerCfgParseDataMode(const nlohmann::json& content, DebuggerDataDirection& direction, DebuggerDataInOut& inout)
{
std::vector<std::string> buf;
bool fw;
bool bw;
bool in;
bool out;
bool all;
direction = DebuggerDataDirection::DIRECTION_BOTH;
inout = DebuggerDataInOut::INOUT_BOTH;
PARSE_OPTIONAL_FIELD_CHECK_RET(content, DATA_MODE, buf);
all = static_cast<bool>(std::find(buf.begin(), buf.end(), DATA_MODE_ALL) != buf.end());
if (buf.empty() || all) {
return;
}
fw = static_cast<bool>(std::find(buf.begin(), buf.end(), DIRECTION_FORWARD) != buf.end());
bw = static_cast<bool>(std::find(buf.begin(), buf.end(), DIRECTION_BACKWARD) != buf.end());
in = static_cast<bool>(std::find(buf.begin(), buf.end(), INOUT_INPUT) != buf.end());
out = static_cast<bool>(std::find(buf.begin(), buf.end(), INOUT_OUTPUT) != buf.end());
if (fw != bw) {
if (fw) {
direction = DebuggerDataDirection::DIRECTION_FORWARD;
} else {
direction = DebuggerDataDirection::DIRECTION_BACKWARD;
}
}
if (in != out) {
if (in) {
inout = DebuggerDataInOut::INOUT_INPUT;
} else {
inout = DebuggerDataInOut::INOUT_OUTPUT;
}
}
return;
}
void StatisticsCfgParseSummary(const nlohmann::json& content, std::vector<DebuggerSummaryOption>& summaryOption)
{
DebuggerErrno ret;
std::string mode = STATISTICS;
std::vector<std::string> modeListName;
ret = ParseJsonBaseObj2Var<std::string>(content, SUMMARY_MODE, mode);
if (ret == DebuggerErrno::OK) {
if (mode == STATISTICS) {
summaryOption.push_back(DebuggerSummaryOption::MAX);
summaryOption.push_back(DebuggerSummaryOption::MIN);
summaryOption.push_back(DebuggerSummaryOption::MEAN);
summaryOption.push_back(DebuggerSummaryOption::L2NORM);
} else if (mode == MD5) {
summaryOption.push_back(DebuggerSummaryOption::MD5);
} else {
LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Summary mode " + mode + " is unknown.");
}
return;
}
ret = ParseJsonBaseObj2Var<std::vector<std::string>>(content, SUMMARY_MODE, modeListName);
if (ret != DebuggerErrno::OK) {
LOG_ERROR(ret, "Value of field summary_mode should be string or list.");
return;
}
if (modeListName.empty()) {
summaryOption.push_back(DebuggerSummaryOption::MAX);
summaryOption.push_back(DebuggerSummaryOption::MIN);
summaryOption.push_back(DebuggerSummaryOption::MEAN);
summaryOption.push_back(DebuggerSummaryOption::L2NORM);
return;
}
for (auto& ele : modeListName) {
int32_t enumId = GetEnumIdFromName(SUMMARY_OPTION_ENUM_2_NAME, ele);
if (enumId == DEBUGGER_INVALID_ENUM) {
LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Summary mode " + ele + " is unknown.");
return;
}
summaryOption.push_back(static_cast<DebuggerSummaryOption>(enumId));
}
return;
}
void StatisticsCfg::Parse(const nlohmann::json& content)
{
std::vector<std::string> filter;
PARSE_OPTIONAL_FIELD_CHECK_RET(content, SCOPE, scope);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, LIST, filter);
filter.erase(std::remove_if(filter.begin(), filter.end(),
[](const std::string& s) { return s.find_first_not_of(' ') == std::string::npos; }),
filter.end());
list = std::move(filter);
if (DebuggerConfig::GetInstance().GetDebugLevel() == DebuggerLevel::L2) {
matcher.Parse(list);
}
DebuggerCfgParseDataMode(content, direction, inout);
StatisticsCfgParseSummary(content, summaryOption);
}
void DumpTensorCfg::Parse(const nlohmann::json& content)
{
std::vector<std::string> filter;
PARSE_OPTIONAL_FIELD_CHECK_RET(content, SCOPE, scope);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, LIST, filter);
filter.erase(std::remove_if(filter.begin(), filter.end(),
[](const std::string& s) { return s.find_first_not_of(' ') == std::string::npos; }),
filter.end());
list = std::move(filter);
if (DebuggerConfig::GetInstance().GetDebugLevel() == DebuggerLevel::L2) {
matcher.Parse(list);
}
DebuggerCfgParseDataMode(content, direction, inout);
PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, FILE_FORMAT, DUMP_FILE_FORMAT_ENUM_2_NAME, fileFormat);
PARSE_OPTIONAL_FIELD_CHECK_RET(content, BACKWARD_INPUT, backwardInput);
}
void OverflowCheckCfg::Parse(const nlohmann::json& content)
{
PARSE_OPTIONAL_FIELD_CHECK_RET(content, OVERFLOW_NUMS, overflowNums);
PARSE_OPTIONAL_FIELD_TRANS_CHECK_RET(content, CHECK_MODE, OP_CHECK_LEVEL_ENUM_2_NAME, checkMode);
}
void DebuggerConfig::Reset()
{
LOG_INFO("Reset configuration.");
commonCfg = CommonCfg();
statisticCfg.reset();
dumpTensorCfg.reset();
overflowCheckCfg.reset();
loaded = false;
}
void DebuggerConfig::Parse()
{
std::ifstream cfgFile;
DebuggerErrno ret = FileUtils::OpenFile(cfgFilePath_, cfgFile);
if (ret != DebuggerErrno::OK) {
LOG_ERROR(ret, "Failed to open file " + cfgFilePath_ + ".");
return;
}
nlohmann::json content;
nlohmann::json::const_iterator iter;
try {
cfgFile >> content;
} catch (const nlohmann::json::parse_error& e) {
LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "Failed to parse json file " + cfgFilePath_ + ".");
return;
}
commonCfg.Parse(content);
#define PARSE_SUBTASK_CONFIG(enumeration, name, member, basetype) \
do { \
if (ELE_IN_VECTOR(commonCfg.tasks, enumeration)) { \
iter = content.find(name); \
if (iter != content.end()) { \
member = std::make_shared<basetype>(); \
((member)->Parse(*(iter))); \
} \
} \
} while (0)
PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_STATISTICS, TASK_STATISTICS, statisticCfg, StatisticsCfg);
PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_DUMP_TENSOR, TASK_DUMP_TENSOR, dumpTensorCfg, DumpTensorCfg);
PARSE_SUBTASK_CONFIG(DebuggerTaskType::TASK_OVERFLOW_CHECK, TASK_OVERFLOW_CHECK, overflowCheckCfg, OverflowCheckCfg);
#undef PARSE_SUBTASK_CONFIG
return;
}
int32_t DebuggerConfig::LoadConfig(const std::string& framework, const std::string& cfgFilePath)
{
if (loaded) {
LOG_WARNING(DebuggerErrno::ERROR, "Repeated initialization, which may lead to errors.");
Reset();
}
cfgFilePath_ = FileUtils::GetAbsPath(cfgFilePath);
if (cfgFilePath_ == "") {
LOG_ERROR(DebuggerErrno::ERROR_CANNOT_PARSE_PATH, "Cannot parse path " + cfgFilePath + ".");
return -1;
}
DebuggerErrno ret = FileUtils::CheckFileBeforeRead(cfgFilePath_, "r", FileType::JSON);
if (ret != DebuggerErrno::OK) {
LOG_ERROR(ret, "Config file " + cfgFilePath + " is invalid.");
return -1;
}
int32_t enumId = GetEnumIdFromName(FRAMEWORK_ENUM_2_NAME, framework);
if (enumId == DEBUGGER_INVALID_ENUM) {
LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_VALUE, "Unknown framework " + framework + ".");
return -1;
}
framework_ = static_cast<DebuggerFramework>(enumId);
Parse();
if (ErrorInfosManager::GetTopErrLevelInDuration() >= DebuggerErrLevel::LEVEL_ERROR) {
LOG_ERROR(DebuggerErrno::ERROR, "Failed to parse config file " + cfgFilePath + ".");
return -1;
}
CheckConfigValidity();
if (ErrorInfosManager::GetTopErrLevelInDuration() >= DebuggerErrLevel::LEVEL_ERROR) {
LOG_ERROR(DebuggerErrno::ERROR, "Config file " + cfgFilePath + " is invalid.");
return -1;
}
loaded = true;
return 0;
}
bool DebuggerConfig::CheckConfigValidity()
{
if (commonCfg.tasks.empty()) {
LOG_WARNING(DebuggerErrno::ERROR, "No task configured. MsProbe will do nothing.");
return true;
}
return true;
}
}