/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ir/transforms/base/mutator.h"

#include <any>
#include <cstddef>
#include <map>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "core/dtype.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/kind_traits.h"
#include "ir/memref.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/span.h"
#include "ir/stmt.h"
#include "ir/transforms/base/functor.h"
#include "ir/type.h"

namespace pypto {
namespace ir {

namespace {

/// Reconstruct a binary expression with new children, preserving the concrete type.
/// All binary ops share the constructor signature (ExprPtr, ExprPtr, DataType, Span).
ExprPtr ReconstructBinaryExpr(ObjectKind kind, ExprPtr left, ExprPtr right, DataType dtype, const Span& span)
{
    switch (kind) {
        case ObjectKind::Add:
            return std::make_shared<const Add>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Sub:
            return std::make_shared<const Sub>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Mul:
            return std::make_shared<const Mul>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::FloorDiv:
            return std::make_shared<const FloorDiv>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::FloorMod:
            return std::make_shared<const FloorMod>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::FloatDiv:
            return std::make_shared<const FloatDiv>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Min:
            return std::make_shared<const Min>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Max:
            return std::make_shared<const Max>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Pow:
            return std::make_shared<const Pow>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Eq:
            return std::make_shared<const Eq>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Ne:
            return std::make_shared<const Ne>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Lt:
            return std::make_shared<const Lt>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Le:
            return std::make_shared<const Le>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Gt:
            return std::make_shared<const Gt>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Ge:
            return std::make_shared<const Ge>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::And:
            return std::make_shared<const And>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Or:
            return std::make_shared<const Or>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::Xor:
            return std::make_shared<const Xor>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::BitAnd:
            return std::make_shared<const BitAnd>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::BitOr:
            return std::make_shared<const BitOr>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::BitXor:
            return std::make_shared<const BitXor>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::BitShiftLeft:
            return std::make_shared<const BitShiftLeft>(std::move(left), std::move(right), dtype, span);
        case ObjectKind::BitShiftRight:
            return std::make_shared<const BitShiftRight>(std::move(left), std::move(right), dtype, span);
        default:
            INTERNAL_CHECK_SPAN(false, span) << "Unknown binary expression kind in ReconstructBinaryExpr";
    }
}

/// Reconstruct a unary expression with a new operand, preserving the concrete type.
/// All unary ops share the constructor signature (ExprPtr, DataType, Span).
ExprPtr ReconstructUnaryExpr(ObjectKind kind, ExprPtr operand, DataType dtype, const Span& span)
{
    switch (kind) {
        case ObjectKind::Abs:
            return std::make_shared<const Abs>(std::move(operand), dtype, span);
        case ObjectKind::Neg:
            return std::make_shared<const Neg>(std::move(operand), dtype, span);
        case ObjectKind::Not:
            return std::make_shared<const Not>(std::move(operand), dtype, span);
        case ObjectKind::BitNot:
            return std::make_shared<const BitNot>(std::move(operand), dtype, span);
        case ObjectKind::Cast:
            return std::make_shared<const Cast>(std::move(operand), dtype, span);
        default:
            INTERNAL_CHECK_SPAN(false, span) << "Unknown unary expression kind in ReconstructUnaryExpr";
    }
}

/// Cast an expression back to VarPtr while preserving MemRef instances, which are valid Var subclasses.
VarPtr AsVarLikeExpr(const ExprPtr& expr, const Span& span, const std::string& context)
{
    if (auto var = As<Var>(expr)) {
        return var;
    }
    if (auto memref = As<MemRef>(expr)) {
        return std::static_pointer_cast<const Var>(memref);
    }
    INTERNAL_CHECK_SPAN(false, span) << context;
    return nullptr;
}

/// Rebuild a Call while preserving kwargs_.
ExprPtr ReconstructCallWithKwargs(
    const std::string& name, std::vector<ExprPtr> args, const std::vector<std::pair<std::string, std::any>>& kwargs,
    const TypePtr& type, const Span& span)
{
    return std::make_shared<Call>(name, std::move(args), kwargs, type, span);
}

} // namespace

// Top-level entry points
ProgramPtr IRMutator::VisitProgram(const ProgramPtr& program)
{
    std::vector<FunctionPtr> new_functions;
    bool changed = false;
    for (const auto& entry : program->functions_) {
        auto new_func = VisitFunction(entry.second);
        new_functions.emplace_back(new_func);
        if (new_func != entry.second) {
            changed = true;
        }
    }
    if (!changed) {
        return program;
    }
    return std::make_shared<const Program>(std::move(new_functions), program->name_, program->span_);
}

FunctionPtr IRMutator::VisitFunction(const FunctionPtr& func)
{
    auto new_body = VisitStmt(func->body_);
    if (new_body.get() == func->body_.get()) {
        return func;
    }
    return std::make_shared<const Function>(
        func->name_, func->params_, func->returnTypes_, std::move(new_body), func->span_, func->funcType_);
}

ExprPtr IRMutator::VisitExpr(const ExprPtr& expr) { return ExprFunctor<ExprPtr>::VisitExpr(expr); }

StmtPtr IRMutator::VisitStmt(const StmtPtr& stmt) { return StmtFunctor<StmtPtr>::VisitStmt(stmt); }

ExprPtr IRMutator::VisitExpr_(const VarPtr& op)
{
    auto it = var_remap_.find(op.get());
    if (it != var_remap_.end()) {
        return it->second;
    }
    return op;
}

ExprPtr IRMutator::VisitExpr_(const MemRefPtr& op) { return op; }

ExprPtr IRMutator::VisitExpr_(const ConstIntPtr& op) { return op; }

ExprPtr IRMutator::VisitExpr_(const ConstFloatPtr& op) { return op; }

ExprPtr IRMutator::VisitExpr_(const ConstBoolPtr& op) { return op; }

ExprPtr IRMutator::VisitExpr_(const CallPtr& op)
{
    std::vector<ExprPtr> new_args;
    bool changed = false;
    new_args.reserve(op->args_.size());

    for (size_t i = 0; i < op->args_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "Call has null argument at index " << i;
        auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
        INTERNAL_CHECK_SPAN(new_arg, op->span_) << "Call argument at index " << i << " mutated to null";
        new_args.push_back(new_arg);
        if (new_arg.get() != op->args_[i].get()) {
            changed = true;
        }
    }

    if (changed) {
        return ReconstructCallWithKwargs(op->name_, std::move(new_args), op->kwargs_, op->GetType(), op->span_);
    }
    return op;
}

ExprPtr IRMutator::VisitExpr_(const MakeTuplePtr& op)
{
    std::vector<ExprPtr> new_elements;
    new_elements.reserve(op->elements_.size());
    bool changed = false;

    for (const auto& elem : op->elements_) {
        INTERNAL_CHECK_SPAN(elem, op->span_) << "MakeTuple has null element";
        auto new_elem = ExprFunctor<ExprPtr>::VisitExpr(elem);
        INTERNAL_CHECK_SPAN(new_elem, op->span_) << "MakeTuple element mutated to null";
        new_elements.push_back(new_elem);
        if (new_elem.get() != elem.get()) {
            changed = true;
        }
    }

    if (changed) {
        return std::make_shared<const MakeTuple>(std::move(new_elements), op->span_);
    }
    return op;
}

ExprPtr IRMutator::VisitExpr_(const GetItemExprPtr& op)
{
    INTERNAL_CHECK_SPAN(op->value_, op->span_) << "GetItemExpr has null value";
    INTERNAL_CHECK_SPAN(op->slice_, op->span_) << "GetItemExpr has null slice";
    auto new_value = ExprFunctor<ExprPtr>::VisitExpr(op->value_);
    auto new_slice = ExprFunctor<ExprPtr>::VisitExpr(op->slice_);
    INTERNAL_CHECK_SPAN(new_value, op->span_) << "GetItemExpr value mutated to null";
    INTERNAL_CHECK_SPAN(new_slice, op->span_) << "GetItemExpr slice mutated to null";

    if (new_value.get() != op->value_.get() || new_slice.get() != op->slice_.get()) {
        return std::make_shared<const GetItemExpr>(std::move(new_value), std::move(new_slice), op->span_);
    }
    return op;
}

ExprPtr IRMutator::VisitExpr_(const ScalarExprPtr& op) { return op; }

ExprPtr IRMutator::VisitBinaryExpr_(const BinaryExprPtr& op)
{
    INTERNAL_CHECK_SPAN(op->left_, op->span_) << "BinaryExpr has null left operand";
    INTERNAL_CHECK_SPAN(op->right_, op->span_) << "BinaryExpr has null right operand";
    auto new_left = ExprFunctor<ExprPtr>::VisitExpr(op->left_);
    auto new_right = ExprFunctor<ExprPtr>::VisitExpr(op->right_);
    INTERNAL_CHECK_SPAN(new_left, op->span_) << "BinaryExpr left operand mutated to null";
    INTERNAL_CHECK_SPAN(new_right, op->span_) << "BinaryExpr right operand mutated to null";
    if (new_left.get() != op->left_.get() || new_right.get() != op->right_.get()) {
        auto scalar_type = As<ScalarType>(op->GetType());
        INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "BinaryExpr has null type";
        return ReconstructBinaryExpr(
            op->GetKind(), std::move(new_left), std::move(new_right), scalar_type->dtype_, op->span_);
    }
    return op;
}

ExprPtr IRMutator::VisitUnaryExpr_(const UnaryExprPtr& op)
{
    INTERNAL_CHECK_SPAN(op->operand_, op->span_) << "UnaryExpr has null operand";
    auto new_operand = ExprFunctor<ExprPtr>::VisitExpr(op->operand_);
    INTERNAL_CHECK_SPAN(new_operand, op->span_) << "UnaryExpr operand mutated to null";
    if (new_operand.get() != op->operand_.get()) {
        auto scalar_type = As<ScalarType>(op->GetType());
        INTERNAL_CHECK_SPAN(scalar_type, op->span_) << "UnaryExpr has null type";
        return ReconstructUnaryExpr(op->GetKind(), std::move(new_operand), scalar_type->dtype_, op->span_);
    }
    return op;
}

#define DEFINE_BINARY_MUTATOR(OpType) \
    ExprPtr IRMutator::VisitExpr_(const OpType##Ptr& op) { return VisitBinaryExpr_(op); }

DEFINE_BINARY_MUTATOR(Add)
DEFINE_BINARY_MUTATOR(Sub)
DEFINE_BINARY_MUTATOR(Mul)
DEFINE_BINARY_MUTATOR(FloorDiv)
DEFINE_BINARY_MUTATOR(FloorMod)
DEFINE_BINARY_MUTATOR(FloatDiv)
DEFINE_BINARY_MUTATOR(Min)
DEFINE_BINARY_MUTATOR(Max)
DEFINE_BINARY_MUTATOR(Pow)
DEFINE_BINARY_MUTATOR(Eq)
DEFINE_BINARY_MUTATOR(Ne)
DEFINE_BINARY_MUTATOR(Lt)
DEFINE_BINARY_MUTATOR(Le)
DEFINE_BINARY_MUTATOR(Gt)
DEFINE_BINARY_MUTATOR(Ge)
DEFINE_BINARY_MUTATOR(And)
DEFINE_BINARY_MUTATOR(Or)
DEFINE_BINARY_MUTATOR(Xor)
DEFINE_BINARY_MUTATOR(BitAnd)
DEFINE_BINARY_MUTATOR(BitOr)
DEFINE_BINARY_MUTATOR(BitXor)
DEFINE_BINARY_MUTATOR(BitShiftLeft)
DEFINE_BINARY_MUTATOR(BitShiftRight)

#undef DEFINE_BINARY_MUTATOR

#define DEFINE_UNARY_MUTATOR(OpType) \
    ExprPtr IRMutator::VisitExpr_(const OpType##Ptr& op) { return VisitUnaryExpr_(op); }

DEFINE_UNARY_MUTATOR(Abs)
DEFINE_UNARY_MUTATOR(Neg)
DEFINE_UNARY_MUTATOR(Not)
DEFINE_UNARY_MUTATOR(BitNot)
DEFINE_UNARY_MUTATOR(Cast)

#undef DEFINE_UNARY_MUTATOR

StmtPtr IRMutator::VisitStmt_(const AssignStmtPtr& op)
{
    INTERNAL_CHECK_SPAN(op->var_, op->span_) << "AssignStmt has null var";
    INTERNAL_CHECK_SPAN(op->value_, op->span_) << "AssignStmt has null value";
    auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->var_);
    auto new_value = ExprFunctor<ExprPtr>::VisitExpr(op->value_);
    INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "AssignStmt var mutated to null";
    INTERNAL_CHECK_SPAN(new_value, op->span_) << "AssignStmt value mutated to null";

    auto new_var = AsVarLikeExpr(new_var_expr, op->span_, "AssignStmt var is not a Var after mutation");
    if (new_var.get() != op->var_.get() || new_value.get() != op->value_.get()) {
        return std::make_shared<const AssignStmt>(std::move(new_var), std::move(new_value), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const IfStmtPtr& op)
{
    INTERNAL_CHECK_SPAN(op->condition_, op->span_) << "IfStmt has null condition";
    auto new_condition = ExprFunctor<ExprPtr>::VisitExpr(op->condition_);
    INTERNAL_CHECK_SPAN(new_condition, op->span_) << "IfStmt condition mutated to null";

    INTERNAL_CHECK_SPAN(op->thenBody_, op->span_) << "IfStmt has null then_body";
    auto new_then_body = StmtFunctor<StmtPtr>::VisitStmt(op->thenBody_);
    INTERNAL_CHECK_SPAN(new_then_body, op->span_) << "IfStmt then_body mutated to null";
    bool then_changed = (new_then_body.get() != op->thenBody_.get());

    std::optional<StmtPtr> new_else_body;
    bool else_changed = false;
    if (op->elseBody_.has_value()) {
        INTERNAL_CHECK_SPAN(*op->elseBody_, op->span_) << "IfStmt has null else_body";
        auto new_stmt = StmtFunctor<StmtPtr>::VisitStmt(*op->elseBody_);
        INTERNAL_CHECK_SPAN(new_stmt, op->span_) << "IfStmt else_body mutated to null";
        new_else_body = new_stmt;
        if (new_stmt.get() != op->elseBody_->get()) {
            else_changed = true;
        }
    }

    std::vector<VarPtr> new_return_vars;
    bool return_vars_changed = false;
    new_return_vars.reserve(op->returnVars_.size());
    for (size_t i = 0; i < op->returnVars_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "IfStmt has null return_vars at index " << i;
        auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
        INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "IfStmt return_vars at index " << i << " mutated to null";
        auto new_var = AsVarLikeExpr(
            new_var_expr, op->span_,
            "IfStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
        new_return_vars.push_back(new_var);
        if (new_var.get() != op->returnVars_[i].get()) {
            return_vars_changed = true;
        }
    }

    if (new_condition.get() != op->condition_.get() || then_changed || else_changed || return_vars_changed) {
        if (new_else_body.has_value()) {
            return std::make_shared<const IfStmt>(
                std::move(new_condition), std::move(new_then_body), *new_else_body, std::move(new_return_vars),
                op->span_);
        } else {
            return std::make_shared<const IfStmt>(
                std::move(new_condition), std::move(new_then_body), std::nullopt, std::move(new_return_vars),
                op->span_);
        }
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const YieldStmtPtr& op)
{
    std::vector<ExprPtr> new_value;
    bool changed = false;
    new_value.reserve(op->value_.size());

    for (size_t i = 0; i < op->value_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->value_[i], op->span_) << "YieldStmt has null value at index " << i;
        auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->value_[i]);
        INTERNAL_CHECK_SPAN(new_expr, op->span_) << "YieldStmt value at index " << i << " mutated to null";
        new_value.push_back(new_expr);
        if (new_expr.get() != op->value_[i].get()) {
            changed = true;
        }
    }

    if (changed) {
        return std::make_shared<const YieldStmt>(std::move(new_value), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const ReturnStmtPtr& op)
{
    std::vector<ExprPtr> new_value;
    bool changed = false;
    new_value.reserve(op->value_.size());

    for (size_t i = 0; i < op->value_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->value_[i], op->span_) << "ReturnStmt has null value at index " << i;
        auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->value_[i]);
        INTERNAL_CHECK_SPAN(new_expr, op->span_) << "ReturnStmt value at index " << i << " mutated to null";
        new_value.push_back(new_expr);
        if (new_expr.get() != op->value_[i].get()) {
            changed = true;
        }
    }

    if (changed) {
        return std::make_shared<const ReturnStmt>(std::move(new_value), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const ForStmtPtr& op)
{
    INTERNAL_CHECK_SPAN(op->loopVar_, op->span_) << "ForStmt has null loop_var";
    INTERNAL_CHECK_SPAN(op->start_, op->span_) << "ForStmt has null start";
    INTERNAL_CHECK_SPAN(op->stop_, op->span_) << "ForStmt has null stop";
    INTERNAL_CHECK_SPAN(op->step_, op->span_) << "ForStmt has null step";
    auto new_loop_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->loopVar_);
    INTERNAL_CHECK_SPAN(new_loop_var_expr, op->span_) << "ForStmt loop_var mutated to null";
    auto new_loop_var = AsVarLikeExpr(new_loop_var_expr, op->span_, "ForStmt loop_var is not a Var after mutation");

    auto new_start = ExprFunctor<ExprPtr>::VisitExpr(op->start_);
    INTERNAL_CHECK_SPAN(new_start, op->span_) << "ForStmt start mutated to null";

    auto new_stop = ExprFunctor<ExprPtr>::VisitExpr(op->stop_);
    INTERNAL_CHECK_SPAN(new_stop, op->span_) << "ForStmt stop mutated to null";

    auto new_step = ExprFunctor<ExprPtr>::VisitExpr(op->step_);
    INTERNAL_CHECK_SPAN(new_step, op->span_) << "ForStmt step mutated to null";

    std::vector<IterArgPtr> new_iter_args;
    bool iter_args_changed = false;
    new_iter_args.reserve(op->iterArgs_.size());
    for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->iterArgs_[i]->initValue_, op->span_)
            << "ForStmt has null iter_args initValue at index " << i;
        auto new_init_value = ExprFunctor<ExprPtr>::VisitExpr(op->iterArgs_[i]->initValue_);
        INTERNAL_CHECK_SPAN(new_init_value, op->span_) << "ForStmt iter_args at index " << i << " mutated to null";
        if (new_init_value.get() != op->iterArgs_[i]->initValue_.get()) {
            new_iter_args.push_back(
                std::make_shared<const IterArg>(op->iterArgs_[i]->iterVar_, std::move(new_init_value)));
            iter_args_changed = true;
        } else {
            new_iter_args.push_back(op->iterArgs_[i]);
        }
    }

    // Register old→new IterArg var mappings so body references are substituted
    for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
        if (new_iter_args[i].get() != op->iterArgs_[i].get()) {
            var_remap_[op->iterArgs_[i]->iterVar_.get()] = new_iter_args[i]->iterVar_;
        }
    }

    INTERNAL_CHECK_SPAN(op->body_, op->span_) << "ForStmt has null body";
    auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
    INTERNAL_CHECK_SPAN(new_body, op->span_) << "ForStmt body mutated to null";
    bool body_changed = (new_body.get() != op->body_.get());

    // Clean up IterArg var remappings.
    for (const auto& old_iter_arg : op->iterArgs_) {
        var_remap_.erase(old_iter_arg->iterVar_.get());
    }

    std::vector<VarPtr> new_return_vars;
    bool return_vars_changed = false;
    new_return_vars.reserve(op->returnVars_.size());
    for (size_t i = 0; i < op->returnVars_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "ForStmt has null return_vars at index " << i;
        auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
        INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "ForStmt return_vars at index " << i << " mutated to null";
        auto new_var = AsVarLikeExpr(
            new_var_expr, op->span_,
            "ForStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
        new_return_vars.push_back(new_var);
        if (new_var.get() != op->returnVars_[i].get()) {
            return_vars_changed = true;
        }
    }

    if (new_loop_var.get() != op->loopVar_.get() || new_start.get() != op->start_.get() ||
        new_stop.get() != op->stop_.get() || new_step.get() != op->step_.get() || iter_args_changed || body_changed ||
        return_vars_changed) {
        return std::make_shared<const ForStmt>(
            std::move(new_loop_var), std::move(new_start), std::move(new_stop), std::move(new_step),
            std::move(new_iter_args), std::move(new_body), std::move(new_return_vars), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const WhileStmtPtr& op)
{
    // Visit iter_args first (definitions), before condition and body (uses).
    std::vector<IterArgPtr> new_iter_args;
    bool iter_args_changed = false;
    new_iter_args.reserve(op->iterArgs_.size());
    for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->iterArgs_[i]->initValue_, op->span_)
            << "WhileStmt has null iter_args initValue at index " << i;
        auto new_init_value = ExprFunctor<ExprPtr>::VisitExpr(op->iterArgs_[i]->initValue_);
        INTERNAL_CHECK_SPAN(new_init_value, op->span_) << "WhileStmt iter_args at index " << i << " mutated to null";
        if (new_init_value.get() != op->iterArgs_[i]->initValue_.get()) {
            new_iter_args.push_back(
                std::make_shared<const IterArg>(op->iterArgs_[i]->iterVar_, std::move(new_init_value)));
            iter_args_changed = true;
        } else {
            new_iter_args.push_back(op->iterArgs_[i]);
        }
    }

    // Register old→new IterArg var mappings so condition and body references are substituted
    for (size_t i = 0; i < op->iterArgs_.size(); ++i) {
        if (new_iter_args[i].get() != op->iterArgs_[i].get()) {
            var_remap_[op->iterArgs_[i]->iterVar_.get()] = new_iter_args[i]->iterVar_;
        }
    }

    INTERNAL_CHECK_SPAN(op->condition_, op->span_) << "WhileStmt has null condition";
    auto new_condition = ExprFunctor<ExprPtr>::VisitExpr(op->condition_);
    INTERNAL_CHECK_SPAN(new_condition, op->span_) << "WhileStmt condition mutated to null";
    bool condition_changed = (new_condition.get() != op->condition_.get());

    INTERNAL_CHECK_SPAN(op->body_, op->span_) << "WhileStmt has null body";
    auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
    INTERNAL_CHECK_SPAN(new_body, op->span_) << "WhileStmt body mutated to null";
    bool body_changed = (new_body.get() != op->body_.get());

    // Clean up IterArg var remappings.
    for (const auto& old_iter_arg : op->iterArgs_) {
        var_remap_.erase(old_iter_arg->iterVar_.get());
    }

    std::vector<VarPtr> new_return_vars;
    bool return_vars_changed = false;
    new_return_vars.reserve(op->returnVars_.size());
    for (size_t i = 0; i < op->returnVars_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->returnVars_[i], op->span_) << "WhileStmt has null return_vars at index " << i;
        auto new_var_expr = ExprFunctor<ExprPtr>::VisitExpr(op->returnVars_[i]);
        INTERNAL_CHECK_SPAN(new_var_expr, op->span_) << "WhileStmt return_vars at index " << i << " mutated to null";
        auto new_var = AsVarLikeExpr(
            new_var_expr, op->span_,
            "WhileStmt return_vars at index " + std::to_string(i) + " is not a Var after mutation");
        new_return_vars.push_back(new_var);
        if (new_var.get() != op->returnVars_[i].get()) {
            return_vars_changed = true;
        }
    }

    if (condition_changed || iter_args_changed || body_changed || return_vars_changed) {
        return std::make_shared<const WhileStmt>(
            std::move(new_condition), std::move(new_iter_args), std::move(new_body), std::move(new_return_vars),
            op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const SeqStmtsPtr& op)
{
    std::vector<StmtPtr> new_stmts;
    bool changed = false;
    new_stmts.reserve(op->stmts_.size());
    for (size_t i = 0; i < op->stmts_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->stmts_[i], op->span_) << "SeqStmts has null statement at index " << i;
        auto new_stmt = StmtFunctor<StmtPtr>::VisitStmt(op->stmts_[i]);
        INTERNAL_CHECK_SPAN(new_stmt, op->span_) << "SeqStmts statement at index " << i << " mutated to null";
        new_stmts.push_back(new_stmt);
        if (new_stmt.get() != op->stmts_[i].get()) {
            changed = true;
        }
    }

    if (changed) {
        return SeqStmts::Flatten(std::move(new_stmts), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const SectionStmtPtr& op)
{
    INTERNAL_CHECK_SPAN(op->body_, op->span_) << "SectionStmt has null body";
    auto new_body = StmtFunctor<StmtPtr>::VisitStmt(op->body_);
    INTERNAL_CHECK_SPAN(new_body, op->span_) << "SectionStmt body mutated to null";

    if (new_body.get() != op->body_.get()) {
        return std::make_shared<const SectionStmt>(op->sectionKind_, std::move(new_body), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const EvalStmtPtr& op)
{
    INTERNAL_CHECK_SPAN(op->expr_, op->span_) << "EvalStmt has null expr";
    auto new_expr = ExprFunctor<ExprPtr>::VisitExpr(op->expr_);
    INTERNAL_CHECK_SPAN(new_expr, op->span_) << "EvalStmt expr mutated to null";

    if (new_expr.get() != op->expr_.get()) {
        return std::make_shared<const EvalStmt>(std::move(new_expr), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const BreakStmtPtr& op) { return op; }

StmtPtr IRMutator::VisitStmt_(const ContinueStmtPtr& op) { return op; }

StmtPtr IRMutator::VisitStmt_(const ScalarOpStmtPtr& op)
{
    bool changed = false;

    INTERNAL_CHECK_SPAN(op->result_, op->span_) << "ScalarOpStmt has null result";
    auto new_result = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_));
    if (new_result.get() != op->result_.get()) {
        changed = true;
    }
    INTERNAL_CHECK_SPAN(new_result, op->span_) << "ScalarOpStmt result mutated to null";

    INTERNAL_CHECK_SPAN(op->result_token_, op->span_) << "ScalarOpStmt has null result_token";
    auto new_token = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_token_));
    if (new_token.get() != op->result_token_.get()) {
        changed = true;
    }
    INTERNAL_CHECK_SPAN(new_token, op->span_) << "ScalarOpStmt result_token mutated to null";
    std::vector<ExprPtr> new_args;
    for (size_t i = 0; i < op->args_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "ScalarOpStmt has null arg at index " << i;
        auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
        if (new_arg.get() != op->args_[i].get()) {
            changed = true;
        }
        INTERNAL_CHECK_SPAN(new_arg, op->span_) << "ScalarOpStmt arg at index " << i << " mutated to null";
        new_args.push_back(new_arg);
    }
    if (changed) {
        return std::make_shared<const ScalarOpStmt>(
            std::move(new_result), std::move(new_token), std::move(op->opcode_), std::move(new_args), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const TensorOpStmtPtr& op)
{
    bool changed = false;

    std::vector<VarPtr> new_results;
    for (size_t i = 0; i < op->result_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->result_[i], op->span_) << "TensorOpStmt has null result at index " << i;
        auto new_result = ExprFunctor<ExprPtr>::VisitExpr(op->result_[i]);
        if (new_result.get() != op->result_[i].get()) {
            changed = true;
        }
        INTERNAL_CHECK_SPAN(new_result, op->span_) << "TensorOpStmt result at index " << i << " mutated to null";
        new_results.push_back(As<Var>(new_result));
    }

    INTERNAL_CHECK_SPAN(op->result_token_, op->span_) << "TensorOpStmt has null result_token";
    auto new_token = As<Var>(ExprFunctor<ExprPtr>::VisitExpr(op->result_token_));
    if (new_token.get() != op->result_token_.get()) {
        changed = true;
    }
    INTERNAL_CHECK_SPAN(new_token, op->span_) << "TensorOpStmt result_token mutated to null";

    std::vector<ExprPtr> new_args;
    for (size_t i = 0; i < op->args_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->args_[i], op->span_) << "TensorOpStmt has null arg at index " << i;
        auto new_arg = ExprFunctor<ExprPtr>::VisitExpr(op->args_[i]);
        if (new_arg.get() != op->args_[i].get()) {
            changed = true;
        }
        INTERNAL_CHECK_SPAN(new_arg, op->span_) << "TensorOpStmt arg at index " << i << " mutated to null";
        new_args.push_back(new_arg);
    }

    std::vector<VarPtr> new_tokens;
    for (size_t i = 0; i < op->tokens_.size(); ++i) {
        INTERNAL_CHECK_SPAN(op->tokens_[i], op->span_) << "TensorOpStmt has null token at index " << i;
        auto new_tok = ExprFunctor<ExprPtr>::VisitExpr(op->tokens_[i]);
        if (new_tok.get() != op->tokens_[i].get()) {
            changed = true;
        }
        INTERNAL_CHECK_SPAN(new_tok, op->span_) << "TensorOpStmt token at index " << i << " mutated to null";
        new_tokens.push_back(As<Var>(new_tok));
    }

    if (changed) {
        return std::make_shared<const TensorOpStmt>(
            std::move(new_results), std::move(new_token), std::move(op->opcode_), std::move(new_args),
            std::move(new_tokens), std::move(op->attrs_), op->span_);
    }
    return op;
}

StmtPtr IRMutator::VisitStmt_(const StmtPtr& op) { return op; }

} // namespace ir
} // namespace pypto