/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "attribute_group/attr_group_shape_env.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "graph/detail/attributes_holder.h"
#include "proto/ge_ir.pb.h"
#include "graph/expression/expression_impl.h"
#include "graph/symbolizer/symbolic_utils.h"

namespace ge {
namespace {
static thread_local ShapeEnvAttr *shape_env_context{nullptr};
static std::map<ge::DataType, std::string> kGeDType2CppDtype = {
  {ge::DT_INT32, "int32_t"},
  {ge::DT_INT64, "int64_t"},
  {ge::DT_UINT32, "uint32_t"},
  {ge::DT_UINT64, "uint64_t"},
};
}

thread_local std::string ShapeEnvAttr::guard_dfx_info_ = "";
ShapeEnvAttr *GetCurShapeEnvContext() {
  return shape_env_context;
}

void SetCurShapeEnvContext(ShapeEnvAttr *shape_env) {
  shape_env_context = shape_env;
}

std::string Source::GetGlobalIndexStr() const {
  return "context->GetInputPointer<int64_t>(" + std::to_string(global_index_) + ")";
}

graphStatus ShapeEnvAttr::SerializeSymbolInfo(proto::ShapeEnvAttrGroupsDef *shape_env_group) {
  GE_ASSERT_NOTNULL(shape_env_group);
  shape_env_group->clear_symbol_to_value();
  auto symbol_to_value_def = shape_env_group->mutable_symbol_to_value();
  GE_ASSERT_NOTNULL(symbol_to_value_def);
  GELOGI("symbol_to_value size: %zu", symbol_to_value_.size());
  for (const auto &iter : symbol_to_value_) {
    GE_ASSERT_TRUE(!iter.first.IsConstExpr(),
        "Symbol in symbol_to_value of shape env attr should be a variable, but get: %s",
        iter.first.Serialize().get());
    symbol_to_value_def->insert({iter.first.Serialize().get(), iter.second});
  }
  auto value_to_symbol_def = shape_env_group->mutable_value_to_symbol();
  GE_ASSERT_NOTNULL(value_to_symbol_def);
  for (const auto &iter : value_to_symbol_) {
    GE_ASSERT_TRUE(!iter.second.empty());
    proto::SymbolInfoDef symbol_infos;
    for (const auto &sym_iter : iter.second) {
      GE_ASSERT_TRUE(!sym_iter.IsConstExpr(),
          "Symbol in value_to_symbol of shape env attr should be a variable, but get: %s",
          sym_iter.Serialize().get());
      symbol_infos.add_symbols(sym_iter.Serialize().get());
    }
    value_to_symbol_def->insert({iter.first, symbol_infos});
  }

  auto symbol_to_source_def = shape_env_group->mutable_symbol_to_source();
  GE_ASSERT_NOTNULL(symbol_to_source_def);
  // todoo: symbol_to_source_实现序列化
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::SerializeSymbolCheckInfos(proto::ShapeEnvAttrGroupsDef *shape_env_group) {
  GE_ASSERT_NOTNULL(shape_env_group);
  auto replacements_def = shape_env_group->mutable_replacements();
  for (const auto &iter : replacements_) {
    proto::ReplacementDef rep_def;
    rep_def.set_replace_expr(iter.second.replace_expr.Serialize().get());
    rep_def.set_rank(iter.second.rank);
    replacements_def->insert({iter.first.Serialize().get(), rep_def});
  }
  shape_env_group->clear_symbol_check_infos();
  for (const auto &iter : symbol_check_infos_) {
    proto::SymbolCheckInfoDef *symbol_check_info_def = shape_env_group->add_symbol_check_infos();
    symbol_check_info_def->set_expr(iter.expr.Serialize().get());
    symbol_check_info_def->set_file(iter.file);
    symbol_check_info_def->set_line(iter.line);
    symbol_check_info_def->set_dfx(iter.dfx_info);
  }
  shape_env_group->clear_symbol_assert_infos();
  for (const auto &iter : symbol_assert_infos_) {
    proto::SymbolCheckInfoDef *symbol_assert_info_def = shape_env_group->add_symbol_assert_infos();
    symbol_assert_info_def->set_expr(iter.expr.Serialize().get());
    symbol_assert_info_def->set_file(iter.file);
    symbol_assert_info_def->set_line(iter.line);
    symbol_assert_info_def->set_dfx(iter.dfx_info);
  }
  return GRAPH_SUCCESS;
}

ShapeEnvAttr::ShapeEnvAttr(const ShapeEnvAttr& other) {
  shape_env_setting_ = other.shape_env_setting_;
  symbol_to_value_ = other.symbol_to_value_;
  value_to_symbol_ = other.value_to_symbol_;
  symbol_to_source_ = other.symbol_to_source_;
  replacements_ = other.replacements_;
  symbol_check_infos_ = other.symbol_check_infos_;
  symbol_assert_infos_ = other.symbol_assert_infos_;
  unique_sym_id_ = other.unique_sym_id_;
}

ShapeEnvAttr& ShapeEnvAttr::operator=(const ShapeEnvAttr& other) {
  if (this != &other) {
    shape_env_setting_ = other.shape_env_setting_;
    symbol_to_value_ = other.symbol_to_value_;
    value_to_symbol_ = other.value_to_symbol_;
    symbol_to_source_ = other.symbol_to_source_;
    replacements_ = other.replacements_;
    symbol_check_infos_ = other.symbol_check_infos_;
    symbol_assert_infos_ = other.symbol_assert_infos_;
    unique_sym_id_ = other.unique_sym_id_;
  }
  return *this;
}

graphStatus ShapeEnvAttr::Serialize(proto::AttrGroupDef &attr_group_def) {
  auto shape_env_group = attr_group_def.mutable_shape_env_attr_group();
  GE_ASSERT_SUCCESS(SerializeSymbolInfo(shape_env_group));
  GE_ASSERT_SUCCESS(SerializeSymbolCheckInfos(shape_env_group));
  proto::ShapeEnvSettingDef *shape_env_setting_def = shape_env_group->mutable_shape_setting();
  shape_env_setting_def->set_specialize_zero_one(shape_env_setting_.specialize_zero_one);
  shape_env_setting_def->set_dynamic_mode(static_cast<int32_t>(shape_env_setting_.dynamic_mode));
  shape_env_group->set_unique_sym_id(unique_sym_id_);
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::DeserializeSymbolInfo(const proto::ShapeEnvAttrGroupsDef &shape_env_group) {
  symbol_to_value_.clear();
  GELOGI("symbol_to_value size: %zu", shape_env_group.symbol_to_value_size());
  for (const auto &iter : shape_env_group.symbol_to_value()) {
    Expression sym = Expression::Deserialize(iter.first.c_str());
    GE_ASSERT_TRUE(!sym.IsConstExpr(),
        "Symbol in symbol_to_value of shape env attr should be a variable, but get: %s",
        iter.first.c_str());
    symbol_to_value_.emplace(std::make_pair(sym, iter.second));
  }
  value_to_symbol_.clear();
  for (const auto &iter : shape_env_group.value_to_symbol()) {
    std::vector<Expression> symbol_infos;
    for (const auto &sym_iter : iter.second.symbols()) {
      Expression sym = Expression::Deserialize(sym_iter.c_str());
      GE_ASSERT_TRUE(!sym.IsConstExpr(),
          "Symbol in value_to_symbol of shape env attr should be a variable, but get: %s",
          sym_iter.c_str());
      symbol_infos.emplace_back(sym);
    }
    value_to_symbol_.emplace(std::make_pair(iter.first, symbol_infos));
  }
  symbol_to_source_.clear();
  // todoo: symbol_to_source_实现反序列化
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::DeserializeSymbolCheckInfos(const proto::ShapeEnvAttrGroupsDef &shape_env_group) {
  replacements_.clear();
  for (const auto &iter : shape_env_group.replacements()) {
    Expression expr = Expression::Deserialize(iter.first.c_str());
    Expression replace_expr = Expression::Deserialize(iter.second.replace_expr().c_str());
    replacements_.emplace(std::make_pair(expr, Replacement(replace_expr, iter.second.rank())));
  }
  symbol_check_infos_.clear();
  for (const auto &iter : shape_env_group.symbol_check_infos()) {
    Expression expr = Expression::Deserialize(iter.expr().c_str());
    symbol_check_infos_.emplace(SymbolCheckInfo(expr, iter.file(), iter.line(), iter.dfx()));
  }
  symbol_assert_infos_.clear();
  for (const auto &iter : shape_env_group.symbol_assert_infos()) {
    Expression expr = Expression::Deserialize(iter.expr().c_str());
    symbol_assert_infos_.emplace(SymbolCheckInfo(expr, iter.file(), iter.line(), iter.dfx()));
  }
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) {
  (void) attr_holder;
  const auto& shape_env_group = attr_group_def.shape_env_attr_group();
  DeserializeSymbolInfo(shape_env_group);
  DeserializeSymbolCheckInfos(shape_env_group);
  shape_env_setting_ =
      ShapeEnvSetting(shape_env_group.shape_setting().specialize_zero_one(),
          static_cast<DynamicMode>(shape_env_group.shape_setting().dynamic_mode()));
  unique_sym_id_ = shape_env_group.unique_sym_id();
  return GRAPH_SUCCESS;
}

std::unique_ptr<AttrGroupsBase> ShapeEnvAttr::Clone() {
  std::unique_ptr<AttrGroupsBase> new_attr = ComGraphMakeUnique<ShapeEnvAttr>(*this);
  GE_ASSERT_NOTNULL(new_attr);
  return new_attr;
}

bool ShapeEnvAttr::HasSymbolCheckInfo(const ge::Expression &e) const {
  auto expr = e.CanonicalizeBoolExpr();
  if (symbol_check_infos_.find(SymbolCheckInfo(expr)) != symbol_check_infos_.end()) {
    return true;
  }
  return false;
}

bool ShapeEnvAttr::HasSymbolAssertInfo(const ge::Expression &e) const {
  auto expr = e.CanonicalizeBoolExpr();
  if (symbol_assert_infos_.find(SymbolCheckInfo(expr)) != symbol_assert_infos_.end()) {
    return true;
  }
  return false;
}

ge::Expression ShapeEnvAttr::FindReplacements(const ge::Expression &expr) {
  auto iter = replacements_.find(expr);
  if (iter == replacements_.end()) {
    return expr;
  }
  if (iter->second.has_replace) {
    GELOGD("Find replace expr: %s of expr: %s has replace",
        iter->second.replace_expr.Str().get(), expr.Str().get());
    return expr;
  }
  auto replace_expr = iter->second.replace_expr;
  GELOGD("Find replace expr: %s of expr: %s",
      replace_expr.Str().get(), expr.Str().get());
  if (replace_expr == expr) {
    return expr;
  }
  std::vector<std::pair<Expression, Expression>> var_replacements;
  iter->second.has_replace = true;
  for (auto &sym : replace_expr.FreeSymbols()) {
    auto replace_sym = FindReplacements(sym);
    var_replacements.emplace_back(std::make_pair(sym, replace_sym));
  }
  iter->second.has_replace = false;
  return replace_expr.Replace(var_replacements);
}

const std::vector<SymbolCheckInfo> ShapeEnvAttr::GetAllSymbolCheckInfos() const {
  std::vector<SymbolCheckInfo> results;
  for (const auto &iter : symbol_check_infos_) {
    results.emplace_back(iter);
  }
  return results;
}

const std::vector<SymbolCheckInfo> ShapeEnvAttr::GetAllSymbolAssertInfos() const {
  std::vector<SymbolCheckInfo> results;
  for (const auto &iter : symbol_assert_infos_) {
    results.emplace_back(iter);
  }
  return results;
};

void ShapeEnvAttr::SimplifySymbolCheckInfo(
    std::set<SymbolCheckInfo, SymbolCheckInfoKeyLess> &symbol_check_infos) const {
  std::vector<SymbolCheckInfo> simplify_symbol_check_info;
  for (auto &iter : symbol_check_infos) {
    const auto simple_expr = iter.expr.Simplify().CanonicalizeBoolExpr();
    if (simple_expr.IsConstExpr()) {
      continue;
    }
    (void)simplify_symbol_check_info.emplace_back(SymbolCheckInfo(simple_expr));
  }
  (void)symbol_check_infos.insert(simplify_symbol_check_info.begin(),
      simplify_symbol_check_info.end());
}

void ShapeEnvAttr::SimplifySymbolCheckInfo() {
  GELOGD("Start simplify guard");
  SimplifySymbolCheckInfo(symbol_check_infos_);
  SimplifySymbolCheckInfo(symbol_assert_infos_);
}

ge::Expression ShapeEnvAttr::Simplify(const ge::Expression &expr) {
  std::vector<std::pair<Expression, Expression>> var_replacements;
  // 初始化replacements遍历状态
  for (auto &iter : replacements_) {
    iter.second.has_replace = false;
  }
  for (const auto &sym : expr.FreeSymbols()) {
    auto replace_expr = FindReplacements(sym);
    if ((!replace_expr.IsVariableExpr()) || (replace_expr != sym)) {
      var_replacements.emplace_back(std::make_pair(sym, replace_expr));
    }
  }
  if (!var_replacements.empty()) {
    auto result_expr = expr.Replace(var_replacements);
    GELOGI("Simplify expr: %s to expr: %s",
        SymbolicUtils::ToString(expr).c_str(), SymbolicUtils::ToString(result_expr).c_str());
    GE_ASSERT_NOTNULL(result_expr.impl_);
    return Expression(result_expr.impl_->Simplify());
  }
  return Expression(expr.impl_->Simplify());
}

ge::Expression ShapeEnvAttr::EvaluateExpr(const ge::Expression &expr) {
  std::vector<std::pair<Expression, Expression>> var_to_val;
  auto free_symbols = expr.FreeSymbols();
  for (const auto &free_sym : free_symbols) {
    const auto &iter = symbol_to_value_.find(free_sym);
    if (iter != symbol_to_value_.end()) {
      var_to_val.emplace_back(std::make_pair(free_sym, Symbol(iter->second)));
    }
  }
  return expr.Subs(var_to_val);
}

TriBool ShapeEnvAttr::HasSymbolInfo(const Expression &expr) const {
  Expression e = expr.CanonicalizeBoolExpr();
  if (HasSymbolCheckInfo(e) || HasSymbolAssertInfo(e)) {
    return TriBool::kTrue;
  }
  return TriBool::kUnknown;
}

void ShapeEnvAttr::AppendInitReplacement(const ge::Expression &expr) {
  if (replacements_.find(expr) == replacements_.end()) {
    (void)replacements_.emplace(std::make_pair(expr, Replacement(expr, 1)));
  }
}

graphStatus ShapeEnvAttr::FindRootExpr(const ge::Expression &expr, ge::Expression &root_expr) const {
  const auto &iter = replacements_.find(expr);
  GE_ASSERT_TRUE(iter != replacements_.end(),
    "Cannot find replacement of expr: %s", SymbolicUtils::ToString(expr).c_str());
  if (iter->second.replace_expr == expr) {
    root_expr = expr;
    return GRAPH_SUCCESS;
  }
  GE_ASSERT_SUCCESS(FindRootExpr(iter->second.replace_expr, root_expr));
  return GRAPH_SUCCESS;
}

std::vector<std::pair<Expression, SourcePtr>> ShapeEnvAttr::GetAllSym2Src() {
  std::vector<std::pair<Expression, SourcePtr>> result;
  for (const auto &iter : symbol_to_source_) {
    result.emplace_back(iter.first, iter.second);
  }
  return result;
}

bool Replacement::operator<=(const Replacement &other) {
  // 并查集的根节点优先级: 常数 > 表达式 > 变量
  if (replace_expr.IsConstExpr()) {
    if (other.replace_expr.IsConstExpr()) {
      return rank <= other.rank;
    }
    return false;
  }
  if (replace_expr.IsVariableExpr()) {
    if (other.replace_expr.IsVariableExpr()) {
      return rank <= other.rank;
    }
    return true;
  }
  if (other.replace_expr.IsConstExpr()) {
    return true;
  }
  if (other.replace_expr.IsVariableExpr()) {
    return false;
  }
  return rank <= other.rank;
}

graphStatus ShapeEnvAttr::MergeReplacement(const ge::Expression &expr1,
    const ge::Expression &expr2) {
  ge::Expression father_expr1;
  GE_ASSERT_SUCCESS(FindRootExpr(expr1, father_expr1));
  ge::Expression father_expr2;
  GE_ASSERT_SUCCESS(FindRootExpr(expr2, father_expr2));
  auto &replacement_1 = replacements_[father_expr1];
  auto &replacement_2 = replacements_[father_expr2];
  if (replacement_1 <= replacement_2) {
    replacement_1.replace_expr = father_expr2;
    if (replacement_2.rank <= replacement_1.rank) {
      replacement_2.rank = replacement_1.rank + 1;
    }
  } else {
    replacement_2.replace_expr = father_expr1;
    if (replacement_1.rank <= replacement_2.rank) {
      replacement_1.rank = replacement_2.rank + 1;
    }
  }
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::MergePath() {
  for (auto &iter : replacements_) {
    ge::Expression root_expr;
    GE_ASSERT_SUCCESS(FindRootExpr(iter.first, root_expr));
    iter.second.replace_expr = root_expr;
    iter.second.rank = 1;
  }
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::AppendReplacement(const ge::Expression &target, const ge::Expression &replacement) {
    if (target == replacement) {
        return GRAPH_SUCCESS;
    }
    ge::Expression expr1 = target;
    ge::Expression expr2 = replacement;
    auto expr = sym::Eq(target, replacement).CanonicalizeBoolExpr();
    std::vector<Expression> args = expr.GetArgs();
    if (args.size() == kSizeTwo) {
        expr1 = args[0];
        expr2 = args[1];
        GELOGD("expr1 %s->%s, expr2 %s->%s",
          SymbolicUtils::ToString(target).c_str(), SymbolicUtils::ToString(expr1).c_str(),
          SymbolicUtils::ToString(replacement).c_str(), SymbolicUtils::ToString(expr2).c_str());
    }

  // 仅支持 符号->常量,符号->表达式,符号->符号 映射
  if (expr1.IsConstExpr()) {
    if (!expr2.IsVariableExpr()) {
      GELOGW("Unsupport append replacement %s to %s",
          SymbolicUtils::ToString(expr1).c_str(), SymbolicUtils::ToString(expr2).c_str());
      return GRAPH_SUCCESS;
    }
  } else if (!expr1.IsVariableExpr()) {
    if (!expr2.IsVariableExpr()) {
      GELOGW("Unsupport append replacement %s to %s",
          SymbolicUtils::ToString(expr1).c_str(), SymbolicUtils::ToString(expr2).c_str());
      return GRAPH_SUCCESS;
    }
  }
  // 判断replacement是否成环
  if (CheckReplacementCycle(expr1, expr2)) {
    GELOGW("Unsupport append replacement %s to %s, replacement contains the other.",
      SymbolicUtils::ToString(expr1).c_str(), SymbolicUtils::ToString(expr2).c_str());
    return GRAPH_SUCCESS;
  }
  AppendInitReplacement(expr1);
  AppendInitReplacement(expr2);
  GE_ASSERT_SUCCESS(MergeReplacement(expr1, expr2));
  // 路径压缩
  GE_ASSERT_SUCCESS(MergePath());
  // replace插入后全量化简已有的guard
  SimplifySymbolCheckInfo();
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::AppendSymbolAssertInfo(const ge::Expression &expr,
    const std::string &file, const int64_t line) {
  GE_ASSERT_TRUE(expr.IsBooleanExpr(),
      "Assert expr: %s should be boolean", SymbolicUtils::ToString(expr).c_str());
  if (!expr.IsConstExpr()) {
    (void)symbol_assert_infos_.emplace(SymbolCheckInfo(expr.CanonicalizeBoolExpr(), file, line, GetGuardDfxContextInfo()));
  }
  return GRAPH_SUCCESS;
}

graphStatus ShapeEnvAttr::AppendSymbolCheckInfo(const ge::Expression &expr,
    const std::string &file, const int64_t line) {
  GE_ASSERT_TRUE(expr.IsBooleanExpr(),
      "Check expr: %s should be boolean", SymbolicUtils::ToString(expr).c_str());
  if (!expr.IsConstExpr()) {
    (void)symbol_check_infos_.emplace(SymbolCheckInfo(expr.CanonicalizeBoolExpr(), file, line, GetGuardDfxContextInfo()));
  }
  return GRAPH_SUCCESS;
}

void ShapeEnvAttr::SetGuardDfxContextInfo(const std::string &guard_dfx_info) const {
  guard_dfx_info_ = guard_dfx_info;
}

void ShapeEnvAttr::ClearGuardDfxContextInfo() const {
  guard_dfx_info_.clear();
}

std::string ShapeEnvAttr::GetGuardDfxContextInfo() const {
  return guard_dfx_info_;
}

bool ShapeEnvAttr::CheckReplacementCycle(const Expression &expr1, const Expression &expr2) const {
  Expression root_expr1;
  if (replacements_.find(expr1) == replacements_.end()) {
    root_expr1 = expr1;
  } else {
    GE_ASSERT_SUCCESS(FindRootExpr(expr1, root_expr1));
  }
  Expression root_expr2;
  if (replacements_.find(expr2) == replacements_.end()) {
    root_expr2 = expr2;
  } else {
    GE_ASSERT_SUCCESS(FindRootExpr(expr2, root_expr2));
  }

  /*
   * *判断exp1和expr2是否相互包含,
   * 1) 先判断原表达式是否包含, 例如expr1: s0 + s1, expr2: s1, 则expr1包含expr2
   * 2) 然后判断化简后是否包含, 例如已有replacement: s1 == s2, expr1: s0 + s1, expr2: s2
   *    则expr1化简后为s0 + s2, 包含expr2
   */
  if (root_expr2.IsVariableExpr()) {
    return root_expr1.ContainVar(root_expr2) || root_expr1.Simplify().ContainVar(root_expr2);
  } else if (root_expr1.IsVariableExpr()) {
    return root_expr2.ContainVar(root_expr1) || root_expr2.Simplify().ContainVar(root_expr1);
  }
  return false;
}

} // namespace ge