* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "generator/preprocess/args_replace.h"
#include <string>
#include <array>
#include "common/checker.h"
#include "util/base_types_printer.h"
#include "base/att_const_values.h"
#include "generator/preprocess/preprocess_utils.h"
namespace att {
const std::string kAlignDelim = "_div_align";
const std::string kPowBase = "_base";
const int32_t kDefualtInitValue = 0;
inline bool IsPowerOfTwo(int32_t n) {
if (n <= 0) {
return false;
}
uint32_t val = static_cast<uint32_t>(n);
return (val & (val - 1)) == 0U;
}
void ArgsReplacer::GetReplaceResult(ExprExprMap &replaced_vars, ExprExprMap &replacements,
ExprExprMap &new_expr_replacements) {
for (auto &new_to_ori_expr : new_expr_ori_expr_map_) {
if (new_to_ori_expr.first != new_to_ori_expr.second) {
replaced_vars.emplace(new_to_ori_expr.first, new_to_ori_expr.second);
replacements.emplace(new_to_ori_expr.second, ori_expr_new_expr_map_[new_to_ori_expr.second]);
new_expr_replacements.emplace(new_to_ori_expr.first, new_expr_replacements_[new_to_ori_expr.first]);
}
}
}
void ArgsReplacer::GetNewExprInitValue(ExprExprMap &new_expr_init_values) {
new_expr_init_values = new_expr_init_values_;
}
bool ArgsReplacer::DoReplace(const std::map<std::string, std::vector<std::pair<Expr, Expr>>> &eq_exprs) {
Reset();
if (!InitWithEqCons(eq_exprs)) {
return false;
}
if (!GetLeafExprs()) {
GELOGW("Replace failed. Get leaf expr failed.");
return false;
}
if (!ReplaceParentExprs()) {
GELOGW("Replace failed. Replace parent expr failed.");
return false;
}
ReplaceNaiveExpr();
return true;
}
Expr ArgsReplacer::ReplaceCommonExpr(const Expr &e, const Expr &align,
ExprExprMap &new_expr_ori_expr_map,
ExprExprMap &new_expr_replacements) {
if (align.IsConstExpr() &&
(af::SymbolicUtils::StaticCheckLe(align, af::Symbol(kMinDimLength)) == af::TriBool::kTrue)) {
new_expr_ori_expr_map.emplace(e, e);
new_expr_replacements.emplace(e, e);
return e;
}
Expr new_epxr = CreateExpr((Str(e) + kAlignDelim).c_str());
new_expr_ori_expr_map.emplace(new_epxr, e);
auto replace_var = af::sym::Mul(align, new_epxr);
new_expr_replacements.emplace(new_epxr, af::sym::Div(e, align));
new_expr_init_values_.emplace(new_epxr, af::sym::kSymbolOne);
return replace_var;
}
bool ArgsReplacer::IsAllFactorReplaced(const ExprExprMap &replaced_vars, std::vector<Expr> &factors) const {
for (const auto &factor : factors) {
if (replaced_vars.find(factor) == replaced_vars.end()) {
return false;
}
}
return true;
}
bool ArgsReplacer::InitWithEqCons(const std::map<std::string, std::vector<std::pair<Expr, Expr>>> &eq_exprs) {
for (const auto &eq_cons : eq_exprs) {
GE_ASSERT_TRUE(eq_cons.first == kFatherToChildNoTail,
"CreateExpr[%s] Repalce doesn't support non notail eq exprs.", eq_cons.first.c_str());
for (const auto &expr : eq_cons.second) {
expr_factors_map_[expr.first].emplace_back(expr.second);
factor_expr_map_[expr.second] = expr.first;
}
}
return true;
}
bool ArgsReplacer::GetLeafExprs() {
for (auto &expr_factors : expr_factors_map_) {
for (auto &factor : expr_factors.second) {
if (expr_factors_map_.find(factor) == expr_factors_map_.end()) {
if (!IsInExprInfo(vars_infos_, factor)) {
return false;
}
Expr factor_align = vars_infos_.at(factor).align;
if ((factor_align.IsConstExpr() &&
(af::SymbolicUtils::StaticCheckLe(factor_align, af::Symbol(kMinDimLength)) == af::TriBool::kTrue)) ||
(vars_infos_.at(factor).do_search == false)) {
GetSelfReplacedVars(factor);
continue;
}
if (factor_align.IsConstExpr()) {
int32_t factor_align_const_value;
factor_align.GetConstValue(factor_align_const_value);
GE_ASSERT_TRUE(IsPowerOfTwo(factor_align_const_value), "CreateExpr Repalce doesn't support align is not power of 2.");
}
GE_ASSERT_TRUE(ori_expr_new_expr_map_.find(factor) == ori_expr_new_expr_map_.end(),
"CreateExpr Repalce doesn't support multi-parent case.");
Expr new_variable = CreateExpr((Str(factor) + kPowBase).c_str());
Expr new_factor_expr = af::sym::Mul(factor_align, af::sym::Pow(CreateExpr(kBaseTwo), new_variable));
new_expr_replacements_.emplace(new_variable, af::sym::Log(af::sym::Div(factor, factor_align),
CreateExpr(kBaseTwo)));
new_expr_init_values_.emplace(new_variable, af::sym::kSymbolZero);
new_expr_ori_expr_map_.emplace(new_variable, factor);
ori_expr_new_expr_map_.emplace(factor, new_factor_expr);
replaced_expr_queue_.push(factor);
}
}
}
return true;
}
void ArgsReplacer::GetSelfReplacedVars(const Expr &expr) {
new_expr_replacements_.emplace(expr, expr);
new_expr_ori_expr_map_.emplace(expr, expr);
ori_expr_new_expr_map_.emplace(expr, expr);
replaced_expr_queue_.push(expr);
}
void ArgsReplacer::GetAlignReplacedVars(const Expr &expr) {
Expr replaced_expr = ReplaceCommonExpr(expr, vars_infos_.at(expr).align,
new_expr_ori_expr_map_, new_expr_replacements_);
ori_expr_new_expr_map_.emplace(expr, replaced_expr);
replaced_expr_queue_.push(expr);
}
void ArgsReplacer::GetFactorReplacedVars(const Expr &expr) {
if (IsAllFactorReplaced(ori_expr_new_expr_map_, expr_factors_map_[expr])) {
Expr new_align_expr = vars_infos_.at(expr).align;
for (auto &factor : expr_factors_map_[expr]) {
new_align_expr = af::sym::Max(new_align_expr, ori_expr_new_expr_map_[factor]);
}
Expr new_variable;
Expr new_expr;
if (factor_expr_map_.find(expr) == factor_expr_map_.end()) {
new_variable = CreateExpr((Str(expr) + kAlignDelim).c_str());
new_expr = af::sym::Mul(new_align_expr, new_variable);
new_expr_replacements_.emplace(new_variable, af::sym::Div(expr, new_align_expr));
new_expr_init_values_.emplace(new_variable, af::sym::kSymbolOne);
} else {
new_variable = CreateExpr((Str(expr) + kPowBase).c_str());
new_expr = af::sym::Mul(new_align_expr, af::sym::Pow(CreateExpr(kBaseTwo), new_variable));
new_expr_replacements_.emplace(new_variable, af::sym::Log(af::sym::Div(expr, new_align_expr), CreateExpr(kBaseTwo)));
new_expr_init_values_.emplace(new_variable, af::sym::kSymbolZero);
}
new_expr_ori_expr_map_.emplace(new_variable, expr);
ori_expr_new_expr_map_.emplace(expr, new_expr);
replaced_expr_queue_.push(expr);
}
}
bool ArgsReplacer::ReplaceParentExprs() {
while (!replaced_expr_queue_.empty()) {
Expr current_expr = replaced_expr_queue_.front();
replaced_expr_queue_.pop();
const std::vector<Expr> parent_exprs = vars_infos_.at(current_expr).from_axis_size;
for (const auto &parent_expr : parent_exprs) {
if (!IsInExprInfo(vars_infos_, parent_expr)) {
continue;
}
if (!vars_infos_.at(parent_expr).do_search) {
GetSelfReplacedVars(parent_expr);
continue;
}
if (expr_factors_map_.find(parent_expr) == expr_factors_map_.end()) {
GetAlignReplacedVars(parent_expr);
} else {
GetFactorReplacedVars(parent_expr);
}
}
}
return true;
}
void ArgsReplacer::ReplaceNaiveExpr() {
for (const auto &expr_info : vars_infos_) {
if (ori_expr_new_expr_map_.find(expr_info.first) == ori_expr_new_expr_map_.end()) {
if (!expr_info.second.do_search) {
ori_expr_new_expr_map_.emplace(expr_info.first, expr_info.first);
continue;
}
Expr new_expr = ReplaceCommonExpr(expr_info.first, expr_info.second.align,
new_expr_ori_expr_map_, new_expr_replacements_);
ori_expr_new_expr_map_.emplace(expr_info.first, new_expr);
}
}
}
void ArgsReplacer::Reset() {
ori_expr_new_expr_map_.clear();
new_expr_ori_expr_map_.clear();
new_expr_replacements_.clear();
new_expr_init_values_.clear();
factor_expr_map_.clear();
std::queue<Expr> empty_queue;
replaced_expr_queue_.swap(empty_queue);
expr_factors_map_.clear();
}
}