* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <memory>
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/program.h"
#include "ir/scalar_expr.h"
#include "ir/stmt.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
#define DEFINE_KIND_TRAIT(TypeName, KindValue) \
template <> \
struct KindTrait<TypeName> { \
static constexpr ObjectKind kind = KindValue; \
}
DEFINE_KIND_TRAIT(Var, ObjectKind::Var);
DEFINE_KIND_TRAIT(MemRef, ObjectKind::MemRef);
DEFINE_KIND_TRAIT(Call, ObjectKind::Call);
DEFINE_KIND_TRAIT(MakeTuple, ObjectKind::MakeTuple);
DEFINE_KIND_TRAIT(GetItemExpr, ObjectKind::GetItemExpr);
DEFINE_KIND_TRAIT(ConstInt, ObjectKind::ConstInt);
DEFINE_KIND_TRAIT(ConstFloat, ObjectKind::ConstFloat);
DEFINE_KIND_TRAIT(ConstBool, ObjectKind::ConstBool);
DEFINE_KIND_TRAIT(Add, ObjectKind::Add);
DEFINE_KIND_TRAIT(Sub, ObjectKind::Sub);
DEFINE_KIND_TRAIT(Mul, ObjectKind::Mul);
DEFINE_KIND_TRAIT(FloorDiv, ObjectKind::FloorDiv);
DEFINE_KIND_TRAIT(FloorMod, ObjectKind::FloorMod);
DEFINE_KIND_TRAIT(FloatDiv, ObjectKind::FloatDiv);
DEFINE_KIND_TRAIT(Min, ObjectKind::Min);
DEFINE_KIND_TRAIT(Max, ObjectKind::Max);
DEFINE_KIND_TRAIT(Pow, ObjectKind::Pow);
DEFINE_KIND_TRAIT(Eq, ObjectKind::Eq);
DEFINE_KIND_TRAIT(Ne, ObjectKind::Ne);
DEFINE_KIND_TRAIT(Lt, ObjectKind::Lt);
DEFINE_KIND_TRAIT(Le, ObjectKind::Le);
DEFINE_KIND_TRAIT(Gt, ObjectKind::Gt);
DEFINE_KIND_TRAIT(Ge, ObjectKind::Ge);
DEFINE_KIND_TRAIT(And, ObjectKind::And);
DEFINE_KIND_TRAIT(Or, ObjectKind::Or);
DEFINE_KIND_TRAIT(Xor, ObjectKind::Xor);
DEFINE_KIND_TRAIT(BitAnd, ObjectKind::BitAnd);
DEFINE_KIND_TRAIT(BitOr, ObjectKind::BitOr);
DEFINE_KIND_TRAIT(BitXor, ObjectKind::BitXor);
DEFINE_KIND_TRAIT(BitShiftLeft, ObjectKind::BitShiftLeft);
DEFINE_KIND_TRAIT(BitShiftRight, ObjectKind::BitShiftRight);
DEFINE_KIND_TRAIT(Abs, ObjectKind::Abs);
DEFINE_KIND_TRAIT(Neg, ObjectKind::Neg);
DEFINE_KIND_TRAIT(Not, ObjectKind::Not);
DEFINE_KIND_TRAIT(BitNot, ObjectKind::BitNot);
DEFINE_KIND_TRAIT(Cast, ObjectKind::Cast);
DEFINE_KIND_TRAIT(ScalarExpr, ObjectKind::ScalarExpr);
DEFINE_KIND_TRAIT(AssignStmt, ObjectKind::AssignStmt);
DEFINE_KIND_TRAIT(IfStmt, ObjectKind::IfStmt);
DEFINE_KIND_TRAIT(YieldStmt, ObjectKind::YieldStmt);
DEFINE_KIND_TRAIT(ReturnStmt, ObjectKind::ReturnStmt);
DEFINE_KIND_TRAIT(ForStmt, ObjectKind::ForStmt);
DEFINE_KIND_TRAIT(SeqStmts, ObjectKind::SeqStmts);
DEFINE_KIND_TRAIT(EvalStmt, ObjectKind::EvalStmt);
DEFINE_KIND_TRAIT(WhileStmt, ObjectKind::WhileStmt);
DEFINE_KIND_TRAIT(SectionStmt, ObjectKind::SectionStmt);
DEFINE_KIND_TRAIT(BreakStmt, ObjectKind::BreakStmt);
DEFINE_KIND_TRAIT(ContinueStmt, ObjectKind::ContinueStmt);
DEFINE_KIND_TRAIT(ScalarOpStmt, ObjectKind::ScalarOpStmt);
DEFINE_KIND_TRAIT(TensorOpStmt, ObjectKind::TensorOpStmt);
DEFINE_KIND_TRAIT(UnknownType, ObjectKind::UnknownType);
DEFINE_KIND_TRAIT(ScalarType, ObjectKind::ScalarType);
DEFINE_KIND_TRAIT(TensorType, ObjectKind::TensorType);
DEFINE_KIND_TRAIT(TileType, ObjectKind::TileType);
DEFINE_KIND_TRAIT(TupleType, ObjectKind::TupleType);
DEFINE_KIND_TRAIT(MemRefType, ObjectKind::MemRefType);
DEFINE_KIND_TRAIT(PtrType, ObjectKind::PtrType);
DEFINE_KIND_TRAIT(TokenType, ObjectKind::TokenType);
DEFINE_KIND_TRAIT(LogicalTensorType, ObjectKind::LogicalTensorType);
DEFINE_KIND_TRAIT(Function, ObjectKind::Function);
DEFINE_KIND_TRAIT(Program, ObjectKind::Program);
DEFINE_KIND_TRAIT(Op, ObjectKind::Op);
#undef DEFINE_KIND_TRAIT
template <>
struct KindTrait<Stmt> {
static constexpr ObjectKind kinds[] = {ObjectKind::AssignStmt, ObjectKind::IfStmt, ObjectKind::YieldStmt,
ObjectKind::ReturnStmt, ObjectKind::ForStmt, ObjectKind::WhileStmt,
ObjectKind::SeqStmts, ObjectKind::EvalStmt, ObjectKind::BreakStmt,
ObjectKind::ContinueStmt, ObjectKind::ScalarOpStmt, ObjectKind::TensorOpStmt,
ObjectKind::SectionStmt};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
template <>
struct KindTrait<Expr> {
static constexpr ObjectKind kinds[] = {
ObjectKind::Var,
ObjectKind::MemRef,
ObjectKind::Call,
ObjectKind::MakeTuple,
ObjectKind::GetItemExpr,
ObjectKind::ConstInt,
ObjectKind::ConstFloat,
ObjectKind::ConstBool,
ObjectKind::Add,
ObjectKind::Sub,
ObjectKind::Mul,
ObjectKind::FloorDiv,
ObjectKind::FloorMod,
ObjectKind::FloatDiv,
ObjectKind::Min,
ObjectKind::Max,
ObjectKind::Pow,
ObjectKind::Eq,
ObjectKind::Ne,
ObjectKind::Lt,
ObjectKind::Le,
ObjectKind::Gt,
ObjectKind::Ge,
ObjectKind::And,
ObjectKind::Or,
ObjectKind::Xor,
ObjectKind::BitAnd,
ObjectKind::BitOr,
ObjectKind::BitXor,
ObjectKind::BitShiftLeft,
ObjectKind::BitShiftRight,
ObjectKind::Abs,
ObjectKind::Neg,
ObjectKind::Not,
ObjectKind::BitNot,
ObjectKind::Cast,
ObjectKind::ScalarExpr,
};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
template <>
struct KindTrait<BinaryExpr> {
static constexpr ObjectKind kinds[] = {
ObjectKind::Add, ObjectKind::Sub, ObjectKind::Mul, ObjectKind::FloorDiv,
ObjectKind::FloorMod, ObjectKind::FloatDiv, ObjectKind::Min, ObjectKind::Max,
ObjectKind::Pow, ObjectKind::Eq, ObjectKind::Ne, ObjectKind::Lt,
ObjectKind::Le, ObjectKind::Gt, ObjectKind::Ge, ObjectKind::And,
ObjectKind::Or, ObjectKind::Xor, ObjectKind::BitAnd, ObjectKind::BitOr,
ObjectKind::BitXor, ObjectKind::BitShiftLeft, ObjectKind::BitShiftRight};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
template <>
struct KindTrait<UnaryExpr> {
static constexpr ObjectKind kinds[] = {
ObjectKind::Abs, ObjectKind::Neg, ObjectKind::Not, ObjectKind::BitNot, ObjectKind::Cast};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
template <>
struct KindTrait<Type> {
static constexpr ObjectKind kinds[] = {
ObjectKind::UnknownType, ObjectKind::ScalarType, ObjectKind::ShapedType,
ObjectKind::TensorType, ObjectKind::TileType, ObjectKind::TupleType,
ObjectKind::TokenType, ObjectKind::LogicalTensorType, ObjectKind::PtrType};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
template <>
struct KindTrait<ShapedType> {
static constexpr ObjectKind kinds[] = {ObjectKind::ShapedType, ObjectKind::TensorType, ObjectKind::TileType};
static constexpr size_t count = sizeof(kinds) / sizeof(ObjectKind);
};
* \brief Check if an IR node is of a specific type (supports inheritance)
*
* \tparam T The target type (concrete or base class)
* \param node The IR node pointer to check
* \return true if node is of type T or inherits from T
*
* \example
* // Concrete type check
* if (IsA<Var>(expr)) {
* // expr is a Var
* }
*
* // Base class check (NEW)
* if (IsA<Stmt>(node)) { ... } // True for any statement type
* if (IsA<BinaryExpr>(expr)) { ... } // True for Add, Sub, Mul, etc.
*/
template <typename T, typename Base, typename = std::enable_if_t<std::is_base_of_v<Base, T>>>
bool IsA(const std::shared_ptr<const Base>& base)
{
if (!base)
return false;
if constexpr (detail::HasSingleKind<T>::value) {
return base->GetKind() == KindTrait<T>::kind;
} else if constexpr (detail::HasKindArray<T>::value) {
return detail::IsKindInArray<T>(base->GetKind());
}
return false;
}
* \brief Safely cast an IR node to a specific type (supports inheritance)
*
* Uses static_pointer_cast for zero runtime overhead after Kind check.
*
* \tparam T The target type (concrete or base class)
* \param node The IR node pointer to cast
* \return Shared pointer to T if cast succeeds, nullptr otherwise
*
* \example
* // Concrete cast
* if (auto var = As<Var>(expr)) {
* // Use var safely
* std::cout << var->name_;
* }
*
* // Base class cast (NEW)
* if (auto stmt = As<Stmt>(node)) { ... } // Cast any statement type
* if (auto binop = As<BinaryExpr>(expr)) { ... } // Cast any binary op
*/
template <typename T, typename Base, typename = std::enable_if_t<std::is_base_of_v<Base, T>>>
std::shared_ptr<const T> As(const std::shared_ptr<const Base>& base)
{
return IsA<T>(base) ? std::static_pointer_cast<const T>(base) : nullptr;
}
template <typename T, typename Base, typename = std::enable_if_t<std::is_base_of_v<Base, T>>>
std::shared_ptr<T> AsMut(const std::shared_ptr<const Base>& base)
{
return IsA<T>(base) ? std::static_pointer_cast<T>(std::const_pointer_cast<Base>(base)) : nullptr;
}
}
}