/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "codegen_infershape.h"
#include "code_printer.h"
namespace codegen {
namespace {
std::string GetFileHeaderDefine() {
  std::string file_header_str = R"(
#include <cmath>
#include <type_traits>
#include <unordered_map>
#include "exe_graph/runtime/infer_shape_context.h"

namespace {
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
#define Pow(a, b) (std::pow(a, b))
#define Exp(a) (std::exp(a))
#define Log(a) (std::log(a))
#define Sqrt(a) (std::sqrt(a))
#define Ceiling(a) (std::ceil(a))
#define Floor(a) (std::floor(a))
#define Abs(a) (std::abs(a))
#define Rational(a, b) (static_cast<double>(a) / static_cast<double>(b))

const double kThreshold = 0.00001;

template <typename Ta, typename Tb>
auto Mod(Ta left, Tb right) -> decltype(left % right) {
  return left % right;
}

// 针对浮点数的特化版本(使用 std::fmod)
template <typename Ta, typename Tb>
auto Mod(Ta left, Tb right) ->
    typename std::enable_if<std::is_floating_point<Ta>::value || std::is_floating_point<Tb>::value,
                            decltype(std::fmod(left, right))>::type {
  return std::fmod(left, right);
}

class InferShapeSymbolEvalContext : public gert::InferShapeContext {
 public:
  const gert::Tensor *GetGraphInputTensor(size_t data_index) const {
    auto *tensor = GetInputPointer<gert::Tensor>(data_index + 1);
    if (tensor == nullptr) {
      return nullptr;
    }
    return tensor;
  }
};
static_assert(std::is_standard_layout<InferShapeSymbolEvalContext>::value,
              "The class InferShapeSymbolEvalContext must be a POD");
} // namespace
)";
  return file_header_str;
}
}  // namespace

std::string InfershapeGen::GenInferShapeFunc(const std::vector<std::vector<std::string>> &symbol_shape_str,
                                             const std::map<std::string, std::string> &shape_info) const {
  ge::CodePrinter printer;
  const std::string blank_space = "  ";
  printer.AddLine(GetFileHeaderDefine());

  std::string common_get_input_str;
  for (const auto &it : shape_info) {
    common_get_input_str += (blank_space + "auto " + it.first + " = " + it.second + ";\n");
  }

  printer.DefineFuncBegin("extern \"C\" ge::graphStatus", "InferShape", "InferShapeSymbolEvalContext *context");
  printer.AddLine(common_get_input_str);

  size_t expr_value_num = 0;
  for (size_t i = 0U; i < symbol_shape_str.size(); ++i) {
    printer.AddLine(blank_space + "context->GetOutputShape(" + std::to_string(i) + ")->SetDimNum(0);");
    for (const auto &sym_expr : symbol_shape_str[i]) {
      std::string append_dim_code = blank_space + "context->GetOutputShape(" + std::to_string(i) + ")->AppendDim(";
      // 如果表达式中包含Rational, 结果可能是浮点数, 强转成整形会舍去小数部分导致结果错误, 因此要进行四舍五入处理
      // 然后对损失的小数部分做校验, 小于设定的阈值才认为计算成功
      if (sym_expr.find("Rational") != std::string::npos) {
        std::string expr_value_name = "expr_value_" + std::to_string(expr_value_num);
        std::string round_value_name = "round_value_" + std::to_string(expr_value_num);
        printer.AddLine(blank_space+ "// 表达式中包含Rational, 结果可能是浮点数, 强转成整形会舍去小数部分导致结果错误, 因此要进行四舍五入处理");
        printer.AddLine(blank_space + "double " + expr_value_name +  " = " + sym_expr + ";");
        printer.AddLine(blank_space + "int64_t " + round_value_name + " = std::round(" + expr_value_name + ");");
        printer.AddLine(blank_space + "// 对损失的小数部分做校验, 小于设定的阈值才认为计算成功");
        printer.AddLine(blank_space + "if ((fabs(" + expr_value_name + " - static_cast<double>(" + round_value_name + ")) > kThreshold)) {");
        printer.AddLine(blank_space + "  return ge::GRAPH_FAILED;");
        printer.AddLine(blank_space + "}");
        append_dim_code += (round_value_name +");");
        expr_value_num++;
      } else {
        append_dim_code += (sym_expr + ");");
      }
      printer.AddLine(append_dim_code);
    }
  }

  printer.AddLine(blank_space + "return ge::GRAPH_SUCCESS;");
  printer.DefineFuncEnd();

  return printer.GetOutputStr();
}
}  // namespace codegen