* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef PYPTO_IR_TRANSFORMS_BASE_MUTATOR_H_
#define PYPTO_IR_TRANSFORMS_BASE_MUTATOR_H_
#include <unordered_map>
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/memref.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/stmt.h"
#include "ir/transforms/base/functor.h"
namespace pypto {
namespace ir {
* \brief IR mutator for immutable transformations
*
* Provides default implementations that recursively transform the IR tree.
* Returns new ExprPtr or StmtPtr for transformed IR nodes, respecting immutability.
* Uses copy-on-write: if children are unchanged, returns the original shared_ptr.
*/
class IRMutator : public ExprFunctor<ExprPtr>, public StmtFunctor<StmtPtr> {
public:
~IRMutator() override = default;
virtual ProgramPtr VisitProgram(const ProgramPtr& program);
virtual FunctionPtr VisitFunction(const FunctionPtr& func);
ExprPtr VisitExpr(const ExprPtr& expr) override;
StmtPtr VisitStmt(const StmtPtr& stmt) override;
protected:
ExprPtr VisitExpr_(const VarPtr& op) override;
ExprPtr VisitExpr_(const MemRefPtr& op) override;
ExprPtr VisitExpr_(const ConstIntPtr& op) override;
ExprPtr VisitExpr_(const ConstFloatPtr& op) override;
ExprPtr VisitExpr_(const ConstBoolPtr& op) override;
ExprPtr VisitExpr_(const CallPtr& op) override;
ExprPtr VisitExpr_(const MakeTuplePtr& op) override;
ExprPtr VisitExpr_(const ScalarExprPtr& op) override;
ExprPtr VisitExpr_(const GetItemExprPtr& op) override;
ExprPtr VisitExpr_(const AddPtr& op) override;
ExprPtr VisitExpr_(const SubPtr& op) override;
ExprPtr VisitExpr_(const MulPtr& op) override;
ExprPtr VisitExpr_(const FloorDivPtr& op) override;
ExprPtr VisitExpr_(const FloorModPtr& op) override;
ExprPtr VisitExpr_(const FloatDivPtr& op) override;
ExprPtr VisitExpr_(const MinPtr& op) override;
ExprPtr VisitExpr_(const MaxPtr& op) override;
ExprPtr VisitExpr_(const PowPtr& op) override;
ExprPtr VisitExpr_(const EqPtr& op) override;
ExprPtr VisitExpr_(const NePtr& op) override;
ExprPtr VisitExpr_(const LtPtr& op) override;
ExprPtr VisitExpr_(const LePtr& op) override;
ExprPtr VisitExpr_(const GtPtr& op) override;
ExprPtr VisitExpr_(const GePtr& op) override;
ExprPtr VisitExpr_(const AndPtr& op) override;
ExprPtr VisitExpr_(const OrPtr& op) override;
ExprPtr VisitExpr_(const XorPtr& op) override;
ExprPtr VisitExpr_(const BitAndPtr& op) override;
ExprPtr VisitExpr_(const BitOrPtr& op) override;
ExprPtr VisitExpr_(const BitXorPtr& op) override;
ExprPtr VisitExpr_(const BitShiftLeftPtr& op) override;
ExprPtr VisitExpr_(const BitShiftRightPtr& op) override;
ExprPtr VisitExpr_(const AbsPtr& op) override;
ExprPtr VisitExpr_(const NegPtr& op) override;
ExprPtr VisitExpr_(const NotPtr& op) override;
ExprPtr VisitExpr_(const BitNotPtr& op) override;
ExprPtr VisitExpr_(const CastPtr& op) override;
StmtPtr VisitStmt_(const AssignStmtPtr& op) override;
StmtPtr VisitStmt_(const IfStmtPtr& op) override;
StmtPtr VisitStmt_(const YieldStmtPtr& op) override;
StmtPtr VisitStmt_(const ReturnStmtPtr& op) override;
StmtPtr VisitStmt_(const ForStmtPtr& op) override;
StmtPtr VisitStmt_(const WhileStmtPtr& op) override;
StmtPtr VisitStmt_(const SeqStmtsPtr& op) override;
StmtPtr VisitStmt_(const SectionStmtPtr& op) override;
StmtPtr VisitStmt_(const EvalStmtPtr& op) override;
StmtPtr VisitStmt_(const BreakStmtPtr& op) override;
StmtPtr VisitStmt_(const ContinueStmtPtr& op) override;
StmtPtr VisitStmt_(const ScalarOpStmtPtr& op) override;
StmtPtr VisitStmt_(const TensorOpStmtPtr& op) override;
StmtPtr VisitStmt_(const StmtPtr& op) override;
virtual ExprPtr VisitBinaryExpr_(const BinaryExprPtr& op);
virtual ExprPtr VisitUnaryExpr_(const UnaryExprPtr& op);
std::unordered_map<const Expr*, ExprPtr> var_remap_;
};
}
}
#endif