/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#ifndef ATT_L0_SOLVER_H_
#define ATT_L0_SOLVER_H_
#include <sstream>
#include <string>
#include "util/base_types_printer.h"

namespace att {
inline std::string GenVarDef() {
  std::string strs = "";
  strs += AddAnotationLine(" L0Var的备选值的个数\n", "");
  strs += "static const uint32_t candidate_size = 7u;\n";
  strs += AddAnotationLine(" L0Var的备选值\n", "");
  strs += "static const uint32_t candidate_value[] = {16u,  32u,  64u,  128u,\n";
  strs += "                                           256u, 512u, 1024u};\n";
  strs += AddAnotationLine(" 表达L0的求解值至少要满足核数的比例,可手动修改\n", "");
  strs += "static const double CORE_NUM_RATIO = 0.6f;\n";
  strs += AddAnotationLine(" 表达L0的求解值pad之后的值不允许超过原始值大小的倍数,可手动修改\n", "");
  strs += "static const uint32_t UPPER_BOUND_RATIO = 2u;\n";
  strs += AddAnotationLine(" 表达最大L0Var的个数\n", "");
  strs += "static const uint32_t MAX_L0_VAR_NUM = 3u;\n";
  strs += "\n";
  return strs;
}
  
inline std::string GenL0VarDef() {
  std::string strs = "";
  strs += AddAnotationLine(" L0相关变量的数据结构\n", "");
  strs += "struct L0Var {\n";
  strs += AddAnotationLine(" 最大值,初始化为输入原始轴的大小\n", "  ");
  strs += "  uint32_t max_value{0u};\n";
  strs += AddAnotationLine(" 是否绑多核\n", "  ");
  strs += "  bool bind_multicore{false};\n";
  strs += "  bool is_innermost{false};\n";
  strs += AddAnotationLine(" 对齐值\n", "  ");
  strs += "  uint32_t align{0u};\n";
  strs += AddAnotationLine(" 提示当前L0Var的最佳对齐值,通常源于父轴的对齐值,\n", "  ");
  strs += AddAnotationLine(" 举例,stepm是basem的父轴,stepm的对齐要求是256和basem,那么basem的prompt_align就是256,\n", "  ");
  strs += AddAnotationLine(" 同时约束L0Var的取值必须是256对齐或者是256的因子,因为这样父轴stepm才能既满足256也满足basem对齐\n", "  ");
  strs += "  uint32_t prompt_align{0u};\n";
  strs += AddAnotationLine(" L0变量的索引\n", "  ");
  strs += "  uint32_t idx;\n";
  strs += AddAnotationLine(" L0变量的值\n", "  ");
  strs += "  uint32_t value{0u};\n";
  strs += "};\n";
  strs += "\n";
  return strs;
}

inline std::string GenL0Input() {
  std::string strs = "";
  strs += AddAnotationLine(" 求解器接收的输入\n", "");
  strs += "struct L0TileInput {\n";
  strs += AddAnotationLine(" 待求解L0变量的集合\n", "  ");
  strs += "  L0Var *l0_vars{nullptr};\n";
  strs += AddAnotationLine(" 待求解L0变量的数量\n", "  ");
  strs += "  uint32_t size;\n";
  strs += AddAnotationLine(" 核数\n", "  ");
  strs += "  uint32_t core_num;\n";
  strs += "};\n";
  strs += "\n";
  return strs;
}

inline std::string GenL0VarCmpAnnotation() {
  std::string strs = "";
  strs += " * 比较两个 L0Var 类型变量的大小\n";
  strs += " *\n";
  strs += " * 这个函数用来比较两个 L0Var "
        "类型变量的大小。它遵循特定的比较逻辑:\n";
  strs += " * 1. 如果 a 变量绑定到多核且 b 变量没有绑定多核,则 a 被认为大于 "
        "b,函数返回\n";
  strs += " * true。\n";
  strs += " * 2. 如果 a 变量没有绑定多核且 b 变量绑定多核,则 a 被认为小于 "
        "b,函数返回\n";
  strs += " * false。\n";
  strs += " * 3. 如果 a 和 b 变量都绑定多核或都未绑定多核,则比较它们的 "
        "prompt_align\n";
  strs += " * 属性,prompt_align 属性值大的变量被认为更大。\n";
  strs += " *\n";
  strs += " * @param a 第一个要比较的 L0Var 变量。\n";
  strs += " * @param b 第二个要比较的 L0Var 变量。\n";
  strs += " * @return 如果 a 大于 b,返回 true;如果 a 小于 b,返回 false。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenL0VarCmp() {
  std::string strs = "";
  strs += GenL0VarCmpAnnotation();
  strs += "static bool L0VarCmp(L0Var a, L0Var b) {\n";
  strs += "  if (a.bind_multicore && !b.bind_multicore) {\n";
  strs += "    return true;\n";
  strs += "  }\n";
  strs += "  if (!a.bind_multicore && b.bind_multicore) {\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  if (a.is_innermost && !b.is_innermost) {\n";
  strs += "    return true;\n";
  strs += "  }\n";
  strs += "  if (!a.is_innermost && b.is_innermost) {\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  return a.prompt_align > b.prompt_align;\n";
  strs += "}\n";
  return strs;
}

inline std::string GenL0SolverAnnotaion() {
  std::string annotations = "";
  annotations += " * L0 求解器类\n";
  return AddAnotationBlock(annotations, "");
}

inline std::string GenL0SolverGenAnnotaion() {
  std::string annotations = "";
  annotations += "   * 构造函数\n";
  annotations += "   *\n";
  annotations += "   * @param input 一个 L0TileInput 结构体,包含了 L0 变量相关的信息\n";
  annotations += "   *\n";
  annotations += "   * 这个构造函数初始化了 L0TileSolver 对象\n";
  return AddAnotationBlock(annotations, "  ");
}

inline std::string GenL0SolverDegenDef() {
  std::string strs = "";
  std::string annotations = "";
  annotations += "   * 析构函数\n";
  annotations += "   *\n";
  annotations += "   * 当 L0TileSolver 对象被销毁时,析构函数被调用\n";
  annotations += "   * 用来释放使用 new 运算符动态分配的内存,确保没有内存泄漏\n";
  strs += AddAnotationBlock(annotations, "  ");
  strs += "  ~L0TileSolver() {\n";
  strs += "    if (sortedvars_ != nullptr) {\n";
  strs += "      delete[] sortedvars_;\n";
  strs += "    }\n";
  strs += "    if (output_ != nullptr) {\n";
  strs += "      delete[] output_;\n";
  strs += "    }\n";
  strs += "    if (input_.l0_vars != nullptr) {\n";
  strs += "      delete[] input_.l0_vars;\n";
  strs += "    }\n";
  strs += "  }\n";
  return strs;
}

inline std::string GenRunAnnotaion() {
  std::string annotations = "";
  annotations += "   * 运行求解器\n";
  annotations += "   *\n";
  annotations += "   * @return 如果求解成功,返回 true;否则返回 false\n";
  annotations += "   *\n";
  annotations += "   * 这个方法是算法的入口点,调用它会启动求解过程\n";
  annotations += "   * 成功与否取决于 CheckBufferUseValid() 方法的返回值\n";
  return AddAnotationBlock(annotations, "  ");
}

inline std::string GenGetOutputAnnotaion() {
  std::string annotations = "";
  annotations += "   * 获取优化结果\n";
  annotations += "   *\n";
  annotations += "   * @return 指向求解结果数据的指针\n";
  annotations += "   *\n";
  annotations += "   * 如果有求解结果,这个方法将返回一个指向结果数据的指针\n";
  annotations += "   * 结果数据的内存使用完后,应该使用 delete[] 释放内存\n";
  return AddAnotationBlock(annotations, "  ");
}

inline std::string GenCheckBufferUseValidAnnotaion() {
  std::string annotations = "";
  annotations += "   * 检查是否满足buffer约束\n";
  annotations += "   *\n";
  annotations += "   * @return 如果满足,返回 true;否则返回 false\n";
  annotations += "   *\n";
  annotations += "   * 这个纯虚函数要求派生类提供实现,以确保缓冲区使用是有效的\n";
  annotations += "   * 在 L0TileSolver 类中,它是一个抽象方法,需要在子类中实现\n";
  return AddAnotationBlock(annotations, "  ");
}

inline std::string GenCheckInputDef() {
  std::string strs = "";
  std::string annotations = "";
  annotations += "   * 检查输入数据的完整性和正确性\n";
  annotations += "   *\n";
  annotations += "   * @return 如果输入数据有效,返回 true;否则返回 false\n";
  annotations += "   *\n";
  annotations += "   * 这个私有方法检查输入数据的格式和逻辑,确保它们适用求解算法\n";
  strs += AddAnotationBlock(annotations, "  ");
  strs += "  bool CheckInput();\n";
  strs += "\n";
  return strs;
}

inline std::string GenInitInputDef() {
  std::string strs = "";
  std::string annotations = "";
  annotations += "   * 使用输入数据初始化算法所需的内部数据结构\n";
  annotations += "   *\n";
  annotations += "   * 这个方法根据输入的 L0TileInput "
        "结构体中的数据,初始化算法所需的内部数据结构\n";
  annotations += "   * 确保 sortedvars_ 和 output_ 成员变量被正确初始化\n";
  strs += AddAnotationBlock(annotations, "  ");
  strs += "  void InitInput();\n";
  strs += "\n";
  return strs;
}

inline std::string GenCheckOutputDef() {
  std::string strs = "";
  std::string annotations = "";
  annotations += "   * 检查算法运行的结果,确保它们符合预期\n";
  annotations += "   *\n";
  annotations += "   * @return 如果输出数据有效,返回 true;否则返回 false\n";
  annotations += "   *\n";
  annotations += "   * 这个方法检查运行算法后得到的结果,确保它们在逻辑上是合理的\n";
  strs += AddAnotationBlock(annotations, "  ");
  strs += "  bool CheckOutput();\n";
  strs += "\n";
  return strs;
}

inline std::string GenUpdateAlignAnnotaion() {
  std::string annotations = "";
  annotations += "   * 更新算法执行过程中的对齐设置\n";
  annotations += "   *\n";
  annotations += "   * "
        "这个方法根据算法执行过程中的数据更新对齐提示值,确保结果按照预期的方式"
        "对齐\n";
  return AddAnotationBlock(annotations, "  ");
}

inline std::string GenBestAlignAnnotaion() {
  std::string strs = "";
  strs += "   * 为指定索引的 L0 变量获取最佳对齐值\n";
  strs += "   *\n";
  strs += "   * @param i 想要的变量索引值\n";
  strs += "   * @return 最佳对齐值\n";
  strs += "   *\n";
  strs += "   * 这个方法计算并返回给定索引值的 L0 变量的最佳对齐值,\n";
  strs += "   * 确保变量以最恰当的方式对齐,从而提高效率或者减少资源浪费\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenIterativeRunAnnotaion() {
  std::string strs = "";
  strs += "   * 为 L0 变量找到最优值进行迭代运行\n";
  strs += "   *\n";
  strs += "   * @param loop_id 当前循环索引,表示正在处理的 L0 变量的位置\n";
  strs += "   * @param best_var_value 一个指针,指向用于存储每个 L0\n";
  strs += "   * 变量迄今为止找到的最佳值的数组\n";
  strs += "   *\n";
  strs += "   * 这个函数使用递归方法来遍历 L0\n";
  strs += "   * "
        "变量的所有可能值。对于每个值,它检查是否满足约束条件,如小于上界且满足"
        "对齐要求。如果满足这些条件,它将继续下一个循环或者递归调用自身来处理下"
        "一个\n";
  strs += "   * L0 变量。如果是最后一个 L0\n";
  strs += "   * "
        "变量,它将检查当前组合是否满足优化条件,如核心数量和数据处理量。如果满"
        "足条件,它将当前组合存储为最优解。\n";
  strs += "   *\n";
  strs += "   * 请注意,这个函数没有返回值,而是将最优解存储在传入的 "
        "best_var_value\n";
  strs += "   * 数组中。\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenMaxCoreNumAnnotaion() {
  std::string strs = "";
  strs += "   * 根据 L0 变量信息和总核心数计算可以分配的最大核心数\n";
  strs += "   *\n";
  strs += "   * @param l0_vars 一个指向 L0Var 结构体数组的指针\n";
  strs += "   * @param core_num 可用于分配的总核心数\n";
  strs += "   * @return 可以分配的最大核心数\n";
  strs += "   *\n";
  strs += "   * 这个函数计算在给定 L0 "
        "变量信息和总核心数的情况下,可以分配的最大核心数。\n";
  strs += "   * 它遍历输入的 L0Var\n";
  strs += "   * "
        "结构体数组,对于每个变量,根据其是否绑定多核心以及最大、当前和提示对齐"
        "值计算所需的块数。\n";
  strs += "   * 通过将所有变量的块数相乘,得到总块数。\n";
  strs += "   * "
        "最大核心数是总块数和总核心数中的最小值,以确保核心数不会超过可用资源。"
        "\n";
  strs += "   *\n";
  strs += "   * 返回值表示可以分配给 L0\n";
  strs += "   * 变量的最大核心数,这对于在多核心系统中进行资源分配是有用的。\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenGetMacUseAnnotaion() {
  std::string strs = "";
  strs += "   * 计算所有 L0 变量值的乘积,作为mac计算量的度量\n";
  strs += "   *\n";
  strs += "   * @return mac计算量\n";
  strs += "   *\n";
  strs += "   * 这个函数计算所有 L0 变量值的乘积,结果是一个数字。\n";
  strs += "   * 这个数字可以作为数据处理量的度量,例如在评估算法性能时。\n";
  strs += "   * 通过不断更新 usage 变量,乘法操作确保了所有 L0 "
        "变量的影响都被计入。\n";
  strs += "   * 最终 usage 变量中的值就是所有 L0 "
        "变量值的乘积,代表了整体的数据处理量。\n";
  strs += "   * "
        "返回值可以帮助了解算法处理的数据量,从而对算法的效率和扩展性有更直观的"
        "认识。\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenSortedvarsAnnotaion() {
  std::string strs = "";
  strs += "   * 用于排序的 L0Var 对象数组\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenMaxcoreAnnotaion() {
  std::string strs = "";
  strs += "   * 最大核心数\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenMacuseAnnotaion() {
  std::string strs = "";
  strs += "   * 最大 MAC 使用量\n";
  return AddAnotationBlock(strs, "  ");
}

inline std::string GenL0TileSolver() {
  std::string strs = "";
  strs += GenL0SolverAnnotaion();
  strs += "class L0TileSolver {\n";
  strs += "public:\n";
  strs += GenL0SolverGenAnnotaion();
  strs += "  explicit L0TileSolver(L0TileInput input) : input_(input) {}\n";
  strs += "  L0TileSolver() {};\n";
  strs += GenL0SolverDegenDef();
  strs += GenRunAnnotaion();
  strs += "  bool Run();\n";
  strs += GenGetOutputAnnotaion();
  strs += "  uint32_t *GetOutput() { return output_; }\n";
  strs += "\n";
  strs += "protected:\n";
  strs += GenCheckBufferUseValidAnnotaion();
  strs += "  virtual bool CheckBufferUseValid() = 0;\n";
  strs += "  L0TileInput input_;\n";
  strs += "  uint32_t *output_{nullptr};\n";
  strs += "\n";
  strs += "private:\n";
  strs += GenCheckInputDef();
  strs += GenInitInputDef();
  strs += GenCheckOutputDef();
  strs += GenUpdateAlignAnnotaion();
  strs += "  void UpdateAlign();\n";
  strs += "\n";
  strs += GenBestAlignAnnotaion();
  strs += "  uint32_t GetBestAlign(uint32_t i) const;\n";
  strs += "\n";
  strs += GenIterativeRunAnnotaion();
  strs += "  void IterativeRun(uint32_t loop_id, uint32_t *best_var_value);\n";
  strs += "\n";
  strs += GenMaxCoreNumAnnotaion();
  strs += "  int32_t MaxCoreNum(const L0Var *l0_vars, const uint32_t "
        "&core_num);\n";
  strs += "\n";
  strs += GenGetMacUseAnnotaion();
  strs += "  uint32_t GetMacUse() const;\n";
  strs += GenSortedvarsAnnotaion();
  strs += "  L0Var *sortedvars_{nullptr};\n";
  strs += GenMaxcoreAnnotaion();
  strs += "  int64_t max_corenum_{-1};\n";
  strs += GenMacuseAnnotaion();
  strs += "  int32_t max_macuse_{-1};\n";
  strs += "};\n";
  strs += "\n";
  return strs;
}

inline std::string GenGetBestAlignFuncAnnotation() {
  std::string strs = "";
  strs += " * 获取给定索引的 L0 变量的最佳对齐值\n";
  strs += " *\n";
  strs += " * @param i 想要的变量索引值\n";
  strs += " * @return 最佳对齐值\n";
  strs += " *\n";
  strs += " * 这个方法为给定索引的 L0\n";
  strs += " * "
        "变量计算最佳对齐值。它考虑到变量的最大、当前和提示对齐值,以确保数据存"
        "储和访问的效率。\n";
  strs += " * "
        "根据变量的原始值(ori_"
        "value),它首先确定最小和最大对齐值的范围。然后,通过在这个范围内以二"
        "的幂次方递增,它找到最大的满足条件的值。\n";
  strs += " * "
        "如果没有找到这样的值,它将返回最小对齐值。如果在范围内找到了一个值,它"
        "将返回这个值的二分之一,作为最佳对齐值。\n";
  strs += " * 这个最佳对齐值可以用于确保数据以最有效的方式存储\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenGetBestAlignFunc() {
  std::string strs = "";
  strs += GenGetBestAlignFuncAnnotation();
  strs += "uint32_t L0TileSolver::GetBestAlign(uint32_t i) const {\n";
  strs += "  uint32_t ori_value = input_.l0_vars[i].max_value;\n";
  strs += "  uint32_t min_align = input_.l0_vars[i].align;\n";
  strs += "  uint32_t max_align = input_.l0_vars[i].prompt_align;\n";
  strs += "  uint32_t ori_align = min_align;\n";
  strs += "  uint32_t max_value = std::min(ori_value, max_align);\n";
  strs += "  while (ori_align <= max_value) {\n";
  strs += "    ori_align = ori_align << 1;\n";
  strs += "  }\n";
  strs += "  if (ori_align == min_align) {\n";
  strs += "    return min_align;\n";
  strs += "  }\n";
  strs += "  return std::max(1u, ori_align >> 1);\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenMaxCoreNumFuncAnnotation() {
  std::string strs = "";
  strs += " * 根据给定的 L0 变量信息计算可以分配的最大核心数\n";
  strs += " *\n";
  strs += " * @param l0_vars 指向 L0Var 结构数组的指针\n";
  strs += " * @param core_num 总核心数\n";
  strs += " * @return 可以分配的最大核心数\n";
  strs += " *\n";
  strs += " * 这个函数遍历 L0Var 结构数组,根据每个变量的 bind_multicore "
        "属性以及\n";
  strs += " * max_value、value 和 prompt_align\n";
  strs += " * "
        "的值来计算每个变量所需的块数。对于绑定多核心的变量,块数计算方式为:("
        "max_value\n";
  strs += " * + max(value,prompt_align)-1)/\n";
  strs += " * max(value,prompt_align)。对于未绑定多核心的变量,块数为 1。\n";
  strs += " * 总块数通过将所有变量的块数相乘得到。之后,通过比较总块数和\n";
  strs += " * "
        "core_"
        "num,返回两者中的最小值,作为可以分配的最大核心数。如果总块数超过了\n";
  strs += " * core_num,那么系统的核心数将成为瓶颈,因此需要将 core_num\n";
  strs += " * 设置为最大核心数。如果总块数小于等于\n";
  strs += " * core_num,那么总块数就是可以分配的最大核心数。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenMaxCoreNumFunc() {
  std::string strs = "";
  strs += GenMaxCoreNumFuncAnnotation();
  strs += "int32_t L0TileSolver::MaxCoreNum(const L0Var *l0_vars,\n";
  strs += "                                 const uint32_t &core_num) {\n";
  strs += "  uint32_t total_block_size = 1u;\n";
  strs += "  for (uint32_t i = 0u; i < input_.size; i++) {\n";
  strs += "    auto var = l0_vars[i];\n";
  strs += "    uint32_t block_num =\n";
  strs += "        var.bind_multicore\n";
  strs += "            ? ((var.max_value + std::max(var.value, var.prompt_align) "
        "- 1)) /\n";
  strs += "                  std::max(var.value, var.prompt_align)\n";
  strs += "            : 1;\n";
  strs += "    total_block_size *= block_num;\n";
  strs += "  }\n";
  strs += "  int64_t max_core_num =\n";
  strs += "      total_block_size > core_num ? core_num : total_block_size;\n";
  strs += "  return max_core_num;\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenGetMacUseFuncAnnotation() {
  std::string strs = "";
  strs += " * 计算所有 L0 变量值的乘积,作为数据处理量的度量\n";
  strs += " *\n";
  strs += " * @return 数据处理量\n";
  strs += " *\n";
  strs += " * 这个函数遍历 L0TileInput 结构体中的所有 L0Var 对象,计算它们的 "
        "value\n";
  strs += " * 属性的乘积。这个乘积代表了所有 L0\n";
  strs += " * 变量值的联合效应,或者说数据处理量的一个度量。 通过不断更新 "
        "usage\n";
  strs += " * 变量,乘法操作确保了所有 L0 变量的贡献都被包含在内。最终 usage\n";
  strs += " * 变量中的值就是所有 L0 变量值的乘积。\n";
  strs += " * "
        "返回值可以帮助评估算法在处理给定输入数据时的效率,以及比较不同算法或优"
        "化策略的数据处理量。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenGetMacUseFunc() {
  std::string strs = "";
  strs += GenGetMacUseFuncAnnotation();
  strs += "uint32_t L0TileSolver::GetMacUse() const {\n";
  strs += "  uint32_t usage = 1u;\n";
  strs += "  for (uint32_t j = 0; j < input_.size; j++) {\n";
  strs += "    usage *= input_.l0_vars[j].value;\n";
  strs += "  }\n";
  strs += "  return usage;\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenIterativeRunFuncAnnotation() {
  std::string strs = "";
  strs += " * 为 L0 变量找到最优值进行迭代运行\n";
  strs += " *\n";
  strs += " * @param loop_id 当前循环索引,表示正在处理的 L0 变量的位置\n";
  strs += " * @param best_var_value 一个指针,指向用于存储每个 L0 "
        "变量迄今为止找到的最佳值的数组\n";
  strs += " *\n";
  strs += " * 这个函数使用递归方法来遍历 L0 "
        "变量的所有可能值。对于每个值,它检查是否满足约束条件,如小于上界且满足"
        "对齐要求。\n";
  strs += " * 如果满足这些条件,它将继续下一个循环或者递归调用自身来处理下一个L0 "
        "变量。\n";
  strs += " * 如果是最后一个 L0 "
        "变量,它将检查当前组合是否满足优化条件,如核心数量和数据处理量。如果满"
        "足条件,它将当前组合存储为最优解。\n";
  strs += " *\n";
  strs += " * 请注意,这个函数没有返回值,而是将最优解存储在传入的 "
        "best_var_value 数组中。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenIterativeRunFunc() {
  std::string strs = "";
  strs += GenIterativeRunFuncAnnotation();
  strs += "void L0TileSolver::IterativeRun(uint32_t loop_id, uint32_t "
        "*best_var_value) {\n";
  strs += "  for (uint32_t i = 0u; i < candidate_size; i++) {\n";
  strs += "    uint32_t candi_value = candidate_value[i];\n";
  strs += "    const auto &l0_tile = sortedvars_[loop_id];\n";
  strs += "    // L0Var的上限\n";
  strs += "    uint32_t upper_bound = l0_tile.max_value * UPPER_BOUND_RATIO;\n";
  strs += "    if (candi_value >= upper_bound) {\n";
  strs += "      continue;\n";
  strs += "    }\n";
  strs += "    // 必须满足prompt_align对齐或者是prompt_align的因子\n";
  strs += "    if ((candi_value % l0_tile.prompt_align != 0) &&\n";
  strs += "        (l0_tile.prompt_align % candi_value != 0)) {\n";
  strs += "      continue;\n";
  strs += "    }\n";
  strs += "    auto idx = l0_tile.idx;\n";
  strs += "    input_.l0_vars[idx].value = candi_value;\n";
  strs += "    // 终止条件为遍历到最后一个变量\n";
  strs += "    if (loop_id == input_.size - 1) {\n";
  strs += "      if (!CheckBufferUseValid()) {\n";
  strs += "        break;\n";
  strs += "      }\n";
  strs += "      int32_t usage = GetMacUse();\n";
  strs += "      int32_t core_num = MaxCoreNum(input_.l0_vars, "
        "input_.core_num);\n";
  strs += "      // "
        "最大核数如果满足核数*系数(默认0."
        "6),则比较mac利用率即可,否则需要比较核数的使用和mac利用率\n";
  strs += "      if (((core_num >= max_corenum_) ||\n";
  strs += "           (core_num >=\n";
  strs += "            static_cast<int32_t>(input_.core_num * CORE_NUM_RATIO))) "
        "&&\n";
  strs += "          (usage >= max_macuse_)) {\n";
  strs += "        max_corenum_ = core_num;\n";
  strs += "        max_macuse_ = usage;\n";
  strs += "        for (uint32_t k = 0u; k < input_.size; k++) {\n";
  strs += "          best_var_value[k] = input_.l0_vars[k].value;\n";
  strs += "        }\n";
  strs += "      }\n";
  strs += "    } else {\n";
  strs += "      IterativeRun(loop_id + 1, best_var_value);\n";
  strs += "    }\n";
  strs += "  }\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenUpdateAlignFuncAnnotation() {
  std::string strs = "";
  strs += " * 更新 L0Var 对象的对齐值\n";
  strs += " *\n";
  strs += " * 这个函数用于更新 L0Var 对象的 prompt_align\n";
  strs += " * 值,以确保它们在内存中按照最优方式对齐。它遍历 input_ 对象中的 "
        "l0_vars\n";
  strs += " * 数组,为每个 L0Var 对象计算并设置最佳的对齐值。\n";
  strs += " *\n";
  strs += " * @param 无\n";
  strs += " * @return 无\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenUpdateAlignFunc() {
  std::string strs = "";
  strs += GenUpdateAlignFuncAnnotation();
  strs += "void L0TileSolver::UpdateAlign() {\n";
  strs += "  for (uint32_t i = 0u; i < input_.size; i++) {\n";
  strs += "    uint32_t best_align = GetBestAlign(i);\n";
  strs += "    input_.l0_vars[i].prompt_align = best_align;\n";
  strs += "  }\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenCheckInputFuncAnnotation() {
  std::string strs = "";
  strs += " * 检查输入数据的有效性\n";
  strs += " *\n";
  strs += " * 这个函数用来检查 L0TileSolver "
        "类的输入数据是否有效。它验证以下几个方面:\n";
  strs += " * - 基础变量指针(l0_vars)是否为空。\n";
  strs += " * - 输入数据的大小(size)是否为0,表示没有 L0 参数需要求解。\n";
  strs += " * - "
        "输入数据的大小(size)是否超过最大支持的参数数量(MAX_L0_VAR_"
        "NUM)。\n";
  strs += " * - 核心数量(core_num)是否为0。\n";
  strs += " * - 对于输入数据中的每个 L0Var 对象(通过索引 i "
        "访问),它检查几个属性:\n";
  strs += " *   - max_value、align 和 prompt_align 是否都不等于0。\n";
  strs += " *   - align 是否不大于 prompt_align。\n";
  strs += " *\n";
  strs += " * 如果以上任何一个条件不满足,函数将通过 OP_LOG "
        "宏记录一条错误消息,并返回\n";
  strs += " * false,表示输入无效。如果所有条件都满足,函数返回 "
        "true,表示输入有效。\n";
  strs += " *\n";
  strs += " * @return 如果输入数据有效,则返回 true;否则返回 false。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenCheckInputFunc() {
  std::string strs = "";
  strs += GenCheckInputFuncAnnotation();
  strs += "bool L0TileSolver::CheckInput() {\n";
  strs += "  if (input_.l0_vars == nullptr) {\n";
  strs += "    OP_LOGW(OP_NAME, \"Input basevar is null\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  if (input_.size == 0u) {\n";
  strs += "    OP_LOGW(OP_NAME, \"Size is 0, no l0 arg to be solved\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  if (input_.size > MAX_L0_VAR_NUM) {\n";
  strs += "    OP_LOGW(OP_NAME, \"L0 solver does not support more than 3 input args\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  if (input_.core_num == 0) {\n";
  strs += "    OP_LOGW(OP_NAME, \"Corenum is 0\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  for (uint32_t i = 0u; i < input_.size; i++) {\n";
  strs += "    auto var = input_.l0_vars[i];\n";
  strs += "    if ((var.max_value == 0) || (var.align == 0) || (var.prompt_align "
        "== 0)) {\n";
  strs += "      OP_LOGW(OP_NAME, \"Input [%u] exists 0\", i);\n";
  strs += "      return false;\n";
  strs += "    }\n";
  strs += "    if (var.align > var.prompt_align) {\n";
  strs += "      OP_LOGW(OP_NAME, \"Input [%u] align is larger than prompt align\", i);\n";
  strs += "      return false;\n";
  strs += "    }\n";
  strs += "  }\n";
  strs += "  return true;\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenInitInputFuncAnnotation() {
  std::string strs = "";
  strs += " * 初始化 L0Var 对象数组\n";
  strs += " *\n";
  strs += " * 这个函数用于初始化 L0TileSolver 类的 input_ 对象中的 l0_vars "
        "数组。它遍历\n";
  strs += " * l0_vars 数组中的每一个元素,对于每个元素,执行以下操作:\n";
  strs += " * 1. 通过访问索引 i 对应的 L0Var 对象的引用 var,重置其 max_value\n";
  strs += " * 属性。具体重置方式是,先将 max_value 增加 align 属性值减 1,再除以 "
        "align\n";
  strs += " * 属性值,最后乘以 align 属性值。这样做的目的可能是为了确保 "
        "max_value 是 align\n";
  strs += " * 的整数倍。\n";
  strs += " * 2. 将当前循环的索引值 i 设置为 var 的 idx "
        "属性。这可能是为了标记每个 L0Var\n";
  strs += " * 对象在数组中的位置,以便后续处理。\n";
  strs += " *\n";
  strs += " * @param 无\n";
  strs += " * @return 无\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenInitInputFunc() {
  std::string strs = "";
  strs += GenInitInputFuncAnnotation();
  strs += "void L0TileSolver::InitInput() {\n";
  strs += "  for (uint32_t i = 0u; i < input_.size; i++) {\n";
  strs += "    auto &var = input_.l0_vars[i];\n";
  strs += "    var.max_value = (var.max_value + var.align - 1) / var.align * "
        "var.align;\n";
  strs += "    var.idx = i;\n";
  strs += "  }\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenCheckOutputFuncAnnotation() {
  std::string strs = "";
  strs += " * 检查输出数据的有效性\n";
  strs += " *\n";
  strs += " * 这个函数用于检查 L0TileSolver 类的输出数据是否有效。它首先检查 "
        "output_\n";
  strs += " * 指针是否为空。如果 output_ 指针为空,通过 OP_LOG "
        "宏记录一条错误消息,并返回\n";
  strs += " * false,表示输出无效。\n";
  strs += " *\n";
  strs += " * 接着,函数遍历 output_ "
        "数组中的每个元素。对于每个元素,它检查其值是否为\n";
  strs += " * 0。如果发现任何一个元素的值为 0,函数会通过 OP_LOG\n";
  strs += " * 宏记录相应的错误消息,并返回 "
        "false,表示输出数据中存在无效的元素。\n";
  strs += " *\n";
  strs += " * 如果输出数据有效,即 output_ 指针不为空且 output_ 数组中没有 0\n";
  strs += " * 值元素,函数返回 true。\n";
  strs += " *\n";
  strs += " * @return 如果输出数据有效,则返回 true;否则返回 false。\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenCheckOutputFunc() {
  std::string strs = "";
  strs += GenCheckOutputFuncAnnotation();
  strs += "bool L0TileSolver::CheckOutput() {\n";
  strs += "  if (output_ == nullptr) {\n";
  strs += "    OP_LOGW(OP_NAME, \"Output is null\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  for (uint32_t i = 0u; i < input_.size; i++) {\n";
  strs += "    if (output_[i] == 0u) {\n";
  strs += "      OP_LOGW(OP_NAME, \"Output [%u] is 0\", i);\n";
  strs += "      return false;\n";
  strs += "    }\n";
  strs += "  }\n";
  strs += "  return true;\n";
  strs += "}\n";
  strs += "\n";
  return strs;
}

inline std::string GenRunFuncAnnotation() {
  std::string strs = "";
  strs += " * 执行 L0TileSolver 类的主要流程\n";
  strs += " * @return 如果所有操作成功并且输出有效,则返回 true;否则返回false\n";
  return AddAnotationBlock(strs, "");
}

inline std::string GenRunFunc() {
  std::string strs = "";
  strs += GenRunFuncAnnotation();
  strs += "bool L0TileSolver::Run() {\n";
  strs += "  // 检查输入数据的有效性\n";
  strs += "  if (!CheckInput()) {\n";
  strs += "    // 如果输入检查失败,则记录一条错误日志,并返回 false\n";
  strs += "    OP_LOGW(OP_NAME, \"Check input failed\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  // 初始化输入数据\n";
  strs += "  InitInput();\n";
  strs += "  // 更新 L0Var 对象的对齐值\n";
  strs += "  UpdateAlign();\n";
  strs += "  output_ = new (std::nothrow) uint32_t[input_.size]();\n";
  strs += "  bool is_fast_mode = true;\n";
  strs += "  for (uint32_t i=0u; i < input_.size; i++) {\n";
  strs += "    auto &var = input_.l0_vars[i];\n";
  strs += "    uint32_t upper_bound = var.max_value * UPPER_BOUND_RATIO;\n";
  strs += "    if ((var.value == 0u) || (var.value > upper_bound)) {\n";
  strs += "      is_fast_mode = false;\n";
  strs += "      break;\n";
  strs += "    }\n";
  strs += "  }\n";
  strs += "  if (is_fast_mode && CheckBufferUseValid()) {\n";
  strs += "    for (uint32_t k=0u; k < input_.size; k++) {\n";
  strs += "      output_[k] = input_.l0_vars[k].value;\n";
  strs += "    }\n";
  strs += "  } else {\n";
  strs += "  // 为排序后的变量申请内存,并初始化为 0\n";
  strs += "    sortedvars_ = new (std::nothrow) L0Var[input_.size];\n";
  strs += "  // 将输入数据复制到新的内存中\n";
  strs += "    std::copy(input_.l0_vars, input_.l0_vars + input_.size, sortedvars_);\n";
  strs += "  // 根据比较函数对变量进行排序\n";
  strs += "    std::sort(sortedvars_, sortedvars_ + input_.size, L0VarCmp);\n";
  strs += "    // 调用 IterativeRun 函数,传递参数 0 和 output_ 数组的指针\n";
  strs += "    IterativeRun(0u, output_);\n";
  strs += "  }\n";
  strs += "  // 检查输出数据的有效性\n";
  strs += "  if (!CheckOutput()) {\n";
  strs += "    // 如果输出检查失败,则记录一条错误日志,并返回 false\n";
  strs += "    OP_LOGW(OP_NAME, \"Check output failed\");\n";
  strs += "    return false;\n";
  strs += "  }\n";
  strs += "  // 如果所有操作都成功,返回 true\n";
  strs += "  return true;\n";
  strs += "}\n";
  return strs;
}

inline std::string GetL0SolverHead() {
  std::string strs = "";
  strs += GenVarDef();
  strs += GenL0VarDef();
  strs += GenL0Input();
  strs += GenL0VarCmp();
  strs += GenL0TileSolver();
  return strs;
}

inline std::string GetL0SolverFunc() {
  std::string strs = "";
  strs += GenGetBestAlignFunc();
  strs += GenMaxCoreNumFunc();
  strs += GenGetMacUseFunc();
  strs += GenIterativeRunFunc();
  strs += GenUpdateAlignFunc();
  strs += GenCheckInputFunc();
  strs += GenInitInputFunc();
  strs += GenCheckOutputFunc();
  strs += GenRunFunc();
  return strs;
}

inline const std::string L0_SOLVER_CODE_HEAD = GetL0SolverHead();
inline const std::string L0_SOLVER_CODE_FUNC = GetL0SolverFunc();
}  // namespace att
#endif