/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "exe_time_pass.h"
#include "base/att_const_values.h"
#include "common/common_utils.h"
#include "ascir_ops.h"
#include "optimize/platform/platform_factory.h"

namespace att {
namespace {
[[maybe_unused]] void AddUsedArgs(const Expr &expr, std::vector<Expr> &used_args) {
  for (const auto &arg : expr.FreeSymbols()) {
    used_args.emplace_back(arg);
  }
}

void CheckSplit(const std::map<std::string, std::set<std::string>> &axis_list, 
                                    const std::string &node_name, const std::string &orig_name, bool &is_split) {
  const auto iter = axis_list.find(node_name);
  if (iter != axis_list.end()) {
    if (iter->second.size() > 0) {
      if (iter->second.find(orig_name) != iter->second.end()) {
        is_split = true;
      }
    }
  }
}

void InsertAxis(const SubAxis *cur_dim, const NodeInfo &node_info,
                std::map<std::string, std::set<std::string>> &axis_list) {
  for (const auto &dim : cur_dim->orig_axis) {
    for (const auto &node_name : node_info.from_data) {
      axis_list[node_name].insert(dim->name);
    }
  }
}
}
void ExeTimePassManager::AddBAxis(const std::string &dim_name, const Expr &repeat, const Expr &stride,
                                  const NodeInfo &node_info) {
  TensorPtr output_tensor = node_info.outputs[0];
  for (size_t i = 0; i < output_tensor->dim_info.size(); i++) {
    auto &cur_dim = output_tensor->dim_info[i];
    if ((cur_dim->name == dim_name)) {
      if ((output_tensor->repeat[i] != af::sym::kSymbolOne && repeat == af::sym::kSymbolOne) ||
          (output_tensor->stride[i] != af::sym::kSymbolZero && stride == af::sym::kSymbolZero)) {
        InsertAxis(cur_dim, node_info, broadcast_axis_);
      }
      break;
    }
  }
}

void ExeTimePassManager::AddRAxis(const std::string &dim_name, const Expr &repeat, const Expr &stride,
                                  const NodeInfo &node_info) {
  TensorPtr output_tensor = node_info.outputs[0];
  for (size_t i = 0; i < output_tensor->dim_info.size(); i++) {
    auto &cur_dim = output_tensor->dim_info[i];
    if ((cur_dim->name == dim_name)) {
      if ((repeat != af::sym::kSymbolOne && output_tensor->repeat[i] == af::sym::kSymbolOne) ||
          (stride != af::sym::kSymbolZero && output_tensor->stride[i] == af::sym::kSymbolZero)) {
        InsertAxis(cur_dim, node_info, reduce_axis_);
      } else {
        InsertAxis(cur_dim, node_info, non_reduce_axis_);
      }
      break;
    }
  }
}

void ExeTimePassManager::CheckBroadcast(const NodeInfo &node) {
  const std::string kNodeBroadcast = "Broadcast";
  if (node.node_type == kNodeBroadcast || ascgen_utils::IsGeneralizeBrcInlineScene(node.node_ptr)) {
    auto &tensor = node.inputs[0];
    for (size_t i = 0; i < tensor->dim_info.size(); i++) {
      AddBAxis(tensor->dim_info[i]->name, tensor->repeat[i], tensor->stride[i], node);
    }
  }
}

ge::Status CheckExecConditionBroadcast(const TuningSpacePtr tuning_space, const NodeInfo &node, SubAxis *axis,
                                       bool &is_split, Expr **fused_axis) {
  GE_ASSERT_NOTNULL(tuning_space);
  GE_ASSERT_NOTNULL(axis);
  GE_ASSERT_NOTNULL(fused_axis);
  if (node.exec_condition != af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis &&
      node.exec_condition != af::ExecuteCondition::kCacheBlockSplitOriginBroadcastAxis) {
    GELOGD("exec_condition not match");
    return ge::SUCCESS;
  }
  is_split = true;
  if (node.exec_condition == af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis) {
    for (const auto &axes : tuning_space->sub_axes) {
      GE_ASSERT_NOTNULL(axes);
      if (!axes->is_split || axes->axis_type != AxisPosition::OUTER) {
        continue;
      }
      for (const auto parent_axis : axes->parent_axis) {
        GE_ASSERT_NOTNULL(parent_axis);
        if (parent_axis->name == axis->name) {
          GELOGD("fused_axis %s", Str(axes->repeat).c_str());
          *fused_axis = &axes->repeat;
          break;
        }
      }
    }
  }
  return ge::SUCCESS;
}

void ExeTimePassManager::CheckReduce(const NodeInfo &node) {
  const std::vector<std::string> kNodeReduce = {"Sum", "Max", "Min"};
  if (std::find(kNodeReduce.begin(), kNodeReduce.end(), node.node_type) != kNodeReduce.end()) {
    auto &tensor = node.inputs[0];
    for (size_t i = 0; i < tensor->dim_info.size(); i++) {
      AddRAxis(tensor->dim_info[i]->name, tensor->repeat[i], tensor->stride[i], node);
    }
  }
}

void ExeTimePassManager::UpdateBufNode(const std::vector<NodeInfo> &nodes) {
  std::string log;
  bool broadcast_node = false;
  for (int32_t i = nodes.size() - 1; i >= 0; --i) {
    auto &node = nodes[i];
    if (node.exec_condition == af::ExecuteCondition::kCacheBlockSplitFusedBroadcastAxis ||
        node.exec_condition == af::ExecuteCondition::kCacheBlockSplitOriginBroadcastAxis) {
      GELOGD("insert brc inline node by exec_condition %s", node.name.c_str());
      brc_buf_node_.insert(node.name);
    }
    if (broadcast_node) {
      brc_buf_node_.insert(node.name);
    }
    if (node.node_type == kBroadcast) {
      brc_buf_node_.insert(node.name);
      broadcast_node = true;
    } else if (node.node_type == kLoad) {
      broadcast_node = false;
      const auto platform = optimize::PlatformFactory::GetInstance().GetPlatform();
      if (platform != nullptr && ascgen_utils::IsLinkToBrdcst(node.node_ptr, platform->BroadcastTypes())) {
        GELOGD("insert brc inline node %s", node.name.c_str());
        brc_buf_node_.insert(node.name);
      }
    }
  }
  for (const auto &node_name : brc_buf_node_) {
    log += node_name + ", ";
  }
  GELOGD("Brc related node is {%s}.", log.c_str());
}

void ExeTimePassManager::GenLog(const std::string &type_name, const std::map<std::string, std::set<std::string>> &axis_list) const {
  std::string log;
  for (const auto &pair : axis_list) {
    log.clear();
    for (const auto &node_name : pair.second) {
      log += node_name + ", ";
    }
    GELOGD("%s axis for node[%s] is {%s}.", type_name.c_str(), pair.first.c_str(), log.c_str());
  }
}

bool ExeTimePassManager::GetRLoop(const NodeInfo &node, Expr &r_loop) const {
  for (const auto &axis : node.loop_axes) {
    GE_ASSERT_NOTNULL(axis, "Get axis failed.");
    if (axis->axis_type != AxisPosition::OUTER && axis->axis_type != AxisPosition::INNER) {
      continue;
    }
    for (const auto &node_name : node.from_data) {
      bool r_split = false;
      for (const auto &orig_axis : axis->orig_axis) {
        CheckSplit(reduce_axis_, node_name, orig_axis->name, r_split);
      }
      if (r_split) {
        r_loop = axis->repeat;
        GELOGD("r_loop of [%s] is [%s].", node.name.c_str(), Str(r_loop).c_str());
        return true;
      }
    }
  }
  return false;
}

/*
brc缓存执行逻辑:
1.递归计算节点的输入信息,将每一个节点与一个或多个Data节点绑定
2.扫描Broadcast节点与Reduce节点,根据绑定的Data节点确定每个Data节点的B轴,A轴与R轴
3.扫描Load至Broadcast间的所有节点
4.判断这些节点的循环轴间是否存在绑定的Data节点的B切分轴,若存在,则使能缓存:
1)若切分轴同时为R轴或非A轴,则使用性能公式除以切分轴的循环次数
2)若切分轴同时是A轴,且当前loop axis中与R轴相关的循环轴的循环轮次为1,则使用性能公式除以B切分轴的循环次数,否则不做压缩
*/
bool ExeTimePassManager::CheckAxisSplit(const NodeInfo &node, const SubAxis *axis, Expr *fused_axis) const {
  bool b_split = false;
  bool a_split = false;
  bool r_split = false;
  for (const auto &node_name : node.from_data) {
    for (const auto &orig_axis : axis->orig_axis) {
      CheckSplit(broadcast_axis_, node_name, orig_axis->name, b_split);
      CheckSplit(reduce_axis_, node_name, orig_axis->name, r_split);
      CheckSplit(non_reduce_axis_, node_name, orig_axis->name, a_split);
      CheckExecConditionBroadcast(tuning_space_, node, orig_axis, b_split, &fused_axis);
    }
    if (b_split) {
      return true;
    }
  }
  return b_split;
}

TernaryOp ExeTimePassManager::HandleBroadcastSplit(const NodeInfo &node, const Expr &exe_time,
                                                     const Expr &fused_exe_time) const {
  GELOGD("[DFX] fused broadcast updates [%s] exe time : [%s] -> [%s]", node.name.c_str(),
         Str(exe_time).c_str(), Str(fused_exe_time).c_str());
  return TernaryOp(fused_exe_time);
}

TernaryOp ExeTimePassManager::HandleReduceOrNormalSplit(const NodeInfo &node, const Expr &exe_time,
                                                         const SubAxis *axis,
                                                         bool r_split, bool a_split) const {
  Expr cur_exe_time = af::sym::Div(exe_time, axis->repeat);
  if (r_split || !a_split) {
    GELOGD("Axis [%s] r_split and not asplit.", axis->name.c_str());
    GELOGD("Brc buf module updates [%s] exe time : [%s] -> [%s]", node.name.c_str(),
           Str(exe_time).c_str(), Str(cur_exe_time).c_str());
    return TernaryOp(cur_exe_time);
  }
  GELOGD("Axis [%s] a_split.", axis->name.c_str());
  Expr r_loop;
  if (GetRLoop(node, r_loop)) {
    GELOGD("Brc buf module updates [%s] exe time : [%s] -> [%s == 1 ? %s : %s]",
           node.name.c_str(), Str(exe_time).c_str(), Str(r_loop).c_str(),
           Str(cur_exe_time).c_str(), Str(exe_time).c_str());
    return TernaryOp(CondType::K_EQ, r_loop, CreateExpr(1.0f), af::sym::Div(cur_exe_time, r_loop), exe_time);
  }
  return TernaryOp(exe_time);
}

TernaryOp ExeTimePassManager::UpdateNodeExeTime(const NodeInfo &node, const Expr &exe_time) const {
  if (brc_buf_node_.find(node.name) == brc_buf_node_.end()) {
    return TernaryOp(exe_time);
  }
  Expr *fused_axis = nullptr;
  bool b_split = false;
  bool a_split = false;
  bool r_split = false;

  for (const auto &axis : node.loop_axes) {
    GE_ASSERT_NOTNULL(axis, "Get axis failed.");
    if (axis->axis_type != AxisPosition::OUTER && axis->axis_type != AxisPosition::INNER) {
      continue;
    }
    // Reset flags for each axis
    b_split = false;
    a_split = false;
    r_split = false;

    for (const auto &node_name : node.from_data) {
      for (const auto &orig_axis : axis->orig_axis) {
        CheckSplit(broadcast_axis_, node_name, orig_axis->name, b_split);
        CheckSplit(reduce_axis_, node_name, orig_axis->name, r_split);
        CheckSplit(non_reduce_axis_, node_name, orig_axis->name, a_split);
        CheckExecConditionBroadcast(tuning_space_, node, orig_axis, b_split, &fused_axis);
      }
      if (b_split) {
        break;
      }
    }

    if (fused_axis != nullptr) {
      Expr fused_exe_time = af::sym::Max(af::sym::kSymbolOne, af::sym::Div(exe_time, *fused_axis));
      return HandleBroadcastSplit(node, exe_time, fused_exe_time);
    }

    if (b_split) {
      return HandleReduceOrNormalSplit(node, exe_time, axis, r_split, a_split);
    }
  }
  return TernaryOp(exe_time);
}
}  // namespace att