/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "vec_func_call.h"

#include <sstream>
#include "attr_utils.h"
#include "ascir_ops.h"
#include "common/ge_common/debug/log.h"
#include "graph/ascendc_ir/utils/asc_tensor_utils.h"
#include "common/checker.h"
#include "api_call/utils/api_call_factory.h"
#include "api_call/utils/api_call_utils.h"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"

namespace {
constexpr size_t kVFMaxLoop = 4U;
constexpr size_t kCommaSpaceLength = 2U;  // 逗号和空格的长度 ", "
}  // namespace

namespace codegen {
using namespace std;
using namespace af::ops;
using namespace af::ascir_op;
using namespace ascgen_utils;

namespace {
void ParamPostProcess(std::stringstream &ss) {
  if (ss.str().length() < kCommaSpaceLength) {
    return;
  }
  std::string str = ss.str();
  if (str.substr(str.length() - kCommaSpaceLength) == ", ") {
    str = str.substr(0, str.length() - kCommaSpaceLength);
    ss.str("");
    ss << str;
  }
  return;
}

void CreateTensorAddr(const std::vector<Tensor> &tensors, const std::vector<std::string> &ub_offsets,
                      const std::vector<Tensor> &tensors_scalar, std::stringstream &ss) {
  size_t count = ub_offsets.size();
  string dtype_name;
  for (size_t i = 0; i < tensors.size(); i++) {
    Tensor::DtypeName(tensors[i].dtype, dtype_name);
    std::string ub_offset = count == 0 ? "0" : ub_offsets[i];
    ss << "(__local_mem__ " << dtype_name << " *)" << tensors[i] << "[" << ub_offset << "].GetPhyAddr(), ";
  }
  for (const auto &scalar : tensors_scalar) {
    ss << scalar << ", ";
  }
}

void GetOuterForOffset(const TPipe &tpipe, const std::vector<std::vector<ascir::SizeExpr>> &strides,
                       std::vector<std::string> &outer_offsets) {
  for (size_t i = 0; i < strides.size(); i++) {
    outer_offsets.emplace_back(CalcInnerOffset(tpipe, strides[i]));
  }
}

void CreateSingleStridesParamsInfo(const Tensor &tensor, const std::vector<ascir::SizeExpr> &strides,
                                   std::stringstream &ss) {
  size_t stride_size = strides.size();
  size_t start_idx = stride_size <= kVFMaxLoop ? 0 : stride_size - kVFMaxLoop;
  for (; start_idx < stride_size; start_idx++) {
    // 在生成函数体时,能从图上判断出来0和1的stride,这些轴对应的stride信息可以在代码生成时,直接生成到代码中,不需要额外进行传递
    if (strides[start_idx].Simplify() == af::ops::One || strides[start_idx].Simplify() == af::ops::Zero) {
      continue;
    }
    ss << "uint32_t " << tensor << "_stride_" << start_idx << ", ";
  }
}

void CreateSingleStridesInfo(const TPipe &tpipe, const std::vector<ascir::SizeExpr> &strides, std::stringstream &ss) {
  size_t stride_size = strides.size();
  size_t start_idx = stride_size <= kVFMaxLoop ? 0 : stride_size - kVFMaxLoop;
  for (; start_idx < stride_size; start_idx++) {
    // 在生成函数体时,能从图上判断出来0和1的stride,这些轴对应的stride信息可以在代码生成时,直接生成到代码中,不需要额外进行传递
    auto current_stride = strides[start_idx].Simplify();
    bool current_stride_is_one =
        (af::SymbolicUtils::StaticCheckEq(current_stride, af::sym::kSymbolOne) == af::TriBool::kTrue);
    bool current_stride_is_zero =
        (af::SymbolicUtils::StaticCheckEq(current_stride, af::sym::kSymbolZero) == af::TriBool::kTrue);
    if (current_stride_is_one || current_stride_is_zero) {
      continue;
    }
    ss << tpipe.tiler.Size(strides[start_idx]) << ", ";
  }
}

void GetOuterForStride(const std::vector<std::vector<ascir::SizeExpr>> &origin_strides,
                       std::vector<std::vector<ascir::SizeExpr>> &target_strides) {
  for (size_t i = 0; i < origin_strides.size(); i++) {
    std::vector<ascir::SizeExpr> strides(origin_strides[i].begin(), origin_strides[i].end() - kVFMaxLoop);
    target_strides.emplace_back(strides);
  }
}

// 生成vf函数体时,用于处理vf函数入参中的main scalar
void CreateVFCallDimAndStrideParmas(const std::vector<Tensor> &inputs, const std::vector<Tensor> &inputs_scalar,
                                    const std::vector<Tensor> &outputs, const VectorizedAxisLoopMergeStatus &merge_info,
                                    std::stringstream &ss) {
  string dtype_name;
  for (const auto &output : outputs) {
    Tensor::DtypeName(output.dtype, dtype_name);
    ss << "__local_mem__ " << dtype_name << " *" << output << "_addr" << ", ";
  }

  for (const auto &input : inputs) {
    Tensor::DtypeName(input.dtype, dtype_name);
    ss << "__local_mem__ " << dtype_name << " *" << input << "_addr" << ", ";
  }

  for (const auto &in_scalar : inputs_scalar) {
    Tensor::DtypeName(in_scalar.dtype, dtype_name);
    ss << dtype_name << " " << in_scalar << ", ";
  }

  size_t dim_size = merge_info.merge_repeats_str.size();
  size_t start_idx = dim_size <= kVFMaxLoop ? 0 : dim_size - kVFMaxLoop;
  for (; start_idx < dim_size; start_idx++) {
    ss << "uint32_t output_dims_" << start_idx << ", ";
  }
  // 生成output stride 参数
  for (size_t i = 0; i < merge_info.outputs_strides.size(); i++) {
    CreateSingleStridesParamsInfo(outputs[i], merge_info.outputs_strides[i], ss);
  }
  // 生成input stride 参数
  for (size_t i = 0; i < merge_info.inputs_strides.size(); i++) {
    CreateSingleStridesParamsInfo(inputs[i], merge_info.inputs_strides[i], ss);
  }
  ParamPostProcess(ss);
}

// 生成函数调用入参
void CreateDimAndStrideParmas(const TPipe &tpipe, const VectorizedAxisLoopMergeStatus &merge_info,
                              std::stringstream &ss) {
  // 生成输入dims
  size_t dim_size = merge_info.merge_repeats_str.size();
  size_t start_idx = dim_size <= kVFMaxLoop ? 0 : dim_size - kVFMaxLoop;
  for (; start_idx < dim_size; start_idx++) {
    ss << merge_info.merge_repeats_str[start_idx] << ", ";
  }
  // 生成output stride
  for (size_t i = 0; i < merge_info.outputs_strides.size(); i++) {
    CreateSingleStridesInfo(tpipe, merge_info.outputs_strides[i], ss);
  }
  // 生成input stride
  for (size_t i = 0; i < merge_info.inputs_strides.size(); i++) {
    CreateSingleStridesInfo(tpipe, merge_info.inputs_strides[i], ss);
  }
  ParamPostProcess(ss);
}

void CreateVFCall(const TPipe &tpipe, const std::string &vf_call_name, const std::vector<Tensor> &inputs,
                  const std::vector<Tensor> &outputs, const std::vector<std::string> &input_ub_offsets,
                  const std::vector<std::string> &output_ub_offsets, const std::vector<Tensor> &tensors_scalar,
                  const VectorizedAxisLoopMergeStatus &merge_info, std::stringstream &ss) {
  ss << vf_call_name << "(";
  CreateTensorAddr(outputs, output_ub_offsets, {}, ss);
  CreateTensorAddr(inputs, input_ub_offsets, tensors_scalar, ss);
  CreateDimAndStrideParmas(tpipe, merge_info, ss);
  ss << ");" << std::endl;
}

void CreateOuterForVFCall(const TPipe &tpipe, const std::string &vf_call_name, const std::vector<Tensor> &inputs,
                          const std::vector<Tensor> &outputs, const std::vector<Tensor> &tensors_scalar,
                          const VectorizedAxisLoopMergeStatus &merge_info, std::stringstream &ss) {
  std::vector<std::string> repeats(merge_info.merge_repeats_str.begin(),
                                   merge_info.merge_repeats_str.end() - kVFMaxLoop);
  std::vector<std::vector<ascir::SizeExpr>> inputs_strides;
  std::vector<std::vector<ascir::SizeExpr>> outputs_strides;
  std::vector<std::string> inputs_ub_offsets;
  std::vector<std::string> outputs_ub_offsets;
  GetOuterForStride(merge_info.inputs_strides, inputs_strides);
  GetOuterForStride(merge_info.outputs_strides, outputs_strides);
  GetOuterForOffset(tpipe, inputs_strides, inputs_ub_offsets);
  GetOuterForOffset(tpipe, outputs_strides, outputs_ub_offsets);
  std::stringstream ss1;
  CreateVFCall(tpipe, vf_call_name, inputs, outputs, inputs_ub_offsets, outputs_ub_offsets, tensors_scalar, merge_info,
               ss1);
  CreateComputeNodeOuterFor(repeats, ss1, ss, 0);
}

void GetVFCallFuncBody(const std::string &params, const std::string &vf_body, std::stringstream &ss) {
  ss << "{" << std::endl;
  ss << params << std::endl;
  ss << "  __VEC_SCOPE__\n";
  ss << "  {\n";
  ss << vf_body << std::endl;
  ss << "  }\n";
  ss << "}\n" << std::endl;
}

size_t GeOriginLastAxisPos(const Tiler &tiler, const std::vector<ascir::AxisId> &current_axis_ids,
                           const std::vector<std::vector<ascir::AxisId>> &origin_axis_ids) {
  size_t axis_num = current_axis_ids.size();
  if (origin_axis_ids.size() != axis_num) {
    return axis_num - static_cast<size_t>(1);
  }
  const auto &origin_last_axis = origin_axis_ids.back();
  for (size_t i = 0; i < current_axis_ids.size(); i++) {
    const auto &axis = tiler.GetAxis(current_axis_ids[i]);
    if (axis.type == ascir::Axis::Type::kAxisTypeMerged) {
      std::set<ascir::AxisId> current_ids;
      for (const auto &from : axis.from) {
        current_ids.insert(from);
      }
      size_t match_count = 0;
      for (size_t j = 0; j < origin_last_axis.size(); j++) {
        match_count = (current_ids.count(origin_last_axis[j]) != 0) ? match_count + 1 : match_count;
      }
      if (match_count == origin_last_axis.size()) {
        return i;
      }
      continue;
    }
    if (origin_last_axis.size() == 1 && origin_last_axis.back() == current_axis_ids[i]) {
      return i;
    }
  }
  return axis_num - static_cast<size_t>(1);
}

}  // namespace

void VfCall::SetNodeAxisIds(const std::vector<ascir::AxisId> &origin_axis_ids) {
  if (origin_axis_ids.size() < axis_ids_.size()) {
    return;
  }
  axis_ids_ = origin_axis_ids;
}

Status VfCall::ParseAttr(const ascir::NodeView &node) {
  vf_call_name_ = "VFCall" + node->GetName();
  af::AscGraph sub_graph("sub_graph");
  GE_ASSERT_SUCCESS(af::AscGraphUtils::FromComputeGraph(node->GetOwnerComputeGraph(), sub_graph),
                    "Get sub_graph failed, node:%s", node->GetNamePtr());
  return ParseSubGraph(node, sub_graph);
}

bool VfCall::ShouldInitAsMaskReg(const ascir::NodeView &node, const af::AscTensor *output) const {
  // 条件1:生产者是 Compare 算子
  bool is_compare_output =
      (IsOps<Ge>(node) || IsOps<Eq>(node) || IsOps<Le>(node) || IsOps<Ne>(node) || IsOps<Gt>(node) || IsOps<Lt>(node));
  if (!is_compare_output) {
    return false;
  }

  // 条件2:所有消费者都必须是 Where/Select 的 mask 输入(idx=0)
  const auto &peer_inputs = output->anchor.GetPeerInDataAnchors();
  if (peer_inputs.empty()) {
    return false;
  }
  for (const auto &peer_input : peer_inputs) {
    auto output_node = std::dynamic_pointer_cast<af::AscNode>(peer_input->GetOwnerNode());
    // 只要有一个消费者不是 Where/Select 的 mask 输入,就返回 false
    if (!(IsOps<Where>(output_node) || IsOps<Select>(output_node)) || (peer_input->GetIdx() != 0)) {
      return false;
    }
  }

  // 同时满足才返回 true
  return true;
}

Status VfCall::ParseSubGraph(const ascir::NodeView &vf_node, const ascir::ImplGraph &graph) {
  // 从节点上读取sub_graph_name属性
  const std::string *graph_name = af::AttrUtils::GetStr(vf_node->GetOpDescBarePtr(), "sub_graph_name");
  GE_ASSERT_NOTNULL(graph_name, "Get sub graph name failed, vf node:%s", vf_node->GetNamePtr());
  af::AscGraph sub_graph("vf_sub_graph");
  GE_ASSERT_SUCCESS(graph.FindSubGraph(*graph_name, sub_graph), "Get sub_graph failed, vf node:%s, sub_graph_name:%s",
                    vf_node->GetNamePtr(), graph_name->c_str());
  GELOGI("VF node:%s, sub_graph_name:%s", vf_node->GetNamePtr(), graph_name->c_str());

  uint32_t max_dtype_size = 0;
  for (auto node : sub_graph.GetAllNodes()) {
    // subgraph上的Load api直接使用Tpipe上保存的UB tensor, 因此vf子图上Data节点的输出Tensor不必保存在tensor manager中.
    if (IsOps<Output>(node) || IsOps<Data>(node) || IsOps<Scalar>(node)) {
      continue;
    }
    // broadcast inline场景,子图内会对轴进行重排序,生成vf代码时,需要找到原始最内层轴所在的位置,用于生成UpdateMask动作
    SetNodeAxisIds(node->attr.sched.axis);
    auto desc = node->GetOpDesc();
    for (auto output : node->outputs()) {
      auto output_index = af::ascir::AscTensorUtils::Index(*output);
      auto tensor_name = node->GetName() + "_" + desc->GetOutputNameByIndex(output_index);

      uint32_t dtype_size = 0;
      std::string dtype_name;
      GE_CHK_STATUS_RET(Tensor::DtypeName(output->attr.dtype, dtype_name), "Codegen get data type:%d failed",
                        static_cast<int32_t>(output->attr.dtype));
      dtype_size = af::GetSizeByDataType(output->attr.dtype);
      if (dtype_size > max_dtype_size || this->max_dtype_size_ == "") {
        this->max_dtype_size_ = dtype_name;
        max_dtype_size = dtype_size;
      }
      auto init_as_mask_reg = ShouldInitAsMaskReg(node, output);
      MicroApiTensor tensor(*output, dtype_name, init_as_mask_reg);
      GE_CHK_STATUS_RET(tensor_mgr_.AddTensor(tensor), "Codegen add tensor failed");
    }
  }
  root_loop_.SetMaxDtypeSize(this->max_dtype_size_);
  // Parse for loop
  return root_loop_.ConstructFromNodes(sub_graph.GetAllNodes(), vf_node);
}

Status VfCall::ParseInputOutputInfo(const TPipe &tpipe) const {
  for (const auto &in : inputs) {
    auto tensor_ptr = tpipe.GetTensor(in->id);
    GE_CHK_BOOL_RET_STATUS(tensor_ptr != nullptr, ge::FAILED, "Check[Param] tensor_ptr is nullptr");
    if (tensor_ptr->IsConstScalar()) {
      scalar_inputs_.emplace_back(*tensor_ptr);
    } else {
      ub_inputs_.emplace_back(*tensor_ptr);
    }
  }

  for (const auto &out : outputs) {
    auto tensor_ptr = tpipe.GetTensor(out.id);
    GE_CHK_BOOL_RET_STATUS(tensor_ptr != nullptr, ge::FAILED, "Check[Param] tensor_ptr is nullptr");
    ub_outputs_.emplace_back(*tensor_ptr);
  }
  // 处理合轴信息
  return ge::SUCCESS;
}

void GenerateStridesEqualCheck(const std::vector<Tensor> &inputs, const std::vector<Tensor> &outputs,
                                     const VectorizedAxisLoopMergeStatus &merge_info, std::stringstream &ss) {
  std::vector<std::string> all_stride_names;
  for (size_t j = 0; j < merge_info.inputs_strides.size(); j++) {
    size_t stride_size = merge_info.inputs_strides[j].size();
    size_t start_idx = stride_size <= kVFMaxLoop ? 0 : stride_size - kVFMaxLoop;
    for (; start_idx < stride_size; start_idx++) {
      if (merge_info.inputs_strides[j][start_idx].Simplify() == af::ops::One || 
          merge_info.inputs_strides[j][start_idx].Simplify() == af::ops::Zero) {
        continue;
      }
      std::stringstream tensor_name;
      tensor_name << inputs[j];
      all_stride_names.push_back(tensor_name.str() + "_stride_" + std::to_string(start_idx));
    }
  }

  for (size_t j = 0; j < merge_info.outputs_strides.size(); j++) {
    size_t stride_size = merge_info.outputs_strides[j].size();
    size_t start_idx = stride_size <= kVFMaxLoop ? 0 : stride_size - kVFMaxLoop;
    for (; start_idx < stride_size; start_idx++) {
      if (merge_info.outputs_strides[j][start_idx].Simplify() == af::ops::One || 
          merge_info.outputs_strides[j][start_idx].Simplify() == af::ops::Zero) {
        continue;
      }
      std::stringstream tensor_name;
      tensor_name << outputs[j];
      all_stride_names.push_back(tensor_name.str() + "_stride_" + std::to_string(start_idx));
    }
  }

  if (all_stride_names.empty()) {
    ss << "  uint32_t strides_align = 0;\n  bool strides_equal = false;\n";
    return;
  }

  ss << "  bool strides_equal = false;\n";
  ss << "  uint32_t strides_align = static_cast<uint32_t>(" << all_stride_names[0] << ");\n";
  if (all_stride_names.size() == 1) {
    ss << "  strides_equal = true;\n";
    return;
  }
  ss << "  if (";
  for (size_t i = 1; i < all_stride_names.size(); i++) {
    ss << all_stride_names[0] << " == " << all_stride_names[i];
    if (i < all_stride_names.size() - 1) {
      ss << " && ";
    }
  }
  ss << ") {\n";
  ss << "    strides_equal = true;\n";
  ss << "  }\n";

  return;
}

void OptimizeMergeParamsAndLoopSize(const std::vector<std::string> &loop_size_vec, std::stringstream &ss) {
  if (loop_size_vec.size() < MAX_VF_AXIS_MERGE_SIZE) {
    return;
  }

  const auto &loop_size_0 = loop_size_vec[0];
  const auto &loop_size_1 = loop_size_vec[1];

  ss << "  if (strides_equal) {\n";
  ss << "    " << loop_size_0 << " = 1;\n";
  ss << "    element_count = static_cast<uint32_t>(strides_align * output_dims_0);\n";
  ss << "    " << loop_size_1 << " = static_cast<uint16_t>((element_count + ELEMENT_PER_VECTOR_LENGTH - 1) / ELEMENT_PER_VECTOR_LENGTH);\n";
  ss << "  }\n";
  return;
}

void GenerateVectorFuncParams(const std::string &max_dtype_size, int32_t stride_depth,
                              const std::vector<std::vector<ascir::AxisId>> &merge_axis_ids, std::stringstream &ss) {
  ss << "  constexpr static uint32_t VECTOR_LENGTH = AscendC::GetVecLen();\n";
  ss << "  constexpr static uint32_t SIZE_OF_DTYPE = sizeof(" << max_dtype_size << ");\n";
  ss << "  constexpr static uint32_t ELEMENT_PER_VECTOR_LENGTH = VECTOR_LENGTH / SIZE_OF_DTYPE;\n";
  if (merge_axis_ids.size() == 0) {
    ss << "  uint32_t element_count = 1;\n";
  } else {
    ss << "  uint32_t element_count = static_cast<uint32_t>(" << "output_dims_" << stride_depth << ");\n";
  }
  ss << "  uint16_t loop_times = static_cast<uint16_t>((element_count + ELEMENT_PER_VECTOR_LENGTH - 1) / "
            "ELEMENT_PER_VECTOR_LENGTH);\n";
  return;
}

void GenerateTensorDefs(const TPipe &tpipe, const TensorManager &tensor_mgr, const VFLoop &root_loop,
                        std::stringstream &vf_body) {
  std::string reg_tensor_def;
  tensor_mgr.GenerateVreg(reg_tensor_def);
  vf_body << reg_tensor_def;

  std::vector<std::string> mask_reg_temp_tensors;
  root_loop.CollectMaskRegTempTensors(tpipe, tensor_mgr, mask_reg_temp_tensors);
  for (const auto &tensor_name : mask_reg_temp_tensors) {
    vf_body << "AscendC::MicroAPI::MaskReg " << tensor_name << "_temp_mask;" << std::endl;
  }
}

Status VfCall::GenerateFuncDefinition(const TPipe &tpipe, const Tiler &tiler, std::stringstream &ss) const {
  // 收集输入输出信息,由于GenInnerLoopSizeAndActualSize函数中会刷新tiler对象中的actual_sizes字段,
  // 导致生成函数签名和函数调用时,获取到的size信息不一致,因此生成函数签名和函数调用时均需要调用合轴函数
  GE_ASSERT_SUCCESS(ParseInputOutputInfo(tpipe));

  // 处理合轴信息,生成函数体时,需要知道子图内的合轴状态
  VectorizedAxisLoopMergeStatus merge_info;
  bool status = GenerateVectorizedAxisMergeStatus(this->ub_inputs_, this->ub_outputs_, merge_info, tpipe);
  GE_ASSERT_TRUE(status, "GenerateVectorizedAxisMergeStatus failed");

  ss << "#if defined(__DAV_C310__) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 5102 || __NPU_ARCH__ == 3510))"
     << std::endl;
  ss << "\ninline __aicore__ void " << this->vf_call_name_ << "(";
  CreateVFCallDimAndStrideParmas(this->ub_inputs_, this->scalar_inputs_, this->ub_outputs_, merge_info, ss);
  ss << ")" << std::endl;

  // func body
  std::stringstream params;
  std::stringstream vf_body;
  int32_t stride_depth = GeOriginLastAxisPos(tiler, axis_ids_, merge_info.merge_axis_ids);

  //   constexpr static uint32_t VECTOR_LENGTH = AscendC::GetVecLen();
  //   constexpr static uint32_t SIZE_OF_DTYPE = sizeof(half);
  //   constexpr static uint32_t ELEMENT_PER_VECTOR_LENGTH = VECTOR_LENGTH / SIZE_OF_DTYPE;
  //   uint32_t element_count = static_cast<uint32_t>(output_dims_0);
  //   uint16_t loop_times = static_cast<uint16_t>((element_count + ELEMENT_PER_VECTOR_LENGTH - 1) / ELEMENT_PER_VECTOR_LENGTH);
  GenerateVectorFuncParams(max_dtype_size_, stride_depth, merge_info.merge_axis_ids, params);

  std::string dtype_name;
  for (const auto &output : this->ub_outputs_) {
    Tensor::DtypeName(output.dtype, dtype_name);
    vf_body << "    " << "__local_mem__ " << dtype_name << " *" << output << " = "
            << "(__local_mem__ " << dtype_name << " *)" << output << "_addr" << ";\n";
  }

  for (const auto &input : this->ub_inputs_) {
    Tensor::DtypeName(input.dtype, dtype_name);
    vf_body << "    " << "__local_mem__ " << dtype_name << " *" << input << " = "
            << "(__local_mem__ " << dtype_name << " *)" << input << "_addr" << ";\n";
  }

  GenerateTensorDefs(tpipe, tensor_mgr_, root_loop_, vf_body);

  // define preg_main and preg_vl1
  vf_body << "\nAscendC::MicroAPI::MaskReg preg_main = AscendC::MicroAPI::CreateMask<" << max_dtype_size_
          << ", AscendC::MicroAPI::MaskPattern::ALL>();\n";
  vf_body << "AscendC::MicroAPI::MaskReg preg_vl1 = AscendC::MicroAPI::CreateMask<" << max_dtype_size_
          << ", AscendC::MicroAPI::MaskPattern::VL1>();\n";

  std::string loop_body;
  std::string loop_size;
  int32_t only_loop_max_depth = -1;
  std::vector<std::string> loop_size_vec;
  root_loop_.Generate(tpipe, tensor_mgr_, stride_depth, loop_body, loop_size, only_loop_max_depth, loop_size_vec);
  params << std::endl << loop_size << std::endl;
  if (stride_depth == MAX_VF_AXIS_MERGE_SIZE - 1 && only_loop_max_depth == MAX_VF_AXIS_MERGE_SIZE - 1) { // 假如stride_depth为1即两层循环,那实际上loop里递归了三次,分别是0、1、2,在2里单独处理call
    GenerateStridesEqualCheck(this->ub_inputs_, this->ub_outputs_, merge_info, params);
    OptimizeMergeParamsAndLoopSize(loop_size_vec, params);
  }
  vf_body << std::endl << loop_body << std::endl;
  GetVFCallFuncBody(params.str(), vf_body.str(), ss);
  ss << "#endif" << std::endl;

  return ge::SUCCESS;
}

Status VfCall::Generate(const TPipe &tpipe, [[maybe_unused]] const std::vector<ascir::AxisId> &current_axis,
                        [[maybe_unused]] const std::vector<std::reference_wrapper<const Tensor>> &inputs,
                        [[maybe_unused]] const std::vector<std::reference_wrapper<const Tensor>> &outputs, std::string &result) const {
  // 合轴信息只会在子图内部体现,不会在原图上体现
  // 对于broadcast inline场景, 需要再合轴后对ApiParams中的repeats和strides做排序
  VectorizedAxisLoopMergeStatus merge_info;
  bool status = GenerateVectorizedAxisMergeStatus(this->ub_inputs_, this->ub_outputs_, merge_info, tpipe);
  GE_ASSERT_TRUE(status, "GenerateVectorizedAxisMergeStatus failed");

  std::stringstream ss;
  size_t loop_num = merge_info.merge_repeats_str.size();
  ss << "#if defined(__DAV_C310__) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 5102 || __NPU_ARCH__ == 3510))"
     << std::endl;
  ss << "AscendC::SetCtrlSpr<60, 60>(0);" << std::endl;
  if (loop_num <= kVFMaxLoop) {
    std::vector<std::string> inputs_ub_offsets = {};
    std::vector<std::string> outputs_ub_offsets = {};
    CreateVFCall(tpipe, this->vf_call_name_, this->ub_inputs_, this->ub_outputs_, inputs_ub_offsets, outputs_ub_offsets,
                 this->scalar_inputs_, merge_info, ss);
  } else {
    CreateOuterForVFCall(tpipe, this->vf_call_name_, this->ub_inputs_, this->ub_outputs_, this->scalar_inputs_,
                         merge_info, ss);
  }
  ss << "#endif" << std::endl;
  result = ss.str();
  return ge::SUCCESS;
}

VfCall::~VfCall() {
  root_loop_.Destruct();
}

static ApiCallRegister<VfCall> register_vf_api_call("VfCall");
}  // namespace codegen