* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <any>
#include <atomic>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>
#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/core.h"
#include "ir/pipe.h"
#include "ir/reflection/field_traits.h"
#include "ir/span.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
* \brief Base class for all expressions in the IR
*
* This is the root base class for all expression types (scalar, tensor, etc).
* Expressions represent computations that produce values.
* All expressions are immutable.
*/
class Expr : public IRNode {
protected:
TypePtr type_;
public:
* \brief Create an expression
*
* \param span Source location
* \param type Type of the expression result (defaults to UnknownType)
*/
explicit Expr(Span s, TypePtr type = GetUnknownType()) : IRNode(std::move(s)), type_(std::move(type)) {}
~Expr() override = default;
* \brief Get the type name of this expression
*
* \return Human-readable type name (e.g., "ScalarExpr", "Var", "Call")
*/
[[nodiscard]] std::string TypeName() const override { return "Expr"; }
* \brief Get the type of this expression
*
* \return Type pointer of the expression result
*/
[[nodiscard]] const TypePtr& GetType() const { return type_; }
static constexpr auto GetFieldDescriptors()
{
return std::tuple_cat(
IRNode::GetFieldDescriptors(), std::make_tuple(reflection::UsualField(&Expr::type_, "type")));
}
};
using ExprPtr = std::shared_ptr<const Expr>;
enum class MemorySpace;
* \brief Base class for operations/functions
*
* Represents callable operations in the IR.
* Stores the schema of allowed kwargs (key -> expected type mapping).
* Actual kwarg values are stored per-Call instance in Call::kwargs_.
*/
class Op {
public:
std::string name_;
explicit Op(std::string name) : name_(std::move(name)) {}
virtual ~Op() = default;
* \brief Register an allowed kwarg with its expected type
*
* Defines that this operator accepts a kwarg with the given key and type.
* This is used for validation when creating Call expressions.
*
* Only specific types are allowed: bool, int, std::string, double, DataType, MemorySpace,
* std::vector<int>. This is enforced at compile-time via static_assert.
*
* \tparam T Expected type of the kwarg value (must be one of the allowed types)
* \param key Kwarg key (string identifier)
*/
template <typename T>
void SetAttrType(const std::string& key) const
{
static_assert(
std::is_same_v<T, bool> || std::is_same_v<T, int> || std::is_same_v<T, std::string> ||
std::is_same_v<T, double> || std::is_same_v<T, DataType> || std::is_same_v<T, MemorySpace> ||
std::is_same_v<T, std::vector<int>>,
"SetAttrType only accepts: bool, int, std::string, double, DataType, MemorySpace, std::vector<int>");
attrs_.emplace(key, std::type_index(typeid(T)));
}
* \brief Get the expected type for a kwarg
*
* \param key Kwarg key
* \return type_index of the expected type
* \throws pypto::ir::ValueError if kwarg is not registered
*/
[[nodiscard]] std::type_index GetAttrType(const std::string& key, const Span& span = Span::Unknown()) const
{
auto it = attrs_.find(key);
IRCHECK(it != attrs_.end()) << "Attribute '" << key << "' not found in operator '" << name_ << "'"
<< " at " << span.ToString();
return it->second;
}
* \brief Check if a kwarg is registered
*
* \param key Kwarg key
* \return true if the kwarg is registered
*/
[[nodiscard]] bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }
* \brief Get all registered kwarg keys
*
* \return Vector of all kwarg keys
*/
[[nodiscard]] std::vector<std::string> GetAttrKeys() const
{
std::vector<std::string> keys;
keys.reserve(attrs_.size());
for (const auto& pair : attrs_) {
keys.push_back(pair.first);
}
return keys;
}
* \brief Get all registered kwargs as a map
*
* \return Map of kwarg keys to expected types
*/
[[nodiscard]] const std::unordered_map<std::string, std::type_index>& GetAttrs() const { return attrs_; }
* \brief Set the pipeline type for this operator
*
* \param pipe Pipeline type (e.g., MTE2, V)
*/
void SetPipe(PipeType pipe) const { pipe_ = pipe; }
* \brief Get the pipeline type for this operator
*
* \return Optional pipeline type (nullopt if not set)
*/
[[nodiscard]] std::optional<PipeType> GetPipe() const { return pipe_; }
[[nodiscard]] virtual ObjectKind GetKind() const { return ObjectKind::Op; }
[[nodiscard]] virtual std::string TypeName() const { return "Op"; }
private:
mutable std::unordered_map<std::string, std::type_index> attrs_;
mutable std::optional<PipeType> pipe_;
};
using OpPtr = std::shared_ptr<const Op>;
* \brief Variable reference expression
*
* Represents a reference to a named variable.
* Can represent both scalar and tensor variables based on its type.
*/
class Var : public Expr {
public:
std::string name_;
* \brief Create a variable reference
*
* \param name Variable name
* \param type Type of the variable (ScalarType, TensorType, or TileType)
* Memory reference information is stored in ShapedType for Tensor/Tile types
* \param span Source location
* \return Shared pointer to const Var expression
*/
Var(std::string name, TypePtr type, Span span) : Expr(std::move(span), std::move(type)), name_(std::move(name)) {}
[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Var; }
[[nodiscard]] std::string TypeName() const override { return "Var"; }
* \brief Get field descriptors for reflection-based visitation
*
* \return Tuple of field descriptors (name_ as IGNORE field)
*/
static constexpr auto GetFieldDescriptors()
{
return std::tuple_cat(
Expr::GetFieldDescriptors(), std::make_tuple(reflection::IgnoreField(&Var::name_, "name")));
}
};
using VarPtr = std::shared_ptr<const Var>;
* \brief Iteration argument variable
*
* Represents an iteration argument (loop-carried value) in for loops.
* IterArgs implement SSA-style loop-carried dependencies where values are
* carried from one iteration to the next via yield statements.
*
* **Scoping Rules:**
* - IterArg variables are scoped to the loop body only
* - Cannot be directly accessed outside the loop
* - Must use return_vars to expose final values after the loop
*
* **Usage Pattern:**
* 1. Create IterArg with initial value
* 2. Use in ForStmt's iter_args list
* 3. Update via YieldStmt in loop body
* 4. Capture final value in ForStmt's return_vars
*
* \example
* // for i, (sum,) in pl.range(0, n, 1, init_values=[0]):
* // sum = pl.yield_(sum + i)
* // sum_final = sum
* auto sum_iter = std::make_shared<IterArg>("sum", type, init_val, span);
* auto sum_final = std::make_shared<Var>("sum_final", type, span);
* auto for_stmt = std::make_shared<ForStmt>(
* i, start, stop, step,
* std::vector{sum_iter}, // iter_args (loop-scoped)
* body,
* std::vector{sum_final}, // return_vars (accessible after loop)
* span
* );
*/
class IterArg {
public:
VarPtr iterVar_;
ExprPtr initValue_;
* \brief Create an iteration argument
*
* \param name Variable name (scoped to loop body)
* \param type Type of the variable (ScalarType, TensorType, or TileType)
* Memory reference information is stored in ShapedType for Tensor/Tile types
* \param initValue Initial value expression for first iteration
* \param span Source location
*/
IterArg(std::string name, TypePtr type, ExprPtr initValue, Span span)
: iterVar_(std::make_shared<Var>(name, std::move(type), std::move(span))), initValue_(std::move(initValue))
{}
* \brief Create an iteration argument with existing variable reference
*
* \param var Variable reference for the iteration argument
* \param initValue Initial value expression for first iteration
*/
IterArg(VarPtr var, ExprPtr initValue) : iterVar_(std::move(var)), initValue_(std::move(initValue)) {}
* \brief Get field descriptors for reflection-based visitation
*
* \return Tuple of field descriptors (initValue_ as USUAL field)
*/
static constexpr auto GetFieldDescriptors()
{
return std::make_tuple(
reflection::UsualField(&IterArg::initValue_, "initValue"),
reflection::UsualField(&IterArg::iterVar_, "iterVar"));
}
};
using IterArgPtr = std::shared_ptr<const IterArg>;
* \brief Function call expression
*
* Represents a function call with an operation and arguments.
* Can accept any Expr as arguments, not just scalar expressions.
* Supports keyword arguments (kwargs) for operator metadata.
*/
class Call : public Expr {
public:
std::string name_;
std::vector<ExprPtr> args_;
std::vector<std::pair<std::string, std::any>> kwargs_;
* \brief Create a function call expression
*
* \param op Operation/function to call
* \param args List of argument expressions
* \param span Source location
*/
Call(std::string name, std::vector<ExprPtr> args, Span span)
: Expr(std::move(span)), name_(std::move(name)), args_(std::move(args)), kwargs_()
{}
* \brief Create a function call expression with explicit type
*
* \param op Operation/function to call
* \param args List of argument expressions
* \param type Result type of the call
* \param span Source location
*/
Call(std::string name, std::vector<ExprPtr> args, TypePtr type, Span span)
: Expr(std::move(span), std::move(type)), name_(std::move(name)), args_(std::move(args)), kwargs_()
{}
Call(std::string name, std::vector<ExprPtr> args, std::vector<std::pair<std::string, std::any>> kwargs, Span span)
: Expr(std::move(span)), name_(std::move(name)), args_(std::move(args)), kwargs_(std::move(kwargs))
{}
Call(
std::string name, std::vector<ExprPtr> args, std::vector<std::pair<std::string, std::any>> kwargs, TypePtr type,
Span span)
: Expr(std::move(span), std::move(type)),
name_(std::move(name)),
args_(std::move(args)),
kwargs_(std::move(kwargs))
{}
* \brief Get a kwarg value with type checking
*
* \tparam T Type of the kwarg value
* \param key Kwarg key
* \param default_value Default value if key doesn't exist
* \return The kwarg value or default
*/
template <typename T>
T GetKwarg(const std::string& key, const T& default_value = T{}) const
{
for (const auto& [k, v] : kwargs_) {
if (k == key) {
return AnyCast<T>(v, "kwarg key: " + key);
}
}
return default_value;
}
* \brief Check if a kwarg exists
*
* \param key Kwarg key
* \return true if the kwarg exists
*/
[[nodiscard]] bool HasKwarg(const std::string& key) const
{
for (const auto& kwarg : kwargs_) {
if (kwarg.first == key) {
return true;
}
}
return false;
}
[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::Call; }
[[nodiscard]] std::string TypeName() const override { return "Call"; }
* \brief Get field descriptors for reflection-based visitation
*
* \return Tuple of field descriptors (op, args, and kwargs as USUAL fields)
*/
static constexpr auto GetFieldDescriptors()
{
return std::tuple_cat(
Expr::GetFieldDescriptors(),
std::make_tuple(
reflection::UsualField(&Call::name_, "name"), reflection::UsualField(&Call::args_, "args"),
reflection::UsualField(&Call::kwargs_, "kwargs")));
}
};
using CallPtr = std::shared_ptr<const Call>;
* \brief Expression to create a tuple from multiple expressions
*
* Takes a list of expressions and creates a tuple value.
* The result type is TupleType containing the types of all input expressions.
*/
class MakeTuple : public Expr {
public:
std::vector<ExprPtr> elements_;
* \brief Create a tuple construction expression
*
* \param elements Expressions to be tuple elements
* \param span Source location
*/
MakeTuple(std::vector<ExprPtr> elements, Span span);
[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::MakeTuple; }
[[nodiscard]] std::string TypeName() const override { return "MakeTuple"; }
* \brief Get field descriptors for reflection-based visitation
*
* \return Tuple of field descriptors
*/
static constexpr auto GetFieldDescriptors()
{
return std::tuple_cat(
Expr::GetFieldDescriptors(), std::make_tuple(reflection::UsualField(&MakeTuple::elements_, "elements")));
}
};
using MakeTuplePtr = std::shared_ptr<const MakeTuple>;
* \brief Unified subscript expression: value[slice]
*
* Represents Python subscript syntax for both tuple element access and tile offset.
* The concrete semantics is determined by the static type of `value_`:
* - `value_` is `TupleType`: tuple element access. `slice_` must be a `ConstInt`
* whose value is the 0-based element index. Result type is the element type.
* - `value_` is `TileType`: tile element offset. `slice_` is an integer expression
* (static or dynamic). Result type is the same TileType as the base tile,
* with physical address shifted by `slice * sizeof(dtype)` bytes.
*/
class GetItemExpr : public Expr {
public:
ExprPtr value_;
ExprPtr slice_;
* \brief Create a subscript expression
*
* \param value Base expression (must have TupleType or TileType)
* \param slice Subscript expression (for TupleType, must be a ConstInt)
* \param span Source location
*/
GetItemExpr(ExprPtr value, ExprPtr slice, Span span);
[[nodiscard]] ObjectKind GetKind() const override { return ObjectKind::GetItemExpr; }
[[nodiscard]] std::string TypeName() const override { return "GetItemExpr"; }
static constexpr auto GetFieldDescriptors()
{
return std::tuple_cat(
Expr::GetFieldDescriptors(), std::make_tuple(
reflection::UsualField(&GetItemExpr::value_, "value"),
reflection::UsualField(&GetItemExpr::slice_, "slice")));
}
};
using GetItemExprPtr = std::shared_ptr<const GetItemExpr>;
}
}