/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "common/fe_context_utils.h"
#include "common/fe_log.h"
#include "common/aicore_util_constants.h"
#include "common/string_utils.h"
#include "common/fe_type_utils.h"
#include "common/platform_utils.h"
#include "common/fe_report_error.h"
#include "graph/ge_context.h"
#include "ge/ge_api_types.h"
#include "graph/tuning_utils.h"
#include "common/platform_utils.h"
#include "platform/platform_info.h"

namespace fe {
namespace {
static const std::unordered_map<std::string, PrecisionMode> kPrecisionModeMap = {
    {ALLOW_MIX_PRECISION, PrecisionMode::ENUM_ALLOW_MIX_PRECISION_FP16},
    {FORCE_FP16, PrecisionMode::ENUM_FORCE_FP16},
    {FORCE_FP32, PrecisionMode::ENUM_FORCE_FP32},
    {ALLOW_FP32_TO_FP16, PrecisionMode::ENUM_ALLOW_FP32_TO_FP16},
    {MUST_KEEP_ORIGIN_DTYPE, PrecisionMode::ENUM_MUST_KEEP_ORIGIN_DTYPE},
    {FORCE_LOWERPRECISION, PrecisionMode::ENUM_FORCE_LOWERPRECISION},
    {ALLOW_FP32_TO_BF16, PrecisionMode::ENUM_ALLOW_FP32_TO_BF16},
    {ALLOW_FP32_TO_LOWPRECISION, PrecisionMode::ENUM_ALLOW_FP32_TO_LOWPRECISION},
    {ALLOW_MIX_PRECISION_FP16, PrecisionMode::ENUM_ALLOW_MIX_PRECISION_FP16},
    {ALLOW_MIX_PRECISION_BF16, PrecisionMode::ENUM_ALLOW_MIX_PRECISION_BF16},
    {CUBE_FP16IN_FP32OUT, PrecisionMode::ENUM_FORCE_FP32},
    {V2_FP16, PrecisionMode::ENUM_FORCE_FP16},
    {V2_MIX_FP16, PrecisionMode::ENUM_ALLOW_MIX_PRECISION_FP16},
    {V2_MIX_BF16, PrecisionMode::ENUM_ALLOW_MIX_PRECISION_BF16},
    {V2_ORIGIN, PrecisionMode::ENUM_MUST_KEEP_ORIGIN_DTYPE},
    {V2_ORIGIN, PrecisionMode::ENUM_MUST_KEEP_ORIGIN_DTYPE},
    {kCubeHif8, PrecisionMode::ENUM_CUBE_HIF8},
    {kMixedHif8, PrecisionMode::ENUM_MIXED_HIF8}
};
const std::unordered_set<std::string> kIntrinsicvconvSet = {
    "f322bf16r", "f322bf16f", "f322bf16c", "f322bf16a", "f322bf16z"
};
const std::map<string, bool> kSwitchMap {{"1", true}, {"0", false}};
const int32_t kBase = 10;
const std::string INTRINSIC_VCONV = "Intrinsic_vconv";
}

std::string FEContextUtils::GetPrecisionMode() {
  std::string precision_mode_str = GetGeContextValue(ge::PRECISION_MODE);
  if (precision_mode_str.empty()) {
    precision_mode_str = GetGeContextValue(ge::PRECISION_MODE_V2);
  }
  return precision_mode_str;
}

bool FEContextUtils::IsTrainMode() {
  std::string run_mode = GetGeContextValue(ge::OPTION_GRAPH_RUN_MODE, "1");
  if (ge::GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= ge::TRAIN) {
    return true;
  }
  return false;
}

void FEContextUtils::SetDefaultPrecisionMode(std::string &precision_mode) {
  if (!IsTrainMode()) {
    precision_mode = FORCE_FP16;
  } else {
    if (PlatformUtils::Instance().IsEnableCubeHighPrecision()) {
      precision_mode = MUST_KEEP_ORIGIN_DTYPE;
    } else {
      precision_mode = ALLOW_FP32_TO_FP16;
    }
  }
  FE_LOGI("The value of [%s] is empty, set to default [%s].", ge::PRECISION_MODE.c_str(), precision_mode.c_str());
}

void FEContextUtils::SetDefaultPrecisionMode(fe::PrecisionMode &precision_mode) {
  if (!IsTrainMode()) {
    precision_mode = PrecisionMode::ENUM_FORCE_FP16;
  } else {
    if (PlatformUtils::Instance().IsEnableCubeHighPrecision()) {
      precision_mode = PrecisionMode::ENUM_MUST_KEEP_ORIGIN_DTYPE;
    } else {
      precision_mode = PrecisionMode::ENUM_ALLOW_FP32_TO_FP16;
    }
  }
  FE_LOGI("The value for [%s] is empty; setting it to the default [%d].", ge::PRECISION_MODE.c_str(), precision_mode);
}

Status FEContextUtils::GetPrecisionMode(std::string &precision_mode) {
  precision_mode = GetPrecisionMode();
  if (precision_mode.empty()) {
    SetDefaultPrecisionMode(precision_mode);
    return SUCCESS;
  }
  const auto &iter = kPrecisionModeMap.find(precision_mode);
  if (iter == kPrecisionModeMap.end()) {
    FE_LOGE("Precision mode value %s is incorrect, please check it.", precision_mode.c_str());
    ErrorMessageDetail err_msg(EM_INPUT_OPTION_INVALID, {precision_mode, ge::PRECISION_MODE,
                               "The current value is not within the valid range"});
    ReportErrorMessage(err_msg);
    return FAILED;
  }
  const auto &precision_mode_enum = iter->second;
  if (precision_mode_enum == PrecisionMode::ENUM_ALLOW_MIX_PRECISION_BF16 ||
      precision_mode_enum == PrecisionMode::ENUM_ALLOW_FP32_TO_BF16) {
    PlatFormInfos platform_infos;
    OptionalInfos optional_infos;
    bool enable_flag = false;
    if (PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != SUCCESS) {
      FE_LOGE("Getting platform information without SOC version has failed");
      return FAILED;
    }
    std::map<std::string, std::vector<std::string>> intrinsic_map = platform_infos.GetAICoreIntrinsicDtype();
    auto intr_iter = intrinsic_map.find(INTRINSIC_VCONV);
    std::vector<std::string> intrinsic_vec;
    if (intr_iter != intrinsic_map.end()) {
      intrinsic_vec = intr_iter->second;
    }
    for (auto nIterator = intrinsic_vec.cbegin(); nIterator != intrinsic_vec.cend(); ++nIterator) {
      if (kIntrinsicvconvSet.count(*nIterator) > 0) {
        enable_flag = true;
        break;
      }
    }
    if (!enable_flag) {
      FE_LOGE("The AI core doesn't support allow_mix_precision_bf16 or allow_fp32_to_bf16.");
      ErrorMessageDetail err_msg(EM_INPUT_OPTION_INVALID,
                                 {precision_mode, ge::PRECISION_MODE, "Current soc not support dtype of BFloat16"});
      ReportErrorMessage(err_msg);
      return FAILED;
    }
  }

  return SUCCESS;
}

Status FEContextUtils::GetPrecisionMode(fe::PrecisionMode &precision_mode) {
  std::string precision_mode_str = GetPrecisionMode();
  if (precision_mode_str.empty()) {
    SetDefaultPrecisionMode(precision_mode);
    return SUCCESS;
  }
  const auto &iter = kPrecisionModeMap.find(precision_mode_str);
  if (iter == kPrecisionModeMap.end()) {
    FE_LOGE("Precision mode value %s is incorrect, please check it.", precision_mode_str.c_str());
    ErrorMessageDetail err_msg(EM_INPUT_OPTION_INVALID, {precision_mode_str, ge::PRECISION_MODE,
                               "The current value is not within the valid range"});
    ReportErrorMessage(err_msg);
    return FAILED;
  }
  const auto &precision_mode_enum = iter->second;
  if (precision_mode_enum == PrecisionMode::ENUM_ALLOW_MIX_PRECISION_BF16 ||
      precision_mode_enum == PrecisionMode::ENUM_ALLOW_FP32_TO_BF16) {
    PlatFormInfos platform_infos;
    OptionalInfos optional_infos;
    bool enable_flag = false;
    if (PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != SUCCESS) {
      FE_LOGE("Getting platform information without SOC version has failed");
      return FAILED;
    }
    std::map<std::string, std::vector<std::string>> intrinsic_map = platform_infos.GetAICoreIntrinsicDtype();
    auto intr_iter = intrinsic_map.find(INTRINSIC_VCONV);
    std::vector<std::string> intrinsic_vec;
    if (intr_iter != intrinsic_map.end()) {
      intrinsic_vec = intr_iter->second;
    }
    for (auto nIterator = intrinsic_vec.cbegin(); nIterator != intrinsic_vec.cend(); ++nIterator) {
      if (kIntrinsicvconvSet.count(*nIterator) > 0) {
        enable_flag = true;
        break;
      }
    }
    if (!enable_flag) {
      FE_LOGE("The AI core doesn't support mixed_bfloat16, allow_mix_precision_bf16 or allow_fp32_to_bf16.");
      ErrorMessageDetail err_msg(EM_INPUT_OPTION_INVALID,
          {precision_mode_str, ge::PRECISION_MODE, "Current soc not support dtype of BFloat16"});
      ReportErrorMessage(err_msg);
      return FAILED;
    }
  }
  precision_mode = precision_mode_enum;
  return SUCCESS;
}

std::string FEContextUtils::GetBuildMode() { return GetGeContextValue(ge::BUILD_MODE); }

std::string FEContextUtils::GetBuildStep() { return GetGeContextValue(ge::BUILD_STEP); }

std::string FEContextUtils::GetCoreType() { return GetGeContextValue(ge::CORE_TYPE); }

std::string FEContextUtils::GetFusionSwitchFilePath() { return GetGeContextValue(ge::FUSION_SWITCH_FILE); }

std::string FEContextUtils::GetGeContextValue(const std::string &key) {
  std::string option_value;
  return GetGeContextValue(key, option_value);
}

std::string FEContextUtils::GetGeContextValue(const std::string &key, const std::string &default_value) {
  std::string option_value;
  ge::graphStatus status = ge::GetContext().GetOption(key, option_value);
  if (status != ge::GRAPH_SUCCESS) {
    FE_LOGD("Cannot get option value [%s].", key.c_str());
  } else {
    FE_LOGD("The option value[%s] in ge context is %s.", key.c_str(), option_value.c_str());
    return option_value;
  }
  return default_value;
}

bool FEContextUtils::GetGeContextBoolValue(const std::string &key, const bool &default_value) {
  std::string option_value;
  ge::graphStatus status = ge::GetContext().GetOption(key, option_value);
  if (status != ge::GRAPH_SUCCESS || option_value.empty()) {
    FE_LOGD("Cannot get option value [%s] or the value is empty.", key.c_str());
    return default_value;
  }
  FE_LOGD("The option value[%s] in ge context is [%s].", key.c_str(), option_value.c_str());
  const std::map<std::string, bool>::const_iterator iter = kSwitchMap.find(option_value);
  if (iter == kSwitchMap.end()) {
    FE_LOGD("The value [%s] is neither 0 nor 1.", option_value.c_str());
    return default_value;
  }
  return iter->second;
}

bool FEContextUtils::IsOpTuneMode() {
  std::string build_mode = GetBuildMode();
  return (build_mode == ge::BUILD_MODE_BASELINE || build_mode == ge::BUILD_MODE_TUNING ||
          build_mode == ge::BUILD_MODE_OPAT_RESULT);
}
}