* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cmd_flag_info.h"
#include "base/err_msg.h"
#include <iostream>
#include <algorithm>
#include <memory>
#include <regex>
#include <set>
#include <unordered_set>
#include <string>
#include "mmpa/mmpa_api.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
namespace ge {
namespace {
const static std::unordered_set<std::string> kDeprecatedFlags = {"shape_generalized_build_mode"};
const static std::map<std::string, std::set<std::string>> kStrValueRange = {
{"sparsity", {"0", "1"}},
{"dump_mode", {"0", "1"}},
{"virtual_type", {"0", "1"}},
{"status_check", {"0", "1"}},
{"deterministic", {"0", "1"}},
{"deterministic_level", {"0", "1", "2"}},
{"external_weight", {"0", "1", "2"}},
{"display_model_info", {"0", "1"}},
{"atomic_clean_policy", {"0", "1"}},
{"disable_reuse_memory", {"0", "1"}},
{"framework", {"0", "1", "3", "5"}},
{"op_debug_level", {"0", "1", "2", "3", "4"}},
{"mode", {"0", "1", "3", "5", "6", "7","30"}},
{"core_type", {"AiCore", "VectorCore"}},
{"log", {"debug", "info", "warning", "error", "null"}},
{"op_compiler_cache_mode", {"enable", "disable", "force"}},
{"buffer_optimize", {"l1_optimize", "l2_optimize", "off_optimize"}},
{"shape_generalized_build_mode", {"shape_generalized", "shape_precise"}},
{"op_select_implmode",
{"high_precision", "high_performance", "high_precision_for_all", "high_performance_for_all"}},
{"precision_mode",
{"force_fp16", "force_fp32", "cube_fp16in_fp32out", "allow_mix_precision", "allow_fp32_to_fp16",
"must_keep_origin_dtype", "allow_mix_precision_fp16", "allow_mix_precision_bf16", "allow_fp32_to_bf16"}},
{"precision_mode_v2",
{"fp16", "origin", "cube_fp16in_fp32out", "mixed_float16", "mixed_bfloat16", "cube_hif8", "mixed_hif8"}},
};
}
namespace flgs {
#define GFPRINTF(fmt, ...) \
do { \
fprintf(stderr, "ERROR: " fmt "\n", ##__VA_ARGS__); \
} while (false)
namespace {
namespace FlagName {
const char HELP[] = "help";
}
namespace FlagIndex {
const int32_t COLON = ':';
const int32_t QUESTION_MARK = '?';
const int32_t HELP = 'h';
const int32_t CUSTOM_START_INDEX = 'z';
}
const char CMD_SHORT_OPTS[] = ":h";
class CmdFlagInfo;
using CmdFlagInfoMap = std::map<int32_t, std::shared_ptr<CmdFlagInfo>>;
static GfStatus StringToBool(const std::string &value, bool &out);
static GfStatus StringToInt32(const std::string &value, int32_t &out);
class CmdFlagInfo {
public:
CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name,
const std::string &default_val, const std::string &msg);
CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name,
const bool default_val, const std::string &msg);
CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name,
const int32_t default_val, const std::string &msg);
GfStatus SetFlagValue(const std::string &value);
int32_t GetFlagHasArg() const { return has_arg_; }
int32_t GetFlagIndex() const { return index_; }
const std::string& GetFlagName() const { return flag_name_; }
const std::string& GetFlagMsg() const { return msg_; }
std::string& GetFlagValueString() { return value_string_; }
bool& GetFlagValueBool() { return value_bool_; }
int32_t& GetFlagValueInt32() { return value_int32_; }
void PrintTypeError(const char *type);
void PrintValueError();
std::string PrintValueRange();
private:
enum class DataType {
Unknown,
String,
Bool,
Int32
};
int32_t has_arg_;
int32_t index_;
std::string flag_name_;
std::string default_val_;
std::string msg_;
DataType data_type_;
std::string value_string_;
bool value_bool_;
int32_t value_int32_;
};
CmdFlagInfo::CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name, const std::string &default_val,
const std::string &msg)
: has_arg_(has_arg),
index_(index),
flag_name_(flag_name),
default_val_(default_val),
msg_(msg),
data_type_(CmdFlagInfo::DataType::String),
value_string_(default_val) {}
CmdFlagInfo::CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name, const bool default_val,
const std::string &msg)
: has_arg_(has_arg),
index_(index),
flag_name_(flag_name),
default_val_(std::to_string(default_val)),
msg_(msg),
data_type_(CmdFlagInfo::DataType::Bool),
value_bool_(default_val) {}
CmdFlagInfo::CmdFlagInfo(int32_t has_arg, int32_t index, const std::string &flag_name, const int32_t default_val,
const std::string &msg)
: has_arg_(has_arg),
index_(index),
flag_name_(flag_name),
default_val_(std::to_string(default_val)),
msg_(msg),
data_type_(CmdFlagInfo::DataType::Int32),
value_int32_(default_val) {}
void CmdFlagInfo::PrintTypeError(const char *type) {
std::string reason = "The value type must be [" + std::string(type) + "].";
REPORT_PREDEFINED_ERR_MSG("E10003", std::vector<const char *>({"value", "parameter", "reason"}),
std::vector<const char *>({value_string_.c_str(), flag_name_.c_str(), reason.c_str()}));
}
std::string CmdFlagInfo::PrintValueRange() {
const auto it = kStrValueRange.find(flag_name_);
if (it != kStrValueRange.cend()) {
std::string result;
const auto &ranges = it->second;
for (const auto &range : ranges) {
result += "[" + range + "] ";
}
return result;
}
return "";
}
void CmdFlagInfo::PrintValueError() {
if ((flag_name_ == "status_check") || (flag_name_ == "deterministic") ||
(flag_name_ == "external_weight") || (flag_name_ == "display_model_info") ||
(flag_name_ == "atomic_clean_policy") || (flag_name_ == "dump_mode") || (flag_name_ == "disable_reuse_memory") ||
flag_name_ == "sparsity") {
REPORT_PREDEFINED_ERR_MSG("E10006", std::vector<const char *>({"value", "parameter"}),
std::vector<const char *>({value_string_.c_str(), flag_name_.c_str()}));
} else if (flag_name_ == "log") {
REPORT_PREDEFINED_ERR_MSG("E10010", std::vector<const char *>({"loglevel"}),
std::vector<const char *>({value_string_.c_str()}));
} else if (flag_name_ == "framework") {
const std::string support = "0(Caffe) or 1(MindSpore) or 3(TensorFlow) or 5(Onnx)";
REPORT_PREDEFINED_ERR_MSG("E10007", std::vector<const char *>({"parameter", "support"}),
std::vector<const char *>({"framework", support.c_str()}));
} else {
std::string parameter = "--" + flag_name_;
const std::string reason =
"The value is not within the range of values. The valid range is " + PrintValueRange() + ".";
REPORT_PREDEFINED_ERR_MSG("E10001", std::vector<const char *>({"value", "parameter", "reason"}),
std::vector<const char *>({value_string_.c_str(), parameter.c_str(), reason.c_str()}));
}
}
GfStatus CmdFlagInfo::SetFlagValue(const std::string &value)
{
value_string_ = value;
const auto it = kStrValueRange.find(flag_name_);
if ((it != kStrValueRange.cend()) && (it->second.count(value_string_) == 0U)) {
PrintValueError();
return GF_FAILED;
}
if ((data_type_ == DataType::Bool) && (StringToBool(value_string_, value_bool_) != GF_SUCCESS)) {
PrintTypeError("bool");
return GF_FAILED;
}
if ((data_type_ == DataType::Int32) && (StringToInt32(value_string_, value_int32_) != GF_SUCCESS)) {
PrintTypeError("int32");
return GF_FAILED;
}
if (kDeprecatedFlags.count(flag_name_) > 0U) {
GELOG_DEPRECATED(flag_name_);
}
return GF_SUCCESS;
}
GfStatus StringToBool(const std::string &value, bool &out)
{
const std::set<std::string> true_values = { "1", "t", "true", "y", "yes" };
const std::set<std::string> false_values = { "0", "f", "false", "n", "no" };
std::string val = value;
std::transform(value.begin(), value.end(), val.begin(), ::tolower);
if (val.empty() || (true_values.count(val) > 0U)) {
out = true;
return GF_SUCCESS;
}
if (false_values.count(val) > 0U) {
out = false;
return GF_SUCCESS;
}
return GF_FAILED;
}
GfStatus StringToInt32(const std::string &value, int32_t &out)
{
if (std::any_of(value.begin(), value.end(), [](const char &v) { return !std::isdigit(v); })) {
return GF_FAILED;
}
try {
out = std::stoi(value);
} catch (std::exception &e) {
return GF_FAILED;
}
return GF_SUCCESS;
}
static CmdFlagInfoMap& GetCmdFlagInfoMap()
{
static CmdFlagInfoMap cmd_flag_info_map;
return cmd_flag_info_map;
}
static int32_t GetNextFlagIndex(const std::string &name)
{
static int32_t current_index = FlagIndex::CUSTOM_START_INDEX;
if (name == FlagName::HELP) {
return FlagIndex::HELP;
}
return ++current_index;
}
template <class T>
static std::shared_ptr<CmdFlagInfo> RegisterCmdFlagInfo(int32_t has_arg, const std::string &name,
const T &default_val, const std::string &msg)
{
std::shared_ptr<CmdFlagInfo> cmd_flag_info = nullptr;
try {
cmd_flag_info = std::make_shared<CmdFlagInfo>(has_arg, GetNextFlagIndex(name), name, default_val, msg);
} catch (...) {
return nullptr;
}
GetCmdFlagInfoMap()[cmd_flag_info->GetFlagIndex()] = cmd_flag_info;
return cmd_flag_info;
}
static GfStatus UpdateCmdFlagInfo(const int32_t index, const std::string &value)
{
const auto info = GetCmdFlagInfoMap().find(index);
if (info != GetCmdFlagInfoMap().cend()) {
GE_ASSERT_NOTNULL(info->second);
auto ret = info->second->SetFlagValue(value);
if (ret == GF_SUCCESS) {
GetUserOptions().emplace(info->second->GetFlagName(), value);
}
return ret;
}
return GF_FAILED;
}
static std::vector<mmStructOption>& GetOptionsVec()
{
static std::vector<mmStructOption> options_vec;
options_vec.clear();
const CmdFlagInfoMap &info_map = GetCmdFlagInfoMap();
for (const auto &info : info_map) {
mmStructOption opt;
opt.name = info.second->GetFlagName().c_str();
opt.has_arg = info.second->GetFlagHasArg();
opt.flag = nullptr;
opt.val = info.second->GetFlagIndex();
options_vec.push_back(opt);
if (info.second->GetFlagName() == "help") {
opt.name = "h";
options_vec.push_back(opt);
}
}
mmStructOption tmp_opt{};
options_vec.push_back(tmp_opt);
return options_vec;
}
static std::string GetFlagMsg(const std::string &flag_name)
{
const std::regex r("^-+");
const std::string &name = std::regex_replace(flag_name, r, "");
const CmdFlagInfoMap &info_map = GetCmdFlagInfoMap();
for (const auto &info : info_map) {
if (info.second->GetFlagName() == name) {
return info.second->GetFlagMsg();
}
}
return "";
}
static void ReplaceFlagPrefixSingleMinus(std::string &str)
{
const std::regex r("^-(?=[[:alpha:]]{2,})");
str = std::regex_replace(str, r, "--");
}
static void ReplaceFlagNameMinus(std::string &str)
{
const std::regex r("^--(?=[[:alpha:]]{2,})");
const int prefix_len = 4;
if (!std::regex_search(str, r)) {
return;
}
auto it = str.begin() + prefix_len;
while (it != str.end()) {
if (*it == '=') {
break;
}
if (*it == '-') {
*it = '_';
}
++it;
}
}
static char** FormatArgv(int32_t argc, char *argv[])
{
static std::vector<std::shared_ptr<std::string>> vec_str;
static std::vector<char *> vec_ptr;
vec_str.clear();
vec_ptr.clear();
for (int i = 0; i < argc; ++i) {
std::string str(argv[i]);
ReplaceFlagPrefixSingleMinus(str);
ReplaceFlagNameMinus(str);
vec_str.push_back(std::make_shared<std::string>(str));
vec_ptr.push_back(const_cast<char *>(vec_str.back()->c_str()));
}
return vec_ptr.data();
}
static std::string GetFlagName(const std::string &flag_name)
{
const std::regex r("^-+|=.*");
return std::regex_replace(flag_name, r, "");
}
}
std::string& RegisterParamString(const std::string &name, const std::string &default_val, const std::string &msg)
{
return RegisterCmdFlagInfo(MMPA_REQUIRED_ARGUMENT, name, default_val, msg)->GetFlagValueString();
}
bool& RegisterParamBool(const std::string &name, bool default_val, const std::string &msg)
{
return RegisterCmdFlagInfo(MMPA_OPTIONAL_ARGUMENT, name, default_val, msg)->GetFlagValueBool();
}
int32_t& RegisterParamInt32(const std::string &name, int32_t default_val, const std::string &msg)
{
return RegisterCmdFlagInfo(MMPA_REQUIRED_ARGUMENT, name, default_val, msg)->GetFlagValueInt32();
}
static std::string& GetUsageMessage()
{
static std::string usage;
return usage;
}
static void PrintUsageMessage()
{
if (!GetUsageMessage().empty()) {
std::cout << GetUsageMessage() << std::endl;
}
}
static GfStatus CheckNonOptionParameters(const int32_t argc, char** const argv)
{
int32_t idx = mmGetOptInd();
if (idx >= argc) {
return GF_SUCCESS;
}
if (idx < 0) {
return GF_FAILED;
}
while (idx < argc) {
GFPRINTF("Non-option parameter: %s", argv[idx]);
idx++;
}
return GF_FAILED;
}
void SetUsageMessage(const std::string &usage)
{
GetUsageMessage() = usage;
}
std::string& GetArgv()
{
static std::string argv;
return argv;
}
std::unordered_map<std::string, std::string> &GetUserOptions()
{
static std::unordered_map<std::string, std::string> user_options;
return user_options;
}
GfStatus ParseCommandLine(int32_t argc, char* argv[])
{
int32_t index = 0;
GfStatus ret = GF_FAILED;
GetArgv().clear();
for (int32_t i = 0; i < argc; ++i) {
GetArgv() += std::string(argv[i]) + ((i == argc - 1) ? "" : " ");
}
mmSetOptErr(0);
char **buff = FormatArgv(argc, argv);
optind = 1;
while ((index = mmGetOptLong(argc, buff, CMD_SHORT_OPTS, GetOptionsVec().data(), nullptr)) != -1) {
int32_t idx = mmGetOptInd() - 1;
if ((idx < 0) || (idx >= argc)) {
GFPRINTF("input argument value is empty or command line format is invalid");
return GF_FAILED;
}
if (index == FlagIndex::COLON) {
const char *flag_name = buff[idx];
const std::string &flag_msg = GetFlagMsg(flag_name);
if (flag_msg.empty()) {
GFPRINTF("flag '%s' is missing its argument", flag_name);
} else {
GFPRINTF("flag '%s' is missing its argument; flag description: %s", flag_name, flag_msg.c_str());
}
return GF_FAILED;
}
if (index == FlagIndex::QUESTION_MARK) {
GFPRINTF("unknown command line flag '%s'", GetFlagName(buff[idx]).c_str());
return GF_FAILED;
}
std::string value = (mmGetOptArg() == nullptr) ? "" : mmGetOptArg();
ret = UpdateCmdFlagInfo(index, value);
if (ret != GF_SUCCESS) {
return GF_FAILED;
}
if (index == FlagIndex::HELP) {
const CmdFlagInfoMap &info_map = GetCmdFlagInfoMap();
const auto it = info_map.find(FlagIndex::HELP);
if ((it != info_map.end()) && it->second->GetFlagValueBool()) {
PrintUsageMessage();
return GF_HELP;
}
}
}
return CheckNonOptionParameters(argc, buff);
}
}
}