* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ir/transforms/base/mutator.h"
#include <any>
#include <cstddef>
#include <map>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "core/dtype.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/kind_traits.h"
#include "ir/memref.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/span.h"
#include "ir/stmt.h"
#include "ir/transforms/base/functor.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
namespace {
ExprPtr ReconstructBinaryExpr(ObjectKind kind, ExprPtr left, ExprPtr right, DataType dtype, const Span& span)
{
switch (kind) {
case ObjectKind::Add:
return std::make_shared<const Add>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Sub:
return std::make_shared<const Sub>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Mul:
return std::make_shared<const Mul>(std::move(left), std::move(right), dtype, span);
case ObjectKind::FloorDiv:
return std::make_shared<const FloorDiv>(std::move(left), std::move(right), dtype, span);
case ObjectKind::FloorMod:
return std::make_shared<const FloorMod>(std::move(left), std::move(right), dtype, span);
case ObjectKind::FloatDiv:
return std::make_shared<const FloatDiv>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Min:
return std::make_shared<const Min>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Max:
return std::make_shared<const Max>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Pow:
return std::make_shared<const Pow>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Eq:
return std::make_shared<const Eq>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Ne:
return std::make_shared<const Ne>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Lt:
return std::make_shared<const Lt>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Le:
return std::make_shared<const Le>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Gt:
return std::make_shared<const Gt>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Ge:
return std::make_shared<const Ge>(std::move(left), std::move(right), dtype, span);
case ObjectKind::And:
return std::make_shared<const And>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Or:
return std::make_shared<const Or>(std::move(left), std::move(right), dtype, span);
case ObjectKind::Xor:
return std::make_shared<const Xor>(std::move(left), std::move(right), dtype, span);
case ObjectKind::BitAnd:
return std::make_shared<const BitAnd>(std::move(left), std::move(right), dtype, span);
case ObjectKind::BitOr:
return std::make_shared<const BitOr>(std::move(left), std::move(right), dtype, span);
case ObjectKind::BitXor:
return std::make_shared<const BitXor>(std::move(left), std::move(right), dtype, span);
case ObjectKind::BitShiftLeft:
return std::make_shared<const BitShiftLeft>(std::move(left), std::move(right), dtype, span);
case ObjectKind::BitShiftRight:
return std::make_shared<const BitShiftRight>(std::move(left), std::move(right), dtype, span);
default:
INTERNAL_CHECK_SPAN(false, span) << "Unknown binary expression kind in ReconstructBinaryExpr";
}
}
ExprPtr ReconstructUnaryExpr(ObjectKind kind, ExprPtr operand, DataType dtype, const Span& span)
{
switch (kind) {
case ObjectKind::Abs:
return std::make_shared<const Abs>(std::move(operand), dtype, span);
case ObjectKind::Neg:
return std::make_shared<const Neg>(std::move(operand), dtype, span);
case ObjectKind::Not:
return std::make_shared<const Not>(std::move(operand), dtype, span);
case ObjectKind::BitNot:
return std::make_shared<const BitNot>(std::move(operand), dtype, span);
case ObjectKind::Cast:
return std::make_shared<const Cast>(std::move(operand), dtype, span);
default:
INTERNAL_CHECK_SPAN(false, span) << "Unknown unary expression kind in ReconstructUnaryExpr";
}
}
VarPtr AsVarLikeExpr(const ExprPtr& expr, const Span& span, const std::string& context)
{
if (auto var = As<Var>(expr)) {
return var;
}
if (auto memref = As<MemRef>(expr)) {
return std::static_pointer_cast<const Var>(memref);
}
INTERNAL_CHECK_SPAN(false, span) << context;
return nullptr;
}
ExprPtr ReconstructCallWithKwargs(
const std::string& name, std::vector<ExprPtr> args, const std::vector<std::pair<std::string, std::any>>& kwargs,
const TypePtr& type, const Span& span)
{
return std::make_shared<Call>(name, std::move(args), kwargs, type, span);
}
}
ProgramPtr IRMutator::VisitProgram(const ProgramPtr& program)
{
std::vector<FunctionPtr> new_functions;
bool changed = false;
for (const auto& entry : program->functions_) {
auto new_func = VisitFunction(entry.second);
new_functions.emplace_back(new_func);
if (new_func != entry.second) {
changed = true;
}
}
if (!changed) {
return program;
}
return std::make_shared<const Program>(std::move(new_functions), program->name_, program->span_);
}
FunctionPtr IRMutator::VisitFunction(const FunctionPtr& func)
{
auto new_body = VisitStmt(func->body_);
if (new_body.get() == func->body_.get()) {
return func;
}
return std::make_shared<const Function>(
func->name_, func->params_, func->returnTypes_, std::move(new_body), func->span_, func->funcType_);
}
ExprPtr IRMutator::VisitExpr(const ExprPtr& expr) { return ExprFunctor<ExprPtr>::VisitExpr(expr); }
StmtPtr IRMutator::VisitStmt(const StmtPtr& stmt) { return StmtFunctor<StmtPtr>::VisitStmt(stmt); }
ExprPtr IRMutator::VisitExpr_(const VarPtr& op)
{
auto it = var_remap_.find(op.get());
if (it != var_remap_.end()) {
return it->second;
}
return op;
}
ExprPtr IRMutator::VisitExpr_(const MemRefPtr& op) { return op; }
ExprPtr IRMutator::VisitExpr_(const ConstIntPtr& op) { return op; }
ExprPtr IRMutator::VisitExpr_(const ConstFloatPtr& op) { return op; }
ExprPtr IRMutator::VisitExpr_(const ConstBoolPtr& op) { return op; }
ExprPtr IRMutator::VisitExpr_(const CallPtr& op)
{
std::vector<ExprPtr> new_args;
bool changed = false;
new_args.reserve(op->args_.size());
for (size_t i = 0; i < op->args_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "Call has null argument at index " << i;
auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
INTERNAL_CHECK_SPAN(new_arg, op->span_) << "Call argument at index " << i << " mutated to null";
new_args.push_back(new_arg);
if (new_arg.get() != op->args_[i].get()) {
changed = true;
}
}
if (changed) {
return ReconstructCallWithKwargs(op->name_, std::move(new_args), op->kwargs_, op->GetType(), op->span_);
}
return op;
}
ExprPtr IRMutator::VisitExpr_(const MakeTuplePtr& op)
{
std::vector<ExprPtr> new_elements;
new_elements.reserve(op->elements_.size());
bool changed = false;
for (const auto& elem : op->elements_) {
INTERNAL_CHECK_SPAN(elem, op->span_) << "MakeTuple has null element";
auto new_elem = ExprFunctor<ExprPtr>::VisitExpr(elem);
INTERNAL_CHECK_SPAN(new_elem, op->span_) << "MakeTuple element mutated to null";
new_elements.push_back(new_elem);
if (new_elem.get() != elem.get()) {
changed = true;
}
}
if (changed) {
return std::make_shared<const MakeTuple>(std::move(new_elements), op->span_);
}
return op;
}
ExprPtr IRMutator::VisitExpr_(const GetItemExprPtr& op)
{
INTERNAL_CHECK_SPAN(op->value_, op->span_) << "GetItemExpr has null value";
INTERNAL_CHECK_SPAN(op->slice_, op->span_) << "GetItemExpr has null slice";
auto new_value = ExprFunctor<ExprPtr>::VisitExpr(op->value_);
auto new_slice = ExprFunctor<ExprPtr>::VisitExpr(op->slice_);
INTERNAL_CHECK_SPAN(new_value, op->span_) << "GetItemExpr value mutated to null";
INTERNAL_CHECK_SPAN(new_slice, op->span_) << "GetItemExpr slice mutated to null";
if (new_value.get() != op->value_.get() || new_slice.get() != op->slice_.get()) {
return std::make_shared<const GetItemExpr>(std::move(new_value), std::move(new_slice), op->span_);
}
return op;
}
ExprPtr IRMutator::VisitExpr_(const ScalarExprPtr& op) { return op; }
ExprPtr IRMutator::VisitBinaryExpr_(const BinaryExprPtr& op)
{
INTERNAL_CHECK_SPAN(op->left_, op->span_) << "BinaryExpr has null left operand";
INTERNAL_CHECK_SPAN(op->right_, op->span_) << "BinaryExpr has null right operand";
auto new_left = ExprFunctor<ExprPtr>::VisitExpr(op->left_);
auto new_right = ExprFunctor<ExprPtr>::VisitExpr(op->right_);
INTERNAL_CHECK_SPAN(new_left, op->span_) << "BinaryExpr left operand mutated to null";
INTERNAL_CHECK_SPAN(new_right, op->span_) << "BinaryExpr right operand mutated to null";
if (new_left.get() != op->left_.get() || new_right.get() != op->right_.get()) {
auto scalar_type = As<ScalarType>(op->GetType());
INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "BinaryExpr has null type";
return ReconstructBinaryExpr(
op->GetKind(), std::move(new_left), std::move(new_right), scalar_type->dtype_, op->span_);
}
return op;
}
ExprPtr IRMutator::VisitUnaryExpr_(const UnaryExprPtr& op)
{
INTERNAL_CHECK_SPAN(op->operand_, op->span_) << "UnaryExpr has null operand";
auto new_operand = ExprFunctor<ExprPtr>::VisitExpr(op->operand_);
INTERNAL_CHECK_SPAN(new_operand, op->span_) << "UnaryExpr operand mutated to null";
if (new_operand.get() != op->operand_.get()) {
auto scalar_type = As<ScalarType>(op->GetType());
INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "UnaryExpr has null type";
return ReconstructUnaryExpr(op->GetKind(), std::move(new_operand), scalar_type->dtype_, op->span_);
}
return op;
}
#define DEFINE_BINARY_MUTATOR(OpType) \
ExprPtr IRMutator::VisitExpr_(const OpType##Ptr& op) { return VisitBinaryExpr_(op); }
DEFINE_BINARY_MUTATOR(Add)
DEFINE_BINARY_MUTATOR(Sub)
DEFINE_BINARY_MUTATOR(Mul)
DEFINE_BINARY_MUTATOR(FloorDiv)
DEFINE_BINARY_MUTATOR(FloorMod)
DEFINE_BINARY_MUTATOR(FloatDiv)
DEFINE_BINARY_MUTATOR(Min)
DEFINE_BINARY_MUTATOR(Max)
DEFINE_BINARY_MUTATOR(Pow)
DEFINE_BINARY_MUTATOR(Eq)
DEFINE_BINARY_MUTATOR(Ne)
DEFINE_BINARY_MUTATOR(Lt)
DEFINE_BINARY_MUTATOR(Le)
DEFINE_BINARY_MUTATOR(Gt)
DEFINE_BINARY_MUTATOR(Ge)
DEFINE_BINARY_MUTATOR(And)
DEFINE_BINARY_MUTATOR(Or)
DEFINE_BINARY_MUTATOR(Xor)
DEFINE_BINARY_MUTATOR(BitAnd)
DEFINE_BINARY_MUTATOR(BitOr)
DEFINE_BINARY_MUTATOR(BitXor)
DEFINE_BINARY_MUTATOR(BitShiftLeft)
DEFINE_BINARY_MUTATOR(BitShiftRight)
#undef DEFINE_BINARY_MUTATOR
#define DEFINE_UNARY_MUTATOR(OpType) \
ExprPtr IRMutator::VisitExpr_(const OpType##Ptr& op) { return VisitUnaryExpr_(op); }
DEFINE_UNARY_MUTATOR(Abs)
DEFINE_UNARY_MUTATOR(Neg)
DEFINE_UNARY_MUTATOR(Not)
DEFINE_UNARY_MUTATOR(BitNot)
DEFINE_UNARY_MUTATOR(Cast)
#undef DEFINE_UNARY_MUTATOR
StmtPtr IRMutator::VisitStmt_(const AssignStmtPtr& op)
{
INTERNAL_CHECK_SPAN(op->var_, op->span_) << "AssignStmt has null var";
INTERNAL_CHECK_SPAN(op->value_, op->span_) << "AssignStmt has null value";
auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->var_);
auto new_value = ExprFunctor<ExprPtr>::VisitExpr(op->value_);
INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "AssignStmt var mutated to null";
INTERNAL_CHECK_SPAN(new_value, op->span_) << "AssignStmt value mutated to null";
auto new_var = AsVarLikeExpr(new_var_expr, op->span_, "AssignStmt var is not a Var after mutation");
if (new_var.get() != op->var_.get() || new_value.get() != op->value_.get()) {
return std::make_shared<const AssignStmt>(std::move(new_var), std::move(new_value), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const IfStmtPtr& op)
{
INTERNAL_CHECK_SPAN(op->condition_, op->span_) << "IfStmt has null condition";
auto new_condition = ExprFunctor<ExprPtr>::VisitExpr(op->condition_);
INTERNAL_CHECK_SPAN(new_condition, op->span_) << "IfStmt condition mutated to null";
INTERNAL_CHECK_SPAN(op->thenBody_, op->span_) << "IfStmt has null then_body";
auto new_then_body = StmtFunctor<StmtPtr>::VisitStmt(op->thenBody_);
INTERNAL_CHECK_SPAN(new_then_body, op->span_) << "IfStmt then_body mutated to null";
bool then_changed = (new_then_body.get() != op->thenBody_.get());
std::optional<StmtPtr> new_else_body;
bool else_changed = false;
if (op->elseBody_.has_value()) {
INTERNAL_CHECK_SPAN(*op->elseBody_, op->span_) << "IfStmt has null else_body";
auto new_stmt = StmtFunctor<StmtPtr>::VisitStmt(*op->elseBody_);
INTERNAL_CHECK_SPAN(new_stmt, op->span_) << "IfStmt else_body mutated to null";
new_else_body = new_stmt;
if (new_stmt.get() != op->elseBody_->get()) {
else_changed = true;
}
}
std::vector<VarPtr> new_return_vars;
bool return_vars_changed = false;
new_return_vars.reserve(op->returnVars_.size());
for (size_t i = 0; i < op->returnVars_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "IfStmt has null return_vars at index " << i;
auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "IfStmt return_vars at index " << i << " mutated to null";
auto new_var = AsVarLikeExpr(
new_var_expr, op->span_,
"IfStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
new_return_vars.push_back(new_var);
if (new_var.get() != op->returnVars_[i].get()) {
return_vars_changed = true;
}
}
if (new_condition.get() != op->condition_.get() || then_changed || else_changed || return_vars_changed) {
if (new_else_body.has_value()) {
return std::make_shared<const IfStmt>(
std::move(new_condition), std::move(new_then_body), *new_else_body, std::move(new_return_vars),
op->span_);
} else {
return std::make_shared<const IfStmt>(
std::move(new_condition), std::move(new_then_body), std::nullopt, std::move(new_return_vars),
op->span_);
}
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const YieldStmtPtr& op)
{
std::vector<ExprPtr> new_value;
bool changed = false;
new_value.reserve(op->value_.size());
for (size_t i = 0; i < op->value_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->value_[i], op->span_) << "YieldStmt has null value at index " << i;
auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->value_[i]);
INTERNAL_CHECK_SPAN(new_expr, op->span_) << "YieldStmt value at index " << i << " mutated to null";
new_value.push_back(new_expr);
if (new_expr.get() != op->value_[i].get()) {
changed = true;
}
}
if (changed) {
return std::make_shared<const YieldStmt>(std::move(new_value), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const ReturnStmtPtr& op)
{
std::vector<ExprPtr> new_value;
bool changed = false;
new_value.reserve(op->value_.size());
for (size_t i = 0; i < op->value_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->value_[i], op->span_) << "ReturnStmt has null value at index " << i;
auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->value_[i]);
INTERNAL_CHECK_SPAN(new_expr, op->span_) << "ReturnStmt value at index " << i << " mutated to null";
new_value.push_back(new_expr);
if (new_expr.get() != op->value_[i].get()) {
changed = true;
}
}
if (changed) {
return std::make_shared<const ReturnStmt>(std::move(new_value), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const ForStmtPtr& op)
{
INTERNAL_CHECK_SPAN(op->loopVar_, op->span_) << "ForStmt has null loop_var";
INTERNAL_CHECK_SPAN(op->start_, op->span_) << "ForStmt has null start";
INTERNAL_CHECK_SPAN(op->stop_, op->span_) << "ForStmt has null stop";
INTERNAL_CHECK_SPAN(op->step_, op->span_) << "ForStmt has null step";
auto new_loop_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->loopVar_);
INTERNAL_CHECK_SPAN(new_loop_var_expr, op->span_) << "ForStmt loop_var mutated to null";
auto new_loop_var = AsVarLikeExpr(new_loop_var_expr, op->span_, "ForStmt loop_var is not a Var after mutation");
auto new_start = ExprFunctor<ExprPtr>::VisitExpr(op->start_);
INTERNAL_CHECK_SPAN(new_start, op->span_) << "ForStmt start mutated to null";
auto new_stop = ExprFunctor<ExprPtr>::VisitExpr(op->stop_);
INTERNAL_CHECK_SPAN(new_stop, op->span_) << "ForStmt stop mutated to null";
auto new_step = ExprFunctor<ExprPtr>::VisitExpr(op->step_);
INTERNAL_CHECK_SPAN(new_step, op->span_) << "ForStmt step mutated to null";
std::vector<IterArgPtr> new_iter_args;
bool iter_args_changed = false;
new_iter_args.reserve(op->iterArgs_.size());
for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->iterArgs_[i]->initValue_, op->span_)
<< "ForStmt has null iter_args initValue at index " << i;
auto new_init_value = ExprFunctor<ExprPtr>::VisitExpr(op->iterArgs_[i]->initValue_);
INTERNAL_CHECK_SPAN(new_init_value, op->span_) << "ForStmt iter_args at index " << i << " mutated to null";
if (new_init_value.get() != op->iterArgs_[i]->initValue_.get()) {
new_iter_args.push_back(
std::make_shared<const IterArg>(op->iterArgs_[i]->iterVar_, std::move(new_init_value)));
iter_args_changed = true;
} else {
new_iter_args.push_back(op->iterArgs_[i]);
}
}
for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
if (new_iter_args[i].get() != op->iterArgs_[i].get()) {
var_remap_[op->iterArgs_[i]->iterVar_.get()] = new_iter_args[i]->iterVar_;
}
}
INTERNAL_CHECK_SPAN(op->body_, op->span_) << "ForStmt has null body";
auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
INTERNAL_CHECK_SPAN(new_body, op->span_) << "ForStmt body mutated to null";
bool body_changed = (new_body.get() != op->body_.get());
for (const auto& old_iter_arg : op->iterArgs_) {
var_remap_.erase(old_iter_arg->iterVar_.get());
}
std::vector<VarPtr> new_return_vars;
bool return_vars_changed = false;
new_return_vars.reserve(op->returnVars_.size());
for (size_t i = 0; i < op->returnVars_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "ForStmt has null return_vars at index " << i;
auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "ForStmt return_vars at index " << i << " mutated to null";
auto new_var = AsVarLikeExpr(
new_var_expr, op->span_,
"ForStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
new_return_vars.push_back(new_var);
if (new_var.get() != op->returnVars_[i].get()) {
return_vars_changed = true;
}
}
if (new_loop_var.get() != op->loopVar_.get() || new_start.get() != op->start_.get() ||
new_stop.get() != op->stop_.get() || new_step.get() != op->step_.get() || iter_args_changed || body_changed ||
return_vars_changed) {
return std::make_shared<const ForStmt>(
std::move(new_loop_var), std::move(new_start), std::move(new_stop), std::move(new_step),
std::move(new_iter_args), std::move(new_body), std::move(new_return_vars), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const WhileStmtPtr& op)
{
std::vector<IterArgPtr> new_iter_args;
bool iter_args_changed = false;
new_iter_args.reserve(op->iterArgs_.size());
for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->iterArgs_[i]->initValue_, op->span_)
<< "WhileStmt has null iter_args initValue at index " << i;
auto new_init_value = ExprFunctor<ExprPtr>::VisitExpr(op->iterArgs_[i]->initValue_);
INTERNAL_CHECK_SPAN(new_init_value, op->span_) << "WhileStmt iter_args at index " << i << " mutated to null";
if (new_init_value.get() != op->iterArgs_[i]->initValue_.get()) {
new_iter_args.push_back(
std::make_shared<const IterArg>(op->iterArgs_[i]->iterVar_, std::move(new_init_value)));
iter_args_changed = true;
} else {
new_iter_args.push_back(op->iterArgs_[i]);
}
}
for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
if (new_iter_args[i].get() != op->iterArgs_[i].get()) {
var_remap_[op->iterArgs_[i]->iterVar_.get()] = new_iter_args[i]->iterVar_;
}
}
INTERNAL_CHECK_SPAN(op->condition_, op->span_) << "WhileStmt has null condition";
auto new_condition = ExprFunctor<ExprPtr>::VisitExpr(op->condition_);
INTERNAL_CHECK_SPAN(new_condition, op->span_) << "WhileStmt condition mutated to null";
bool condition_changed = (new_condition.get() != op->condition_.get());
INTERNAL_CHECK_SPAN(op->body_, op->span_) << "WhileStmt has null body";
auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
INTERNAL_CHECK_SPAN(new_body, op->span_) << "WhileStmt body mutated to null";
bool body_changed = (new_body.get() != op->body_.get());
for (const auto& old_iter_arg : op->iterArgs_) {
var_remap_.erase(old_iter_arg->iterVar_.get());
}
std::vector<VarPtr> new_return_vars;
bool return_vars_changed = false;
new_return_vars.reserve(op->returnVars_.size());
for (size_t i = 0; i < op->returnVars_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "WhileStmt has null return_vars at index " << i;
auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "WhileStmt return_vars at index " << i << " mutated to null";
auto new_var = AsVarLikeExpr(
new_var_expr, op->span_,
"WhileStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
new_return_vars.push_back(new_var);
if (new_var.get() != op->returnVars_[i].get()) {
return_vars_changed = true;
}
}
if (condition_changed || iter_args_changed || body_changed || return_vars_changed) {
return std::make_shared<const WhileStmt>(
std::move(new_condition), std::move(new_iter_args), std::move(new_body), std::move(new_return_vars),
op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const SeqStmtsPtr& op)
{
std::vector<StmtPtr> new_stmts;
bool changed = false;
new_stmts.reserve(op->stmts_.size());
for (size_t i = 0; i < op->stmts_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->stmts_[i], op->span_) << "SeqStmts has null statement at index " << i;
auto new_stmt = StmtFunctor<StmtPtr>::VisitStmt(op->stmts_[i]);
INTERNAL_CHECK_SPAN(new_stmt, op->span_) << "SeqStmts statement at index " << i << " mutated to null";
new_stmts.push_back(new_stmt);
if (new_stmt.get() != op->stmts_[i].get()) {
changed = true;
}
}
if (changed) {
return SeqStmts::Flatten(std::move(new_stmts), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const SectionStmtPtr& op)
{
INTERNAL_CHECK_SPAN(op->body_, op->span_) << "SectionStmt has null body";
auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
INTERNAL_CHECK_SPAN(new_body, op->span_) << "SectionStmt body mutated to null";
if (new_body.get() != op->body_.get()) {
return std::make_shared<const SectionStmt>(op->sectionKind_, std::move(new_body), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const EvalStmtPtr& op)
{
INTERNAL_CHECK_SPAN(op->expr_, op->span_) << "EvalStmt has null expr";
auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->expr_);
INTERNAL_CHECK_SPAN(new_expr, op->span_) << "EvalStmt expr mutated to null";
if (new_expr.get() != op->expr_.get()) {
return std::make_shared<const EvalStmt>(std::move(new_expr), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const BreakStmtPtr& op) { return op; }
StmtPtr IRMutator::VisitStmt_(const ContinueStmtPtr& op) { return op; }
StmtPtr IRMutator::VisitStmt_(const ScalarOpStmtPtr& op)
{
bool changed = false;
INTERNAL_CHECK_SPAN(op->result_, op->span_) << "ScalarOpStmt has null result";
auto new_result = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_));
if (new_result.get() != op->result_.get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_result, op->span_) << "ScalarOpStmt result mutated to null";
INTERNAL_CHECK_SPAN(op->result_token_, op->span_) << "ScalarOpStmt has null result_token";
auto new_token = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_token_));
if (new_token.get() != op->result_token_.get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_token, op->span_) << "ScalarOpStmt result_token mutated to null";
std::vector<ExprPtr> new_args;
for (size_t i = 0; i < op->args_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "ScalarOpStmt has null arg at index " << i;
auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
if (new_arg.get() != op->args_[i].get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_arg, op->span_) << "ScalarOpStmt arg at index " << i << " mutated to null";
new_args.push_back(new_arg);
}
if (changed) {
return std::make_shared<const ScalarOpStmt>(
std::move(new_result), std::move(new_token), std::move(op->opcode_), std::move(new_args), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const TensorOpStmtPtr& op)
{
bool changed = false;
std::vector<VarPtr> new_results;
for (size_t i = 0; i < op->result_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->result_[i], op->span_) << "TensorOpStmt has null result at index " << i;
auto new_result = ExprFunctor<ExprPtr>::VisitExpr(op->result_[i]);
if (new_result.get() != op->result_[i].get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_result, op->span_) << "TensorOpStmt result at index " << i << " mutated to null";
new_results.push_back(As<Var>(new_result));
}
INTERNAL_CHECK_SPAN(op->result_token_, op->span_) << "TensorOpStmt has null result_token";
auto new_token = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_token_));
if (new_token.get() != op->result_token_.get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_token, op->span_) << "TensorOpStmt result_token mutated to null";
std::vector<ExprPtr> new_args;
for (size_t i = 0; i < op->args_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "TensorOpStmt has null arg at index " << i;
auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
if (new_arg.get() != op->args_[i].get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_arg, op->span_) << "TensorOpStmt arg at index " << i << " mutated to null";
new_args.push_back(new_arg);
}
std::vector<VarPtr> new_tokens;
for (size_t i = 0; i < op->tokens_.size(); ++i) {
INTERNAL_CHECK_SPAN(op->tokens_[i], op->span_) << "TensorOpStmt has null token at index " << i;
auto new_tok = ExprFunctor<ExprPtr>::VisitExpr(op->tokens_[i]);
if (new_tok.get() != op->tokens_[i].get()) {
changed = true;
}
INTERNAL_CHECK_SPAN(new_tok, op->span_) << "TensorOpStmt token at index " << i << " mutated to null";
new_tokens.push_back(As<Var>(new_tok));
}
if (changed) {
return std::make_shared<const TensorOpStmt>(
std::move(new_results), std::move(new_token), std::move(op->opcode_), std::move(new_args),
std::move(new_tokens), std::move(op->attrs_), op->span_);
}
return op;
}
StmtPtr IRMutator::VisitStmt_(const StmtPtr& op) { return op; }
}
}