/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <iostream>
#include "stub_ffn_modelinfo.h"

namespace {
att::Expr GetSafeDivisor(const att::Expr &expr)
{
  return ge::sym::Max(ge::sym::kSymbolOne, expr);
}

att::Expr GetSafeOffsetDivisor(const att::Expr &expr)
{
  return ge::sym::Max(ge::sym::kSymbolOne, att::CreateExpr(-1) + expr);
}

struct FfnExprContext {
  att::Expr expr_maxTokens;
  att::Expr expr_basem1;
  att::Expr expr_basem2;
  att::Expr expr_ubm;
  att::Expr expr_n1;
  att::Expr expr_basen1;
  att::Expr expr_k1;
  att::Expr expr_n2;
  att::Expr expr_basen2;
};


void BuildFfnMaxTokenAxes(att::ModelInfo &model_info, FfnExprContext &ctx,
                          att::AttAxisPtr &maxTokens, att::AttAxisPtr &basem1, att::AttAxisPtr &ubm,
                          att::AttAxisPtr &basem2)
{
  ctx.expr_maxTokens = att::CreateExpr("maxTokens");
  ctx.expr_basem1 = att::CreateExpr("base_m1");
  ctx.expr_basem2 = att::CreateExpr("base_m2");
  ctx.expr_ubm = att::CreateExpr("ub_m");

  att::SymVarInfoPtr sym_maxTokens = std::make_shared<att::SymVarInfo>(ctx.expr_maxTokens);
  att::SymVarInfoPtr sym_basem1 = std::make_shared<att::SymVarInfo>(ctx.expr_basem1);
  sym_basem1->align = ge::Symbol(8);
  sym_basem1->related_scope = {att::HardwareDef::L0C};
  att::SymVarInfoPtr sym_ubm = std::make_shared<att::SymVarInfo>(ctx.expr_ubm);
  sym_ubm->align = ge::Symbol(8);
  sym_ubm->related_scope = {att::HardwareDef::UB};
  att::SymVarInfoPtr sym_basem2 = std::make_shared<att::SymVarInfo>(ctx.expr_basem2);
  sym_basem2->align = ge::Symbol(8);
  sym_basem2->related_scope = {att::HardwareDef::L0C};

  maxTokens = std::make_shared<att::AttAxis>();
  basem1 = std::make_shared<att::AttAxis>();
  ubm = std::make_shared<att::AttAxis>();
  basem2 = std::make_shared<att::AttAxis>();

  maxTokens->name = "maxTokens";
  maxTokens->axis_pos = att::AxisPosition::ORIGIN;
  maxTokens->bind_multicore = false;
  maxTokens->is_last = false;
  maxTokens->is_node_innerest_dim = false;
  maxTokens->size = sym_maxTokens;

  basem1->name = "base_m1";
  basem1->axis_pos = att::AxisPosition::INNER;
  basem1->bind_multicore = false;
  basem1->is_last = false;
  basem1->is_node_innerest_dim = false;
  basem1->size = sym_basem1;
  basem1->orig_axis.push_back(maxTokens.get());
  basem1->from_axis = {maxTokens.get()};

  ubm->name = "ub_m";
  ubm->axis_pos = att::AxisPosition::INNER;
  ubm->bind_multicore = false;
  ubm->is_last = true;
  ubm->is_node_innerest_dim = false;
  ubm->size = sym_ubm;
  ubm->orig_axis.push_back(maxTokens.get());
  ubm->from_axis = {basem1.get()};

  basem2->name = "base_m2";
  basem2->axis_pos = att::AxisPosition::INNER;
  basem2->bind_multicore = false;
  basem2->is_last = true;
  basem2->is_node_innerest_dim = false;
  basem2->size = sym_basem2;
  basem2->orig_axis.push_back(maxTokens.get());
  basem2->from_axis = {maxTokens.get()};
}

void BuildFfnN1Axes(FfnExprContext &ctx, att::AttAxisPtr &n1, att::AttAxisPtr &basen1)
{
  ctx.expr_n1 = att::CreateExpr("N1");
  ctx.expr_basen1 = att::CreateExpr("base_n1");

  att::SymVarInfoPtr sym_n1 = std::make_shared<att::SymVarInfo>(ctx.expr_n1);
  att::SymVarInfoPtr sym_basen1 = std::make_shared<att::SymVarInfo>(ctx.expr_basen1);
  sym_basen1->align = ge::Symbol(8);
  sym_basen1->related_scope = {att::HardwareDef::L0C, att::HardwareDef::UB, att::HardwareDef::BTBUF};

  n1 = std::make_shared<att::AttAxis>();
  basen1 = std::make_shared<att::AttAxis>();

  n1->name = "N1";
  n1->axis_pos = att::AxisPosition::ORIGIN;
  n1->bind_multicore = false;
  n1->is_last = false;
  n1->is_node_innerest_dim = false;
  n1->size = sym_n1;

  basen1->name = "base_n1";
  basen1->axis_pos = att::AxisPosition::INNER;
  basen1->bind_multicore = false;
  basen1->is_last = true;
  basen1->is_node_innerest_dim = true;
  basen1->size = sym_basen1;
  basen1->orig_axis.push_back(n1.get());
  basen1->from_axis = {n1.get()};
}

void BuildFfnK1Axis(FfnExprContext &ctx, att::AttAxisPtr &k1)
{
  ctx.expr_k1 = att::CreateExpr("K1");
  att::SymVarInfoPtr sym_k1 = std::make_shared<att::SymVarInfo>(ctx.expr_k1);
  k1 = std::make_shared<att::AttAxis>();
  k1->name = "K1";
  k1->axis_pos = att::AxisPosition::ORIGIN;
  k1->bind_multicore = false;
  k1->is_last = false;
  k1->is_node_innerest_dim = false;
  k1->size = sym_k1;
}

void BuildFfnN2Axes(FfnExprContext &ctx, att::AttAxisPtr &n2, att::AttAxisPtr &basen2)
{
  ctx.expr_n2 = att::CreateExpr("N2");
  ctx.expr_basen2 = att::CreateExpr("base_n2");

  att::SymVarInfoPtr sym_n2 = std::make_shared<att::SymVarInfo>(ctx.expr_n2);
  att::SymVarInfoPtr sym_basen2 = std::make_shared<att::SymVarInfo>(ctx.expr_basen2);
  sym_basen2->align = ge::Symbol(8);
  sym_basen2->related_scope = {att::HardwareDef::L0C, att::HardwareDef::BTBUF};

  n2 = std::make_shared<att::AttAxis>();
  basen2 = std::make_shared<att::AttAxis>();

  n2->name = "N2";
  n2->axis_pos = att::AxisPosition::ORIGIN;
  n2->bind_multicore = false;
  n2->is_last = false;
  n2->is_node_innerest_dim = false;
  n2->size = sym_n2;

  basen2->name = "base_n2";
  basen2->axis_pos = att::AxisPosition::INNER;
  basen2->bind_multicore = false;
  basen2->is_last = true;
  basen2->is_node_innerest_dim = true;
  basen2->size = sym_basen2;
  basen2->orig_axis.push_back(n2.get());
  basen2->from_axis = {n2.get()};
}

att::Expr CalcCube1Mte2(const FfnExprContext &ctx, const att::Expr &n1_cnt, const att::Expr &m1_cnt)
{
  att::Expr expr_m1n1 = ((((att::CreateExpr(0.05624f) * ctx.expr_basem1) + att::CreateExpr(0.3984f)) *
                           att::CreateExpr(6.2712e-05f) * ctx.expr_k1 * ctx.expr_basen1) +
                         (att::CreateExpr(0.0008295f) * ctx.expr_k1 * ctx.expr_basen1));
  att::Expr weight_m1n1 = ((att::CreateExpr(0.05761f) * ctx.expr_basen1) + att::CreateExpr(0.0f));
  att::Expr mte2_m1n1 = expr_m1n1 * weight_m1n1;
  att::Expr expr_n1m1 = ((((att::CreateExpr(0.05940f) * ctx.expr_basen1) + att::CreateExpr(20.0944f)) *
                           att::CreateExpr(6.2712e-05f) * ctx.expr_k1 * ctx.expr_basem1) +
                         (att::CreateExpr(0.0008295f) * ctx.expr_k1 * ctx.expr_basem1));
  att::Expr weight_n1m1 = ((att::CreateExpr(0.07543f) * ctx.expr_k1) + att::CreateExpr(0.0f));
  att::Expr mte2_n1m1 = expr_n1m1 * weight_n1m1;
  att::Expr weight1 = (att::CreateExpr(0.000216f) * ctx.expr_basen1) +
                      (att::CreateExpr(0.0003614f) * ctx.expr_basem1) +
                      (att::CreateExpr(0.0005757f) * ctx.expr_k1);
  att::Expr weight2 = (att::CreateExpr(0.0f) * ctx.expr_k1 * ctx.expr_basem1) +
                      (att::CreateExpr(0.0f) * ctx.expr_basem1 * ctx.expr_basen1) +
                      (att::CreateExpr(0.0f) * ctx.expr_k1 * ctx.expr_basen1);
  return (mte2_m1n1 + mte2_n1m1) * (n1_cnt * m1_cnt) / (weight1 + weight2);
}

att::Expr CalcCube2Mte2(const FfnExprContext &ctx, const att::Expr &n2_cnt, const att::Expr &m2_cnt)
{
  att::Expr expr_m2n2 = ((((att::CreateExpr(0.05624f) * ctx.expr_basem2) + att::CreateExpr(0.3984f)) *
                           att::CreateExpr(6.2712e-05f) * ctx.expr_n1 * ctx.expr_basen2) +
                         (att::CreateExpr(0.0008295f) * ctx.expr_n1 * ctx.expr_basen2));
  att::Expr weight_m2n2 = ((att::CreateExpr(0.05761f) * ctx.expr_basen2) + att::CreateExpr(0.0f));
  att::Expr mte2_m2n2 = expr_m2n2 * weight_m2n2;
  att::Expr expr_n2m2 = ((((att::CreateExpr(0.05940f) * ctx.expr_basen2) + att::CreateExpr(20.0944f)) *
                           att::CreateExpr(6.2712e-05f) * ctx.expr_n1 * ctx.expr_basem2) +
                         (att::CreateExpr(0.0008295f) * ctx.expr_n1 * ctx.expr_basem2));
  att::Expr weight_n2m2 = ((att::CreateExpr(0.07543f) * ctx.expr_n1) + att::CreateExpr(0.0f));
  att::Expr mte2_n2m2 = expr_n2m2 * weight_n2m2;
  att::Expr weight1 = (att::CreateExpr(0.000216f) * ctx.expr_basen2) +
                      (att::CreateExpr(0.0003614f) * ctx.expr_basem2) +
                      (att::CreateExpr(0.0005757f) * ctx.expr_n1);
  att::Expr weight2 = (att::CreateExpr(0.0f) * ctx.expr_n1 * ctx.expr_basem2) +
                      (att::CreateExpr(0.0f) * ctx.expr_basem2 * ctx.expr_basen2) +
                      (att::CreateExpr(0.0f) * ctx.expr_n1 * ctx.expr_basen2);
  return (mte2_m2n2 + mte2_n2m2) * (n2_cnt * m2_cnt) / (weight1 + weight2);
}

void FillFfnModelInfo(att::ModelInfo &model_info, const FfnExprContext &ctx)
{
  att::Expr btbuf_occupy = ge::sym::Max((att::CreateExpr(4) * ctx.expr_basen1), (att::CreateExpr(4) * ctx.expr_basen2));
  att::Expr l0c_occupy = ge::sym::Max((att::CreateExpr(4) * ctx.expr_basen1 * ctx.expr_basem1),
                                      (att::CreateExpr(4) * ctx.expr_basen2 * ctx.expr_basem2));
  att::Expr ub_occupy = (att::CreateExpr(4) * ctx.expr_basen1 * ctx.expr_ubm);
  model_info.hardware_cons[att::HardwareDef::BTBUF] = btbuf_occupy;
  model_info.hardware_cons[att::HardwareDef::L0C] = l0c_occupy;
  model_info.hardware_cons[att::HardwareDef::UB] = ub_occupy;

  att::Expr m1_cnt = ge::sym::Ceiling(ctx.expr_maxTokens / GetSafeDivisor(ctx.expr_basem1));
  att::Expr m2_cnt = ge::sym::Ceiling(ctx.expr_maxTokens / GetSafeDivisor(ctx.expr_basem2));
  att::Expr n1_cnt = ge::sym::Ceiling(ctx.expr_n1 / GetSafeDivisor(ctx.expr_basen1));
  att::Expr n2_cnt = ge::sym::Ceiling(ctx.expr_n2 / GetSafeDivisor(ctx.expr_basen2));
  att::Expr ubm_cnt = ge::sym::Ceiling(ctx.expr_basem1 / GetSafeDivisor(ctx.expr_ubm));

  att::Expr vec_ub = ((att::CreateExpr(4) * ctx.expr_basen1 * ctx.expr_ubm) / GetSafeOffsetDivisor(ctx.expr_basen1) +
                      att::CreateExpr(4));
  att::Expr vec_m1n1 = ((att::CreateExpr(8) * ctx.expr_basem1 * ctx.expr_basen1) /
                        GetSafeOffsetDivisor(ctx.expr_basen1) + att::CreateExpr(4));
  att::Expr vec_m2n2 = ((att::CreateExpr(8) * ctx.expr_basem2 * ctx.expr_basen2) /
                        GetSafeOffsetDivisor(ctx.expr_basen2) + att::CreateExpr(4));
  att::Expr vec = (vec_ub * (m1_cnt * n1_cnt * ubm_cnt)) + (vec_m1n1 * (m1_cnt * n1_cnt)) +
                  (vec_m2n2 * (m2_cnt * n2_cnt));

  att::Expr mte3_ub = ((att::CreateExpr(0.01741f) * ctx.expr_basen1 * ctx.expr_ubm) + att::CreateExpr(0.22f));
  att::Expr v_mte3 = mte3_ub * (m1_cnt * n1_cnt * ubm_cnt);

  att::Expr mte2_n1 = ((att::CreateExpr(5.01f) / (att::CreateExpr(27240.69f) + ctx.expr_basen1)) + att::CreateExpr(1051.66f)) *
                      (ctx.expr_basen1 / att::CreateExpr(30421.24f));
  att::Expr mte2_n2 = ((att::CreateExpr(5.01f) / (att::CreateExpr(27240.69f) + ctx.expr_basen2)) + att::CreateExpr(1051.66f)) *
                      (ctx.expr_basen2 / att::CreateExpr(30421.24f));
  att::Expr mte2_ub = (att::CreateExpr(0.007f) * ctx.expr_basen1 * ctx.expr_ubm) + att::CreateExpr(7.97f);
  att::Expr v_mte2 = mte2_n1 * n1_cnt + mte2_n2 * n2_cnt + mte2_ub * (m1_cnt * n1_cnt * ubm_cnt);

  att::Expr mte2_cube1 = CalcCube1Mte2(ctx, n1_cnt, m1_cnt);
  att::Expr mte2_cube2 = CalcCube2Mte2(ctx, n2_cnt, m2_cnt);
  att::Expr mte2 = mte2_cube1 + mte2_cube2;

  model_info.objects[att::PipeType::AIV_MTE2] = v_mte2;
  model_info.objects[att::PipeType::AIV_MTE3] = v_mte3;
  model_info.objects[att::PipeType::AIC_MTE2] = mte2;
  model_info.objects[att::PipeType::AIV_VEC] = vec;
  model_info.tiling_case_id = 0;
  model_info.eq_exprs[att::kFatherToChildNoTail].push_back(std::pair(ctx.expr_basem1, ctx.expr_ubm));
  model_info.output_size = 1;
}

void AppendFfnArgList(att::ModelInfo &model_info, const att::AttAxisPtr &maxTokens, const att::AttAxisPtr &basen1,
                      const att::AttAxisPtr &basen2, const att::AttAxisPtr &n1, const att::AttAxisPtr &basem1,
                      const att::AttAxisPtr &k1, const att::AttAxisPtr &n2, const att::AttAxisPtr &basem2,
                      const att::AttAxisPtr &ubm)
{
  model_info.arg_list.emplace_back(maxTokens);
  model_info.arg_list.emplace_back(basen1);
  model_info.arg_list.emplace_back(basen2);
  model_info.arg_list.emplace_back(n1);
  model_info.arg_list.emplace_back(basem1);
  model_info.arg_list.emplace_back(k1);
  model_info.arg_list.emplace_back(n2);
  model_info.arg_list.emplace_back(basem2);
  model_info.arg_list.emplace_back(ubm);
}
}  // namespace

namespace att {
ModelInfo GenFFNModelInfo()
{
  ModelInfo model_info;
  FfnExprContext ctx;

  AttAxisPtr maxTokens;
  AttAxisPtr basem1;
  AttAxisPtr ubm;
  AttAxisPtr basem2;
  BuildFfnMaxTokenAxes(model_info, ctx, maxTokens, basem1, ubm, basem2);

  AttAxisPtr n1;
  AttAxisPtr basen1;
  BuildFfnN1Axes(ctx, n1, basen1);

  AttAxisPtr k1;
  BuildFfnK1Axis(ctx, k1);

  AttAxisPtr n2;
  AttAxisPtr basen2;
  BuildFfnN2Axes(ctx, n2, basen2);

  FillFfnModelInfo(model_info, ctx);
  AppendFfnArgList(model_info, maxTokens, basen1, basen2, n1, basem1, k1, n2, basem2, ubm);
  return model_info;
}
}  // namespace att