/* -------------------------------------------------------------------------
 *  This file is part of the MindStudio project.
 * Copyright (c) 2025 Huawei Technologies Co.,Ltd.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * ------------------------------------------------------------------------- */

#include "arg_checker.h"

#include <vector>
#include <unordered_set>

#include "common/prof_args.h"
#include "ascend_helper.h"
#include "filesystem.h"
#include "common/hal_helper.h"
#include "common/defs.h"
#include "log.h"

using namespace Common;
using namespace Utility;

namespace Parser {
constexpr uint32_t LAUNCH_COUNT_MAX_LENGTH = 4;
constexpr int32_t MAX_LAUNCH_COUNT = 5000;
constexpr int32_t MAX_LAUNCH_SKIP_NUMBER = 1000;
constexpr int32_t MAX_WARM_UP_TIMES = 500;
constexpr int32_t MAX_PARSE_CORE_ID = 49;
constexpr int32_t MAX_TIMEOUT = 2880;

// 上板模式下支持的AIC指标,key是指标名称,value是支持的产品类型列表
const AicMetricsSupportMap DEVICE_AIC_METRICS_SUPPORT_MAP{
    {std::string(Common::MsprofMetrics::PIPE_UTILIZATION),        {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::ARITHMETIC_UTILIZATION),  {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::L2_CACHE),                {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::MEMORY),                  {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::MEMORY_L0),               {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::MEMORY_UB),               {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::RESOURCE_CONFLICT_RATIO), {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::DEFAULT),                 {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::KERNEL_SCALE),            {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::OCCUPANCY),               {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::TIMELINE_DETAIL),         {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES}},
    {std::string(Common::MsprofMetrics::ROOFLINE),                {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND310P_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::BASIC_INFO),              {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES,
                                                                   ChipProductType::ASCEND310P_SERIES}},
    {std::string(Common::MsprofMetrics::SOURCE),                  {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::MEMORYDETAIL),            {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::PIPE_TIMELINE),           {ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::INSTR_TIMELINE),          {ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::PCSAMPLING),              {ChipProductType::ASCEND950_SERIES}},
};
// 仿真模式下支持的AIC指标,key是指标名称,value是支持的产品类型列表
const AicMetricsSupportMap SIMULATOR_AIC_METRICS_SUPPORT_MAP{
    {std::string(Common::MsprofMetrics::PIPE_UTILIZATION),        {ChipProductType::ALL_PRODUCT_TYPE}},
    {std::string(Common::MsprofMetrics::RESOURCE_CONFLICT_RATIO), {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND310P_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::PMSAMPLING),              {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES,
                                                                   ChipProductType::ASCEND950_SERIES}},
    {std::string(Common::MsprofMetrics::OVERHEAD),                {ChipProductType::ASCEND910B_SERIES,
                                                                   ChipProductType::ASCEND910_93_SERIES}},
};

bool ArgChecker::CheckDeviceChipSupport(
    const std::string &argName, const std::vector<ChipProductType> &supportTypes, std::string &msg) const {
    ChipType chipType = Common::HalHelper::Instance().GetPlatformType();
    auto iter = CHIP_ARCHITECTURE_TO_PRODUCT_SERIES.find(chipType);
    if (iter == CHIP_ARCHITECTURE_TO_PRODUCT_SERIES.end()) {
        msg = "chiptype " + std::to_string(static_cast<int>(chipType)) + " not support.";
        return false;
    }
    for (const auto &supportType : supportTypes) {
        if (IsChipSeriesTypeValid(iter->second, supportType)) {
            return true;
        }
    }
    msg = "Unexpected argument " + argName + ", maybe in wrong soc platform";
    return false;
}

ArgChecker::ArgChecker(const std::string &runMode)
{
    checkers_.emplace_back(&ArgChecker::CheckRunModeValid);
    checkers_.emplace_back(&ArgChecker::CheckApplicationValid);
    checkers_.emplace_back(&ArgChecker::CheckOutputPathValid);
    checkers_.emplace_back(&ArgChecker::CheckKernelNameValid);
    checkers_.emplace_back(&ArgChecker::CheckLaunchCount);
    checkers_.emplace_back(&ArgChecker::CheckMstx);
    checkers_.emplace_back(&ArgChecker::CheckMstxInclude);
    checkers_.emplace_back(&ArgChecker::CheckAicMetrics);
    checkers_.emplace_back(&ArgChecker::CheckCoreId);
    checkers_.emplace_back(&ArgChecker::CheckDump);
    if (runMode == "simulator") {
        checkers_.emplace_back(&ArgChecker::CheckExportPathValid);
        checkers_.emplace_back(&ArgChecker::CheckSimSocVersion);
        checkers_.emplace_back(&ArgChecker::CheckTimeout);
    } else {
        checkers_.emplace_back(&ArgChecker::CheckLaunchSkipBeforeMatch);
        checkers_.emplace_back(&ArgChecker::CheckKillAdvance);
        checkers_.emplace_back(&ArgChecker::CheckReplayMode);
        checkers_.emplace_back(&ArgChecker::CheckWarmUp);
        checkers_.emplace_back(&ArgChecker::CheckInstrTimelinePipe);
        checkers_.emplace_back(&ArgChecker::CheckCustomInput);
    }
}

bool ArgChecker::CheckMetrics(const std::vector<std::string> &metricsVec, const ChipProductType &productType,
                              const AicMetricsSupportMap &supports, std::string &msg) const
{
    auto isSupport = [&supports, &productType] (const std::string &metric) {
        for (const auto &supportType : supports.at(metric)) {
            if (IsChipSeriesTypeValid(productType, supportType)) {
                return true;
            }
        }
        return false;
    };
    for (const auto &metric : metricsVec) {
        if (supports.find(metric) == supports.end()) {
            msg = "Unexpected argument in --aic-metrics, maybe in wrong run mode";
            return false;
        }
        if (isSupport(metric)) {
            continue;
        }
        msg = "Unexpected argument in --aic-metrics, maybe in wrong soc platform";
        return false;
    }
    return true;
}

bool ArgChecker::CheckAicMetrics(const Common::ProfArgs &config, std::string &msg) const
{
    std::vector<std::string> metricVec = config.argAicMetrics.metricVec;
    if (config.runMode == "device") {
        ChipType chipType = Common::HalHelper::Instance().GetPlatformType();
        if (CHIP_ARCHITECTURE_TO_PRODUCT_SERIES.find(chipType) == CHIP_ARCHITECTURE_TO_PRODUCT_SERIES.end()) {
            msg = "chiptype " + std::to_string(static_cast<int>(chipType)) + " not support.";
            return false;
        }
        if (config.argAicMetrics.pipeTimelineEnable && config.argAicMetrics.instrTimelineEnable) {
            msg = "Unexpected argument --aic-metrics, 'PipeTimeline' and 'InstrTimeline' cannot be specified together.";
            return false;
        }
        ChipProductType productType = CHIP_ARCHITECTURE_TO_PRODUCT_SERIES.at(chipType);
        return CheckMetrics(metricVec, productType, DEVICE_AIC_METRICS_SUPPORT_MAP, msg);
    } else if (config.runMode == "simulator") {
        std::string socVersion = config.argSocVersion;
        if (socVersion.empty() && !GetSocVersionFromEnvVar(socVersion)) {
            socVersion = "Ascend910B1";
        }
        ChipProductType productType = GetProductTypeBySocVersion(socVersion);
        return CheckMetrics(metricVec, productType, SIMULATOR_AIC_METRICS_SUPPORT_MAP, msg);
    }
    return true;
}

bool ArgChecker::Check(const ProfArgs &config, std::string &msg) const
{
    for (auto const &c : checkers_) {
        if (!(this->*c)(config, msg)) {
            return false;
        }
    }
    return true;
}

bool ArgChecker::CheckRunModeValid(const ProfArgs &config, std::string &msg) const
{
    if (config.runMode == "device") {
        if (!HalHelper::Instance().IsSupportPlatform()) {
            msg = "Device profiling is not supported on current chip.";
            return false;
        }
        return true;
    }

    if (config.runMode != "simulator") {
        msg = "unexpected run mode";
        return false;
    }
    if (config.argAicMetrics.isDeviceToSimulator) {
        msg = "--aic-metrics=TimelineDetail is invalid in simulator";
        return false;
    }
    std::vector<std::string> sims;
    if (!GetSimulators(sims)) {
        msg = "get simulators from ascend path failed";
        return false;
    }

    return true;
}

bool ArgChecker::CheckApplicationValid(const ProfArgs &args, std::string &msg) const
{
    /// check if programs be used correctly
    const int32_t inputArgsCount = static_cast<int32_t>(!args.argConfig.empty()) +
                                   static_cast<int32_t>(!args.cmd.empty()) +
                                   static_cast<int32_t>(!args.argExport.empty());
    constexpr int32_t argsCount = 1;
    if (inputArgsCount != argsCount) {
        msg = "Input parameter config, export and application can not be used together or empty at the same time";
        return false;
    } else if (args.cmd.empty()) {
        return true;
    }

    if (IsDir(args.cmd[0]) || !IsExecutable(args.cmd[0])) {
        msg = "application to be profiled is not exist or not executable.";
        return false;
    }

    return true;
}

bool ArgChecker::CheckOutputPathValid(const ProfArgs &config, std::string &msg) const
{
    if (IsExist(config.argOutput) && !IsDir(config.argOutput)) {
        msg = "--output parameter is not a folder but already exist, please check output path is correct.";
        return false;
    }

    std::string errorMsg;
    if (!IsStringCharValid(config.argOutput, errorMsg)) {
        msg = "--output parameter contains " + errorMsg;
        return false;
    }

    if (!PathLenCheckValid(config.argOutput)) {
        msg = "--output parameter length is larger than 200.";
        return false;
    }

    // Search for the created path in the absolute path of the output path and saved in checkPath
    std::vector<std::string> dirs;
    SplitString(config.argOutput, '/', dirs);
    std::string checkPath;
    for (const auto &dir : dirs) {
        if (dir.empty()) { continue; }
        checkPath.append(Utility::PATH_SEP + dir);
        if (!Utility::IsDir(checkPath)) {
            checkPath.erase(checkPath.size() - dir.size());
            break;
        }
    }
    if (IsSoftLinkRecursively(checkPath)) {
        LogWarn("Output path contains soft link, may cause security problems");
    }
    if (!IsWritable(checkPath)) {
        LogWarn("Output dir is not writable: %s", checkPath.c_str());
    }
    CheckOwnerPermission(checkPath, msg);
    return true;
}

bool ArgChecker::CheckExportPathValid(const ProfArgs &config, std::string &msg) const
{
    if (config.argExport.empty()) {
        return true;
    }
    if (!CheckInputFileValid(config.argExport, "dir")) {
        LogWarn("In input parameter --export receive parent dir permission wrong");
    }
    CheckPermission(config.argExport);
    return true;
}

bool ArgChecker::CheckKernelNameValid(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argKernelName.empty()) {
        return true;
    }
    if (!config.argConfig.empty() || !config.argExport.empty()) {
        msg = "--kernel-name only supports application mode";
        return false;
    }

    if (config.argKernelName.size() >= MAX_KERNEL_NAME_LENGTH) {
        msg = "--kernel-name input length exceeds limitation.";
        return false;
    }

    std::set<std::string> kernelNameSet;
    Utility::SplitString(config.argKernelName, '|', kernelNameSet);
    std::regex namePattern("^[A-Za-z0-9_*]+$");
    for (const auto &kernelName : kernelNameSet) {
        if (kernelName.empty() || kernelName.length() > MAX_KERNEL_NAME_LENGTH) {
            msg = "invalid kernel name, name is too long or empty";
            return false;
        }
        if (!std::regex_match(kernelName, namePattern)) {
            msg = "invalid kernel name, name contains unsupported character in name,"
                  "Support characters in one kernel name are: A-Z a-z 0-9 _ *";
            return false;
        }
    }
    return true;
}

bool ArgChecker::CheckLaunchCount(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argLaunchCount.length() > LAUNCH_COUNT_MAX_LENGTH) {
        msg = "Launch count should in [1, 5000]";
        return false;
    }
    int32_t num {0};
    if (!StringToNum<int32_t>(config.argLaunchCount, num)) {
        msg = "Launch count should be number and within [1, 5000]";
        return false;
    }
    if (num < 1 || num > MAX_LAUNCH_COUNT) { // num should in [1, 5000]
        msg = "Launch count should within [1, 5000]";
        return false;
    }
    return true;
}

bool ArgChecker::CheckLaunchSkipBeforeMatch(const Common::ProfArgs &config, std::string &msg) const
{
    int32_t num {0};
    size_t maxLength = 5;
    if (config.argLaunchSkipBeforeMatch.size() >= maxLength ||
        !StringToNum<int32_t>(config.argLaunchSkipBeforeMatch, num)) {
        msg = "Launch-skip-before-match should be number and within [0, 1000]";
        return false;
    }
    if (num < 0 || num > MAX_LAUNCH_SKIP_NUMBER) {
        msg = "Launch-skip-before-match should within [0, 1000]";
        return false;
    }
    return true;
}

bool ArgChecker::CheckReplayMode(const Common::ProfArgs &config, std::string &msg) const
{
    std::set<std::string> targetModes = {"application", "kernel"};
    ChipType chipType = Common::HalHelper::Instance().GetPlatformType();
    if (chipType == ChipType::ASCEND910B || chipType == ChipType::ASCEND950) {
        targetModes.insert("range");
    }
    if (targetModes.count(config.argReplayMode) == 0) {
        msg = "Replay mode should be " + Join(targetModes.begin(), targetModes.end(), "/");
        return false;
    }
    if (config.argReplayMode == "application" && config.argAicMetrics.isDeviceToSimulator) {
        msg = "--aic-metrics=TimelineDetail is invalid when --replay-mode=application";
        return false;
    }
    if (config.argReplayMode == "range") {
        if (config.argMstx != "on") {
            msg = "--replay-mode=range only support when --mstx=on";
            return false;
        }
        if (config.argAicMetrics.isDeviceToSimulator || config.argAicMetrics.isSource ||
            config.argAicMetrics.isMemoryDetail) {
            msg = "--aic-metrics=TimelineDetail/Source/MemoryDetail is invalid when --replay-mode=range";
            return false;
        }
    }
    return true;
}

bool ArgChecker::CheckKillAdvance(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argKill != "on" && config.argKill != "off") {
        msg = "Kill should be on/off";
        return false;
    }
    return true;
}

bool ArgChecker::CheckDump(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argDump != "on" && config.argDump != "off") {
        msg = "--dump should be on/off";
        return false;
    }
    if (config.runMode == "device" && config.argDump == "on") {
        return CheckDeviceChipSupport(
            "--dump", {ChipProductType::ASCEND910B_SERIES, ChipProductType::ASCEND910_93_SERIES}, msg);
    }
    if (config.runMode == "simulator" && config.argDump == "on") {
        std::string socVersion = config.argSocVersion;
        if (socVersion.empty() && !GetSocVersionFromEnvVar(socVersion)) {
            socVersion = "Ascend910B1";
        }
        ChipProductType productType = GetProductTypeBySocVersion(socVersion);
        if (IsChipSeriesTypeValid(productType, ChipProductType::ASCEND910B_SERIES) ||
            IsChipSeriesTypeValid(productType, ChipProductType::ASCEND910_93_SERIES)) {
            return true;
        }
        msg = "Unexpected argument --dump, maybe in wrong soc platform";
        return false;
    }
    return true;
}

bool ArgChecker::CheckMstx(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argMstx.empty()) {
        return true;
    }
    if (config.argMstx != "on" && config.argMstx != "off") {
        msg = "--mstx should use on/off";
        return false;
    }
    return true;
}

bool ArgChecker::CheckMstxInclude(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argMstxInclude.empty()) {
        return true;
    }

    if (config.argMstx == "off") {
        msg = "--mstx-include only support when --mstx=on";
        return false;
    }

    if (config.argMstxInclude.size() >= MAX_KERNEL_NAME_LENGTH) {
        msg = "--mstx-include input length exceeds limitation";
        return false;
    }

    std::set<std::string> messageSet;
    Utility::SplitString(config.argMstxInclude, '|', messageSet);
    for (const auto &message : messageSet) {
        if (!Utility::CheckInputStringValid(message, MAX_MSTX_INCLUDE_NAME_LENGTH)) {
            msg = "invalid include string, include string is too long or use unsupported character in include string. "
                    "Support characters in one message are: A-Z a-z 0-9 _";
            return false;
        }
    }
    return true;
}

bool ArgChecker::CheckSimSocVersion(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.runMode != "simulator" || config.argSocVersion.empty()) {
        return true;
    }
    if (!config.argConfig.empty()) {
        msg = "--soc-version is not effective in config mode";
        return false;
    }
    std::string ascendHomePath;
    if (!GetAscendHomePath(ascendHomePath)) {
        msg = "$ASCEND_HOME_PATH not found";
        return false;
    }
    if (StartsWith(config.argSocVersion, "Ascend950") &&
        SOC_STRING_TO_CHIP_PRODUCT.find(config.argSocVersion) != SOC_STRING_TO_CHIP_PRODUCT.end()) {
        return true;
    }
    std::vector<std::string> sims;
    if (!GetFileNames(ascendHomePath + "/tools/simulator", sims)) {
        msg = "get simulator failed, please check $ASCEND_HOME_PATH/tools/simulator";
        return false;
    }
    if (std::count(sims.begin(), sims.end(), config.argSocVersion) == 0) {
        msg = "--soc-version is invalid, please specify a simulator in $ASCEND_HOME_PATH/tools/simulator";
        return false;
    }
    return true;
}

bool ArgChecker::CheckWarmUp(const Common::ProfArgs &config, std::string &msg) const
{
    int32_t num {0};
    if (!StringToNum<int32_t>(config.argWarmUp, num)) {
        msg = "Warm up times should be number and within [0, 500]";
        return false;
    }
    if (num < 0 || num > MAX_WARM_UP_TIMES) {
        msg = "Warm up times should within [0, 500]";
        return false;
    }
    return true;
}

bool ArgChecker::CheckCoreId(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argCoreId.empty()) {
        return true;
    }

    if (config.argCoreId.size() > MAX_INPUT_STR_LENGTH) {
        msg = "--core-id input length exceeds limitation.";
        return false;
    }

    std::set<std::string> coreSet;
    Utility::SplitString(config.argCoreId, '|', coreSet);
    for (const auto &core: coreSet) {
        uint16_t coreId = 0;
        // core ID should be [0, 49]
        if (!Utility::StringToNum<uint16_t>(core, coreId) || coreId > MAX_PARSE_CORE_ID) {
            msg = "--core-id is invalid, the cores to be parsed should be separated by '|',"
                  " and each core id should be an integer which within [0, 49].";
            return false;
        }
    }
    if (config.runMode == "device") {
        return CheckDeviceChipSupport(
            "--core-id", {ChipProductType::ASCEND910B_SERIES, ChipProductType::ASCEND910_93_SERIES}, msg);
    }
    return true;
}

bool ArgChecker::CheckTimeout(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argTimeout.empty()) {
        return true;
    }
    uint32_t num {0};
    if (!StringToNum<uint32_t>(config.argTimeout, num) || num == 0 || num > MAX_TIMEOUT) {
        msg = "--timeout is invalid, it should be an integer which within [1, 2880].";
        return false;
    }
    return true;
}

bool ArgChecker::CheckInstrTimelinePipe(const Common::ProfArgs &config, std::string &msg) const {
    if (config.argInstrTimelinePipe.empty()) {
        return true;
    }

    if (!config.argAicMetrics.instrTimelineEnable) {
        msg = "--instr-timeline-pipe only support when --aic-metrics include InstrTimeline.";
        return false;
    }

    if (config.argInstrTimelinePipe.size() >= MAX_INPUT_STR_LENGTH) {
        msg = "--instr-timeline-pipe input length exceeds limitation.";
        return false;
    }

    std::set<std::string> pipeSet;
    Utility::SplitString(config.argInstrTimelinePipe, '|', pipeSet);
    for (const auto &pipe : pipeSet) {
        if (!DfxPipe::IsValidDfxPipe(pipe)) {
            msg = "--instr-timeline-pipe only support pipes in [cube fixp vector mte1 mte2 mte3] and separated by '|'.";
            return false;
        }
    }
    return true;
}

bool ArgChecker::CheckCustomInput(const Common::ProfArgs &config, std::string &msg) const
{
    if (config.argCustomInput.empty()) {
        return true;
    }
    if (!CheckInputFileValid(config.argCustomInput, "json", MAX_JSON_FILE_SIZE, "custom-input")) {
        return false;
    }
    return true;
}
} // namespace Parser