* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <any>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "bindings/ir/bindings.h"
#include "core/any_cast.h"
#include "core/common.h"
#include "core/dtype.h"
#include "core/error.h"
#include "ir/core.h"
#include "ir/expr.h"
#include "ir/function.h"
#include "ir/memref.h"
#include "ir/op_registry.h"
#include "ir/pipe.h"
#include "ir/program.h"
#include "ir/reflection/field_visitor.h"
#include "ir/scalar_expr.h"
#include "ir/scalar_expr_ops.h"
#include "ir/stmt.h"
#include "ir/transforms/op_conversion_registry.h"
#include "ir/transforms/printer.h"
#include "ir/transforms/structural_comparison.h"
#include "ir/type.h"
#include "tilefwk/symbolic_scalar.h"
namespace py = pybind11;
using npu::tile_fwk::SymbolicScalar;
namespace pypto {
namespace ir {
using pypto::ir::DataType;
std::string ToPythonCTypeString(const DataType& dtype)
{
if (dtype == DataType::BF16) {
return "bfloat16";
}
return dtype.ToCTypeString();
}
std::shared_ptr<PtrType> GetOuterCompatPtrType()
{
static auto ptr_type = std::make_shared<PtrType>(DataType::INT8);
return ptr_type;
}
std::vector<FunctionPtr> GetOuterProgramFunctions(const ProgramPtr& program)
{
std::vector<FunctionPtr> functions;
functions.reserve(program->functions_.size());
for (const auto& entry : program->functions_) {
functions.push_back(entry.second);
}
return functions;
}
template <typename T>
bool TryConvertAnyToPy(const std::any& value, py::object& out)
{
if (value.type() != typeid(T)) {
return false;
}
out = py::cast(AnyCastRef<T>(value, "converting to Python"));
return true;
}
template <typename... Ts>
py::object AnyToPyObject(const std::any& value, const std::string& key)
{
py::object out;
if ((TryConvertAnyToPy<Ts>(value, out) || ...)) {
return out;
}
throw pypto::ir::TypeError("Attribute '" + key + "' has unsupported type");
}
template <typename T>
std::vector<T> ConvertSequenceToVector(py::handle obj)
{
auto seq = py::reinterpret_borrow<py::sequence>(obj);
std::vector<T> ret;
ret.reserve(py::len(seq));
for (auto item : seq) {
ret.push_back(py::cast<T>(item));
}
return ret;
}
template <typename ClassType, typename PyClassType, typename FieldDesc>
void BindField(PyClassType& py_class, const FieldDesc& desc)
{
py_class.def_readonly(desc.name, desc.fieldPtr);
}
template <typename ClassType, typename PyClassType, typename DescTuple, std::size_t... Is>
void BindFieldsImpl(PyClassType& py_class, const DescTuple& descriptors, std::index_sequence<Is...>)
{
(BindField<ClassType>(py_class, std::get<Is>(descriptors)), ...);
}
template <typename ClassType, typename PyClassType>
void BindFields(PyClassType& py_class)
{
constexpr auto descriptors = ClassType::GetFieldDescriptors();
constexpr auto num_fields = std::tuple_size_v<decltype(descriptors)>;
BindFieldsImpl<ClassType>(py_class, descriptors, std::make_index_sequence<num_fields>{});
}
std::vector<std::pair<std::string, std::any>> ConvertKwargsDict(const py::dict& kwargs_dict)
{
std::vector<std::pair<std::string, std::any>> kwargs;
for (auto item : kwargs_dict) {
std::string key = py::cast<std::string>(item.first);
if (py::isinstance<DataType>(item.second)) {
kwargs.emplace_back(key, py::cast<DataType>(item.second));
} else if (py::isinstance<MemorySpace>(item.second)) {
kwargs.emplace_back(key, py::cast<MemorySpace>(item.second));
} else if (py::isinstance<TilePad>(item.second)) {
kwargs.emplace_back(key, static_cast<int>(py::cast<TilePad>(item.second)));
} else if (py::isinstance<PipeType>(item.second)) {
kwargs.emplace_back(key, static_cast<int>(py::cast<PipeType>(item.second)));
} else if (py::isinstance<CoreType>(item.second)) {
kwargs.emplace_back(key, static_cast<int>(py::cast<CoreType>(item.second)));
} else if (py::isinstance<py::bool_>(item.second)) {
kwargs.emplace_back(key, py::cast<bool>(item.second));
} else if (py::isinstance<py::int_>(item.second)) {
kwargs.emplace_back(key, py::cast<int>(item.second));
} else if (py::isinstance<py::tuple>(item.second) || py::isinstance<py::list>(item.second)) {
auto seq = py::cast<py::sequence>(item.second);
if (py::len(seq) > 0 && py::isinstance<py::str>(seq[0])) {
kwargs.emplace_back(key, ConvertSequenceToVector<std::string>(seq));
} else {
kwargs.emplace_back(key, ConvertSequenceToVector<int>(seq));
}
} else if (py::isinstance<py::str>(item.second)) {
kwargs.emplace_back(key, py::cast<std::string>(item.second));
} else if (py::isinstance<py::float_>(item.second)) {
kwargs.emplace_back(key, py::cast<double>(item.second));
} else {
throw pypto::ir::TypeError("Unsupported kwarg type for key: " + key);
}
}
return kwargs;
}
std::any ConvertAttr(const py::object& attr)
{
if (py::isinstance<DataType>(attr)) {
return py::cast<DataType>(attr);
} else if (py::isinstance<MemorySpace>(attr)) {
return py::cast<MemorySpace>(attr);
} else if (py::isinstance<TensorLayout>(attr)) {
return py::cast<TensorLayout>(attr);
} else if (py::isinstance<bool>(attr)) {
return py::cast<bool>(attr);
} else if (py::isinstance<py::int_>(attr)) {
return py::cast<int>(attr);
} else if (py::isinstance<py::str>(attr)) {
return py::cast<std::string>(attr);
} else if (py::isinstance<py::float_>(attr)) {
return py::cast<double>(attr);
} else if (py::isinstance<SymbolicScalar>(attr)) {
return py::cast<SymbolicScalar>(attr);
} else if (py::isinstance<ExprPtr>(attr)) {
return py::cast<ExprPtr>(attr);
} else {
throw TypeError("Unsupported attr type");
}
}
std::any ConvertListAttr(const py::list& list)
{
if (list.empty()) {
return std::any();
}
auto item0 = list[0];
if (py::isinstance<SymbolicScalar>(item0)) {
std::vector<SymbolicScalar> ret;
for (auto item : list) {
ret.push_back(py::cast<SymbolicScalar>(item));
}
return ret;
} else if (py::isinstance<ExprPtr>(item0)) {
return ConvertSequenceToVector<ExprPtr>(list);
} else if (py::isinstance<std::string>(item0)) {
return ConvertSequenceToVector<std::string>(list);
} else if (py::isinstance<py::int_>(item0)) {
return ConvertSequenceToVector<int>(list);
} else {
throw TypeError("Unsupported list attr");
}
}
std::vector<std::pair<std::string, std::any>> ConvertAttrDict(const py::dict& attrs)
{
std::vector<std::pair<std::string, std::any>> kwargs;
for (auto item : attrs) {
std::string key = py::cast<std::string>(item.first);
std::any value;
if (py::isinstance<py::list>(item.second)) {
value = ConvertListAttr(py::cast<py::list>(item.second));
} else {
value = ConvertAttr(py::cast<py::object>(item.second));
}
if (value.has_value()) {
kwargs.emplace_back(key, value);
}
}
return kwargs;
}
void BindDType(py::module_& ir)
{
py::class_<DataType>(ir, "DataType", "Data type representation for IR tensors and operations")
.def_readonly_static("BOOL", &DataType::BOOL, "Boolean (true/false)")
.def_readonly_static("INT4", &DataType::INT4, "4-bit signed integer")
.def_readonly_static("INT8", &DataType::INT8, "8-bit signed integer")
.def_readonly_static("INT16", &DataType::INT16, "16-bit signed integer")
.def_readonly_static("INT32", &DataType::INT32, "32-bit signed integer")
.def_readonly_static("INT64", &DataType::INT64, "64-bit signed integer")
.def_readonly_static("UINT4", &DataType::UINT4, "4-bit unsigned integer")
.def_readonly_static("UINT8", &DataType::UINT8, "8-bit unsigned integer")
.def_readonly_static("UINT16", &DataType::UINT16, "16-bit unsigned integer")
.def_readonly_static("UINT32", &DataType::UINT32, "32-bit unsigned integer")
.def_readonly_static("UINT64", &DataType::UINT64, "64-bit unsigned integer")
.def_readonly_static("FP4", &DataType::FP4, "4-bit floating point")
.def_readonly_static("FP8E4M3FN", &DataType::FP8E4M3FN, "8-bit floating point (E4M3FN format)")
.def_readonly_static("FP8E5M2", &DataType::FP8E5M2, "8-bit floating point (E5M2 format)")
.def_readonly_static("FP16", &DataType::FP16, "16-bit floating point (IEEE 754 half precision)")
.def_readonly_static("FP32", &DataType::FP32, "32-bit floating point (IEEE 754 single precision)")
.def_readonly_static("BF16", &DataType::BF16, "16-bit brain floating point")
.def_readonly_static("FP64", &DataType::FP64, "64-bit floating point (IEEE 754 double precision)")
.def_readonly_static("HF4", &DataType::HF4, "4-bit Hisilicon float")
.def_readonly_static("HF8", &DataType::HF8, "8-bit Hisilicon float")
.def_readonly_static("INDEX", &DataType::INDEX, "Machine-word sized integer type for index computations")
.def_readonly_static(
"DEFAULT_CONST_INT", &DataType::INT64, "Default dtype for bare integer constant literals (= INT64)")
.def_readonly_static(
"DEFAULT_CONST_FLOAT", &DataType::FP32, "Default dtype for bare float constant literals (= FP32)")
.def("get_bit", &DataType::GetBit, "Get the size in bits of this data type.")
.def("to_string", &DataType::ToString, "Get a human-readable string name.")
.def("to_c_type_string", &DataType::ToCTypeString, "Get C style type string for code generation.")
.def("is_float", &DataType::IsFloat, "Check if floating point type.")
.def("is_signed_int", &DataType::IsSignedInt, "Check if signed integer type.")
.def("is_unsigned_int", &DataType::IsUnsignedInt, "Check if unsigned integer type.")
.def("is_int", &DataType::IsInt, "Check if any integer type.")
.def("code", &DataType::Code, "Get the underlying type code.")
.def("bits", &DataType::GetBit, "Get the size in bits (alias for get_bit).")
.def("c_type", &ToPythonCTypeString, "Get C style type string with BF16 handling.")
.def("is_signed", &DataType::IsSignedInt, "Check if signed integer (alias for is_signed_int).")
.def("is_unsigned", &DataType::IsUnsignedInt, "Check if unsigned integer (alias for is_unsigned_int).")
.def("__int__", &DataType::Code, "Get the underlying type code as int.")
.def("__eq__", &DataType::operator==, py::arg("other"))
.def("__ne__", &DataType::operator!=, py::arg("other"))
.def("__repr__", &DataType::ToString)
.def("__str__", &DataType::ToString);
}
void BindSpan(py::module_& ir)
{
py::class_<Span>(
ir, "Span", "Source location information tracking file, line, and column positions")
.def(
py::init<std::string, int, int, int, int>(), py::arg("filename"), py::arg("begin_line"),
py::arg("begin_column"), py::arg("end_line") = -1, py::arg("end_column") = -1, "Create a source span")
.def("to_string", &Span::ToString, "Convert span to string representation")
.def("is_valid", &Span::IsValid, "Check if the span has valid coordinates")
.def_static(
"unknown", &Span::Unknown, "Create an unknown/invalid span for cases where source location is unavailable",
py::return_value_policy::reference)
.def(
"is_unknown", [](const Span& self) { return self.IsUnknown(); }, "Check if the span is unknown")
.def("__repr__", &Span::ToString)
.def("__str__", &Span::ToString)
.def_property_readonly("filename", &Span::Filename, "Source filename")
.def_property_readonly("begin_line", &Span::BeginLine, "Beginning line (1-indexed)")
.def_property_readonly("begin_column", &Span::BeginColumn, "Beginning column (1-indexed)")
.def_property_readonly("end_line", &Span::EndLine, "Ending line (1-indexed)")
.def_property_readonly("end_column", &Span::EndColumn, "Ending column (1-indexed)");
}
void BindOp(py::module_& ir)
{
py::class_<Op, std::shared_ptr<Op>>(
ir, "Op",
"Represents callable operations in the IR. Stores the schema of allowed kwargs (key -> type "
"mapping). Actual kwarg values are stored per-Call instance in Call.kwargs")
.def(py::init<std::string>(), py::arg("name"), "Create an operation with the given name")
.def_readonly("name", &Op::name_, "Operation name")
.def("has_attr", &Op::HasAttr, py::arg("key"), "Check if a kwarg is registered in the schema")
.def("get_attr_keys", &Op::GetAttrKeys, "Get all registered kwarg keys from the schema")
.def_property_readonly(
"pipe", [](const Op& self) -> std::optional<PipeType> { return self.GetPipe(); },
"Pipeline type (optional)");
}
void BindTypeClass(py::module_& ir)
{
auto type_class =
py::class_<Type, std::shared_ptr<Type>>(ir, "Type", "Base class for type representations");
BindFields<Type>(type_class);
type_class.def(
"__str__", [](const TypePtr& self) { return PythonPrint(self, "ir"); },
"IR string representation");
type_class.def(
"__eq__", [](const TypePtr& self, const TypePtr& other) { return structural_equal(self, other); },
"Equality comparison");
auto unknown_type_class = py::class_<UnknownType, Type, std::shared_ptr<UnknownType>>(
ir, "UnknownType", "Unknown or unspecified type representation");
unknown_type_class.def(py::init<>(), "Create an unknown type");
unknown_type_class.def_static("get", []() { return GetUnknownType(); }, "Get the singleton UnknownType instance");
BindFields<UnknownType>(unknown_type_class);
auto scalar_type_class = py::class_<ScalarType, Type, std::shared_ptr<ScalarType>>(
ir, "ScalarType", "Scalar type representation");
scalar_type_class.def(py::init<DataType>(), py::arg("dtype"), "Create a scalar type");
BindFields<ScalarType>(scalar_type_class);
auto ptr_type_class = py::class_<PtrType, Type, std::shared_ptr<PtrType>>(
ir, "PtrType", "Pointer type representation (!pto.ptr<dtype>)");
ptr_type_class.def(py::init<DataType>(), py::arg("dtype"), "Create a pointer type");
ptr_type_class.def(
py::init([]() { return GetOuterCompatPtrType(); }),
"Create the outer-IR compatibility pointer type (INT8-based)");
ptr_type_class.def_static("get", &GetOuterCompatPtrType, "Get the outer-IR compatibility pointer type singleton");
BindFields<PtrType>(ptr_type_class);
auto irnode_class =
py::class_<IRNode, std::shared_ptr<IRNode>>(ir, "IRNode", "Base class for all IR nodes");
BindFields<IRNode>(irnode_class);
irnode_class
.def(
"same_as", [](const IRNodePtr& self, const IRNodePtr& other) { return self == other; }, py::arg("other"),
"Check if this IR node is the same as another IR node.")
.def(
"__str__", [](const IRNodePtr& self) { return PythonPrint(self, "ir"); },
"IR string representation")
.def(
"as_python", [](const IRNodePtr& self, const std::string& prefix) { return PythonDslPrint(self, prefix); },
py::arg("prefix") = "pl",
"Convert to Python-style string representation.\n\n"
"Args:\n"
" prefix: Module prefix (default 'pl' for 'import pypto.language as pl')");
auto expr_class = py::class_<Expr, IRNode, std::shared_ptr<Expr>>(
ir, "Expr", "Base class for all expressions");
BindFields<Expr>(expr_class);
auto shaped_type_class = py::class_<ShapedType, Type, std::shared_ptr<ShapedType>>(
ir, "ShapedType", "Base class for shaped types (tensors and tiles)");
BindFields<ShapedType>(shaped_type_class);
shaped_type_class.def(
"shares_memref_with",
[](const ShapedTypePtr& self, const ShapedTypePtr& other) {
if (!self->memref_.has_value() || !other->memref_.has_value()) {
return false;
}
return self->memref_.value().get() == other->memref_.value().get();
},
py::arg("other"), "Check if this ShapedType shares the same MemRef object with another ShapedType");
py::enum_<TensorLayout>(ir, "TensorLayout", "Tensor layout enumeration")
.value("ND", TensorLayout::ND, "ND layout")
.value("DN", TensorLayout::DN, "DN layout")
.value("NZ", TensorLayout::NZ, "NZ layout")
.export_values();
py::class_<TensorView>(ir, "TensorView", "Tensor view representation with stride and layout")
.def(py::init<>(), "Create an empty tensor view")
.def(
py::init<const std::vector<ExprPtr>&, TensorLayout>(), py::arg("stride"), py::arg("layout"),
"Create a tensor view with stride and layout")
.def(
py::init<const std::vector<ExprPtr>&, const std::vector<ExprPtr>&, TensorLayout>(), py::arg("valid_shape"),
py::arg("stride"), py::arg("layout"), "Create a tensor view with valid_shape, stride and layout")
.def_readwrite("valid_shape", &TensorView::validShape, "Valid shape dimensions")
.def_readwrite("stride", &TensorView::stride, "Stride for each dimension")
.def_readwrite("layout", &TensorView::layout, "Tensor layout type")
.def_readwrite(
"ptr", &TensorView::ptr, "Source pointer ExprPtr (set for ptr.make_tensor-created views; None otherwise).");
auto tensor_type_class = py::class_<TensorType, ShapedType, std::shared_ptr<TensorType>>(
ir, "TensorType", "Tensor type representation");
tensor_type_class.def(
py::init<const std::vector<ExprPtr>&, DataType, std::optional<MemRefPtr>>(), py::arg("shape"), py::arg("dtype"),
py::arg("memref") = py::none(), "Create a tensor type");
tensor_type_class.def(
py::init<const std::vector<int64_t>&, DataType, std::optional<MemRefPtr>>(), py::arg("shape"), py::arg("dtype"),
py::arg("memref") = py::none(), "Create a tensor type");
tensor_type_class.def(
py::init<const std::vector<ExprPtr>&, DataType, std::optional<MemRefPtr>, std::optional<TensorView>>(),
py::arg("shape"), py::arg("dtype"), py::arg("memref") = py::none(), py::arg("tensor_view") = py::none(),
"Create a tensor type with optional memory reference and tensor view");
tensor_type_class.def(
py::init<const std::vector<int64_t>&, DataType, std::optional<MemRefPtr>, std::optional<TensorView>>(),
py::arg("shape"), py::arg("dtype"), py::arg("memref") = py::none(), py::arg("tensor_view") = py::none(),
"Create a tensor type with constant shape, optional memory reference and tensor view");
BindFields<TensorType>(tensor_type_class);
auto tile_type_class = py::class_<TileType, ShapedType, std::shared_ptr<TileType>>(
ir, "TileType", "Tile type representation (multi-dimensional tensor)");
tile_type_class.def(
py::init<
const std::vector<ExprPtr>&, DataType, std::optional<MemRefPtr>, std::optional<TileView>,
std::optional<HardwareInfo>>(),
py::arg("shape"), py::arg("dtype"), py::arg("memref") = py::none(), py::arg("tile_view") = py::none(),
py::arg("hardware_info") = py::none(),
"Create a tile type (supports multi-dimensional tensors; code generation has constraints)");
tile_type_class.def(
py::init<
const std::vector<int64_t>&, DataType, std::optional<MemRefPtr>, std::optional<TileView>,
std::optional<HardwareInfo>>(),
py::arg("shape"), py::arg("dtype"), py::arg("memref") = py::none(), py::arg("tile_view") = py::none(),
py::arg("hardware_info") = py::none(),
"Create a tile type (supports multi-dimensional tensors; code generation has constraints)");
BindFields<TileType>(tile_type_class);
auto tuple_type_class = py::class_<TupleType, Type, std::shared_ptr<TupleType>>(
ir, "TupleType", "Tuple type representation (contains multiple types)");
tuple_type_class.def(
py::init<const std::vector<TypePtr>&>(), py::arg("types"),
"Create a tuple type from a list of types");
BindFields<TupleType>(tuple_type_class);
py::class_<TokenType, Type, std::shared_ptr<TokenType>>(ir, "TokenType", "Opaque token type")
.def(py::init<>(), "Create a token type")
.def_static("get", GetTokenType, "Get the token type");
py::class_<LogicalTensorType, Type, std::shared_ptr<LogicalTensorType>>(
ir, "LogicalTensorType", "Logical tensor type")
.def(py::init<>(), "Create a logical tensor type");
py::enum_<MemorySpace>(ir, "MemorySpace", "Memory space enumeration")
.value("DDR", MemorySpace::DDR, "DDR memory (off-chip)")
.value("Vec", MemorySpace::Vec, "Vector/unified buffer (on-chip)")
.value("Mat", MemorySpace::Mat, "Matrix/L1 buffer")
.value("Left", MemorySpace::Left, "Left matrix operand buffer")
.value("Right", MemorySpace::Right, "Right matrix operand buffer")
.value("Scaling", MemorySpace::Scaling, "Scaling/FBuffer buffer")
.value("Acc", MemorySpace::Acc, "Accumulator buffer")
.value("Bias", MemorySpace::Bias, "Bias buffer")
.export_values();
ir.attr("Mem") = ir.attr("MemorySpace");
py::enum_<PipeType>(ir, "PipeType", py::arithmetic(), "Pipeline type enumeration")
.value("MTE1", PipeType::MTE1, "Memory Transfer Engine 1")
.value("MTE2", PipeType::MTE2, "Memory Transfer Engine 2")
.value("MTE3", PipeType::MTE3, "Memory Transfer Engine 3")
.value("M", PipeType::M, "Matrix Unit")
.value("V", PipeType::V, "Vector Unit")
.value("S", PipeType::S, "Scalar Unit")
.value("FIX", PipeType::FIX, "Fix Pipe")
.value("ALL", PipeType::ALL, "All Pipes")
.export_values();
py::enum_<CoreType>(ir, "CoreType", py::arithmetic(), "Core type enumeration")
.value("VECTOR", CoreType::VECTOR, "Vector Core")
.value("CUBE", CoreType::CUBE, "Cube Core")
.export_values();
py::enum_<TileLayout>(ir, "TileLayout", "Tile layout enumeration")
.value("none_box", TileLayout::none_box, "No layout constraint")
.value("row_major", TileLayout::row_major, "Row-major layout")
.value("col_major", TileLayout::col_major, "Column-major layout")
.export_values();
py::enum_<TilePad>(ir, "TilePad", "Tile pad mode enumeration")
.value("null", TilePad::null, "No padding")
.value("zero", TilePad::zero, "Zero padding")
.value("max", TilePad::max, "Max value padding")
.value("min", TilePad::min, "Min value padding")
.export_values();
py::enum_<CompactMode>(ir, "CompactMode", "Compact mode for tile buffer")
.value("null", CompactMode::null, "No compact mode")
.value("normal", CompactMode::normal, "Normal compact mode")
.value("row_plus_one", CompactMode::row_plus_one, "Row plus one compact mode")
.export_values();
py::class_<HardwareInfo>(
ir, "HardwareInfo", "Hardware-specific tile information (layout, fractal, pad, compact)")
.def(py::init<>(), "Create a default hardware info")
.def(
py::init<TileLayout, TileLayout, uint64_t, TilePad, CompactMode>(),
py::arg("blayout") = TileLayout::row_major, py::arg("slayout") = TileLayout::none_box,
py::arg("fractal") = static_cast<uint64_t>(512), py::arg("pad") = TilePad::null,
py::arg("compact") = CompactMode::null,
"Create hardware info with blayout, slayout, fractal, pad, and compact")
.def_readwrite("blayout", &HardwareInfo::blayout, "Block layout")
.def_readwrite("slayout", &HardwareInfo::slayout, "Scatter layout")
.def_readwrite("fractal", &HardwareInfo::fractal, "Fractal size")
.def_readwrite("pad", &HardwareInfo::pad, "Pad mode")
.def_readwrite("compact", &HardwareInfo::compact, "Compact mode");
py::class_<TileView>(
ir, "TileView", "Tile view representation with valid shape, stride, and start offset")
.def(py::init<>(), "Create an empty tile view")
.def(
py::init<const std::vector<ExprPtr>&, const std::vector<ExprPtr>&, ExprPtr>(), py::arg("valid_shape"),
py::arg("stride"), py::arg("start_offset"), "Create a tile view with valid_shape, stride, and start_offset")
.def_readwrite("valid_shape", &TileView::validShape, "Valid shape dimensions")
.def_readwrite("stride", &TileView::stride, "Stride for each dimension")
.def_readwrite("start_offset", &TileView::startOffset, "Starting offset");
ir.attr("DYNAMIC_DIM") = kDynamicDim;
ir.def(
"create_op_call",
[](const std::string& op_name, const std::vector<ExprPtr>& args, const Span& span) {
return OpRegistry::GetInstance().Create(op_name, args, span);
},
py::arg("op_name"), py::arg("args"), py::arg("span"), "Create a Call expression (backward compatibility)");
ir.def(
"create_op_call",
[](const std::string& op_name, const std::vector<ExprPtr>& args, const py::dict& kwargs_dict,
const Span& span) {
auto kwargs = ConvertKwargsDict(kwargs_dict);
return OpRegistry::GetInstance().Create(op_name, args, kwargs, span);
},
py::arg("op_name"), py::arg("args"), py::arg("kwargs"), py::arg("span"),
"Create a Call expression with args and kwargs");
ir.def(
"is_op_registered", [](const std::string& op_name) { return OpRegistry::GetInstance().IsRegistered(op_name); },
py::arg("op_name"), "Check if an operator is registered");
ir.def(
"get_op", [](const std::string& op_name) { return OpRegistry::GetInstance().GetOp(op_name); },
py::arg("op_name"), "Get an operator instance by name");
}
void BindExpr(py::module_& ir)
{
auto var_class =
py::class_<Var, Expr, std::shared_ptr<Var>>(ir, "Var", "Variable reference expression");
var_class.def(
py::init<const std::string&, const TypePtr&, const Span&>(), py::arg("name"), py::arg("type"), py::arg("span"),
"Create a variable reference (memory reference is stored in ShapedType for Tensor/Tile types)");
BindFields<Var>(var_class);
auto iterarg_class = py::class_<IterArg, std::shared_ptr<IterArg>>(
ir, "IterArg", "Iteration argument variable");
iterarg_class
.def(
py::init<const std::string&, const TypePtr&, const ExprPtr&, const Span&>(), py::arg("name"),
py::arg("type"), py::arg("initValue"), py::arg("span"),
"Create an iteration argument with initial value")
.def(
py::init<VarPtr, const ExprPtr&>(), py::arg("iterVar"), py::arg("initValue"),
"Create an iteration argument with initial value")
.def(
"__str__", [](const std::shared_ptr<const IterArg>& self) -> std::string { return self->iterVar_->name_; })
.def_property_readonly(
"name", [](const std::shared_ptr<const IterArg>& self) -> const std::string& {
return self->iterVar_->name_;
},
"Variable name");
BindFields<IterArg>(iterarg_class);
auto memref_class = py::class_<MemRef, Var, std::shared_ptr<MemRef>>(
ir, "MemRef", "Memory reference variable for shaped types (inherits from Var)");
memref_class
.def(
py::init<MemorySpace, ExprPtr, uint64_t, Span>(), py::arg("memory_space"), py::arg("addr"), py::arg("size"),
py::arg("span") = Span::Unknown(), "Create a memory reference with memory_space, addr, size, and span")
.def(
py::init<MemorySpace, ExprPtr, uint64_t, uint64_t, Span>(), py::arg("memory_space"), py::arg("addr"),
py::arg("size"), py::arg("id"), py::arg("span") = Span::Unknown(),
"Create a memory reference with memory_space, addr, size, id, and span")
.def_readonly("memory_space", &MemRef::memorySpace_, "Memory space (DDR, Vec, Mat, Left, Right, Scaling, Acc)")
.def_readonly("addr", &MemRef::addr_, "Starting address expression")
.def_readonly("size", &MemRef::size_, "Size in bytes (64-bit unsigned)")
.def_static(
"same_allocation", &MemRef::SameAllocation, py::arg("a"), py::arg("b"),
"Check if two MemRefs share the same allocation (same base pointer)")
.def_static(
"may_alias", &MemRef::MayAlias, py::arg("a"), py::arg("b"),
"Check if two MemRefs may alias (same base + overlapping byte ranges)");
auto constint_class = py::class_<ConstInt, Expr, std::shared_ptr<ConstInt>>(
ir, "ConstInt", "Constant integer expression");
constint_class.def(
py::init<int64_t, DataType, const Span&>(), py::arg("value"), py::arg("dtype"), py::arg("span"),
"Create a constant integer expression");
BindFields<ConstInt>(constint_class);
constint_class.def_property_readonly("dtype", &ConstInt::dtype, "Data type of the expression");
auto constfloat_class = py::class_<ConstFloat, Expr, std::shared_ptr<ConstFloat>>(
ir, "ConstFloat", "Constant float expression");
constfloat_class.def(
py::init<double, DataType, const Span&>(), py::arg("value"), py::arg("dtype"), py::arg("span"),
"Create a constant float expression");
BindFields<ConstFloat>(constfloat_class);
constfloat_class.def_property_readonly("dtype", &ConstFloat::dtype, "Data type of the expression");
auto constbool_class = py::class_<ConstBool, Expr, std::shared_ptr<ConstBool>>(
ir, "ConstBool", "Constant boolean expression");
constbool_class.def(
py::init<bool, const Span&>(), py::arg("value"), py::arg("span"), "Create a constant boolean expression");
BindFields<ConstBool>(constbool_class);
constbool_class.def_property_readonly("dtype", &ConstBool::dtype, "Data type of the expression (always BOOL)");
auto call_class =
py::class_<Call, Expr, std::shared_ptr<Call>>(ir, "Call", "Function call expression");
call_class.def(
py::init<std::string, const std::vector<ExprPtr>&, const Span&>(), py::arg("op"), py::arg("args"),
py::arg("span"), "Create a function call expression from op name string");
call_class.def(
py::init<std::string, const std::vector<ExprPtr>&, const TypePtr&, const Span&>(), py::arg("op"),
py::arg("args"), py::arg("type"), py::arg("span"), "Create a function call expression with explicit type");
call_class.def(
py::init([](const OpPtr& op, const std::vector<ExprPtr>& args, const Span& span) -> std::shared_ptr<Call> {
return std::make_shared<Call>(op->name_, args, span);
}),
py::arg("op"), py::arg("args"), py::arg("span"), "Create a function call expression");
call_class.def(
py::init(
[](const OpPtr& op, const std::vector<ExprPtr>& args, const TypePtr& type, const Span& span)
-> std::shared_ptr<Call> { return std::make_shared<Call>(op->name_, args, type, span); }),
py::arg("op"), py::arg("args"), py::arg("type"), py::arg("span"),
"Create a function call expression with explicit type");
call_class.def(
py::init(
[](const std::string& name, const std::vector<ExprPtr>& args, const py::dict& kwargs_dict,
const Span& span) -> std::shared_ptr<Call> {
auto kwargs = ConvertKwargsDict(kwargs_dict);
return std::make_shared<Call>(name, args, kwargs, span);
}),
py::arg("op"), py::arg("args"), py::arg("kwargs"), py::arg("span"),
"Create a function call expression with kwargs");
call_class.def(
py::init(
[](const std::string& name, const std::vector<ExprPtr>& args, const py::dict& kwargs_dict,
const TypePtr& type, const Span& span) -> std::shared_ptr<Call> {
auto kwargs = ConvertKwargsDict(kwargs_dict);
return std::make_shared<Call>(name, args, kwargs, type, span);
}),
py::arg("op"), py::arg("args"), py::arg("kwargs"), py::arg("type"), py::arg("span"),
"Create a function call expression with kwargs and explicit type");
call_class.def(
py::init(
[](const OpPtr& op, const std::vector<ExprPtr>& args, const py::dict& kwargs_dict,
const Span& span) -> std::shared_ptr<Call> {
auto kwargs = ConvertKwargsDict(kwargs_dict);
return std::make_shared<Call>(op->name_, args, kwargs, span);
}),
py::arg("op"), py::arg("args"), py::arg("kwargs"), py::arg("span"),
"Create a function call expression with kwargs");
call_class.def(
py::init(
[](const OpPtr& op, const std::vector<ExprPtr>& args, const py::dict& kwargs_dict, const TypePtr& type,
const Span& span) -> std::shared_ptr<Call> {
auto kwargs = ConvertKwargsDict(kwargs_dict);
return std::make_shared<Call>(op->name_, args, kwargs, type, span);
}),
py::arg("op"), py::arg("args"), py::arg("kwargs"), py::arg("type"), py::arg("span"),
"Create a function call expression with kwargs and explicit type");
BindFields<Call>(call_class);
call_class.def_property_readonly(
"kwargs",
[](const CallPtr& self) {
py::dict result;
for (const auto& [key, value] : self->kwargs_) {
if (value.type() == typeid(int)) {
result[key.c_str()] = AnyCast<int>(value, "converting to Python: " + key);
} else if (value.type() == typeid(int64_t)) {
result[key.c_str()] = AnyCast<int64_t>(value, "converting to Python: " + key);
} else if (value.type() == typeid(bool)) {
result[key.c_str()] = AnyCast<bool>(value, "converting to Python: " + key);
} else if (value.type() == typeid(std::string)) {
result[key.c_str()] = AnyCast<std::string>(value, "converting to Python: " + key);
} else if (value.type() == typeid(double)) {
result[key.c_str()] = AnyCast<double>(value, "converting to Python: " + key);
} else if (value.type() == typeid(float)) {
result[key.c_str()] = AnyCast<float>(value, "converting to Python: " + key);
} else if (value.type() == typeid(DataType)) {
result[key.c_str()] = AnyCast<DataType>(value, "converting to Python: " + key);
} else if (value.type() == typeid(TilePad)) {
result[key.c_str()] = AnyCast<TilePad>(value, "converting to Python: " + key);
}
}
return result;
},
"Keyword arguments (metadata) for this call");
auto make_tuple_class = py::class_<MakeTuple, Expr, std::shared_ptr<MakeTuple>>(
ir, "MakeTuple", "Tuple construction expression");
make_tuple_class.def(
py::init<const std::vector<ExprPtr>&, const Span&>(), py::arg("elements"), py::arg("span"),
"Create a tuple construction expression");
BindFields<MakeTuple>(make_tuple_class);
auto get_item_class = py::class_<GetItemExpr, Expr, std::shared_ptr<GetItemExpr>>(
ir, "GetItemExpr",
"Unified subscript expression: value[slice]. Dispatch on value's static type "
"(TupleType => tuple element access with ConstInt slice; TileType => tile element offset).");
get_item_class.def(
py::init<const ExprPtr&, const ExprPtr&, const Span&>(), py::arg("value"), py::arg("slice"), py::arg("span"),
"Create a subscript expression (value[slice])");
BindFields<GetItemExpr>(get_item_class);
auto binaryexpr_class = py::class_<BinaryExpr, Expr, std::shared_ptr<BinaryExpr>>(
ir, "BinaryExpr", "Base class for binary operations");
BindFields<BinaryExpr>(binaryexpr_class);
auto unaryexpr_class = py::class_<UnaryExpr, Expr, std::shared_ptr<UnaryExpr>>(
ir, "UnaryExpr", "Base class for unary operations");
BindFields<UnaryExpr>(unaryexpr_class);
#define BIND_BINARY_EXPR(OpName, Description) \
py::class_<OpName, BinaryExpr, std::shared_ptr<OpName>>(ir, #OpName, Description) \
.def( \
py::init<const ExprPtr&, const ExprPtr&, DataType, const Span&>(), py::arg("left"), py::arg("right"), \
py::arg("dtype"), py::arg("span"), "Create " Description)
BIND_BINARY_EXPR(Add, "Addition expression (left + right)");
BIND_BINARY_EXPR(Sub, "Subtraction expression (left - right)");
BIND_BINARY_EXPR(Mul, "Multiplication expression (left * right)");
BIND_BINARY_EXPR(FloorDiv, "Floor division expression (left // right)");
BIND_BINARY_EXPR(FloorMod, "Floor modulo expression (left % right)");
BIND_BINARY_EXPR(FloatDiv, "Float division expression (left / right)");
BIND_BINARY_EXPR(Min, "Minimum expression (min(left, right))");
BIND_BINARY_EXPR(Max, "Maximum expression (max(left, right))");
BIND_BINARY_EXPR(Pow, "Power expression (left ** right)");
BIND_BINARY_EXPR(Eq, "Equality expression (left == right)");
BIND_BINARY_EXPR(Ne, "Inequality expression (left != right)");
BIND_BINARY_EXPR(Lt, "Less than expression (left < right)");
BIND_BINARY_EXPR(Le, "Less than or equal to expression (left <= right)");
BIND_BINARY_EXPR(Gt, "Greater than expression (left > right)");
BIND_BINARY_EXPR(Ge, "Greater than or equal to expression (left >= right)");
BIND_BINARY_EXPR(And, "Logical and expression (left and right)");
BIND_BINARY_EXPR(Or, "Logical or expression (left or right)");
BIND_BINARY_EXPR(Xor, "Logical xor expression (left xor right)");
BIND_BINARY_EXPR(BitAnd, "Bitwise and expression (left & right)");
BIND_BINARY_EXPR(BitOr, "Bitwise or expression (left | right)");
BIND_BINARY_EXPR(BitXor, "Bitwise xor expression (left ^ right)");
BIND_BINARY_EXPR(BitShiftLeft, "Bitwise left shift expression (left << right)");
BIND_BINARY_EXPR(BitShiftRight, "Bitwise right shift expression (left >> right)");
#undef BIND_BINARY_EXPR
#define BIND_UNARY_EXPR(OpName, Description) \
py::class_<OpName, UnaryExpr, std::shared_ptr<OpName>>(ir, #OpName, Description) \
.def( \
py::init<const ExprPtr&, DataType, const Span&>(), py::arg("operand"), py::arg("dtype"), py::arg("span"), \
"Create " Description)
BIND_UNARY_EXPR(Abs, "Absolute value expression (abs(operand))");
BIND_UNARY_EXPR(Neg, "Negation expression (-operand)");
BIND_UNARY_EXPR(Not, "Logical not expression (not operand)");
BIND_UNARY_EXPR(BitNot, "Bitwise not expression (~operand)");
BIND_UNARY_EXPR(Cast, "Cast expression (cast operand to dtype)");
#undef BIND_UNARY_EXPR
ir.def(
"structural_hash", static_cast<uint64_t (*)(const IRNodePtr&, bool)>(&structural_hash_with_var_identity),
py::arg("node"), py::arg("enable_auto_mapping") = false,
"Compute deterministic structural hash of an IR node (ignores Span). "
"If enable_auto_mapping=True, variable names are ignored (e.g., x+1 and y+1 hash the same). "
"If enable_auto_mapping=False (default), different variable objects produce different hashes.");
ir.def(
"structural_hash", static_cast<uint64_t (*)(const TypePtr&, bool)>(&structural_hash_with_var_identity),
py::arg("type"), py::arg("enable_auto_mapping") = false,
"Compute deterministic structural hash of a type. "
"enable_auto_mapping only affects variables embedded in the type (e.g., shape expressions).");
ir.def(
"structural_equal", static_cast<bool (*)(const IRNodePtr&, const IRNodePtr&, bool)>(&structural_equal),
py::arg("lhs"), py::arg("rhs"), py::arg("enable_auto_mapping") = false,
"Check if two IR nodes are structurally equal. "
"Ignores source location (Span). Returns True if IR nodes have identical structure. "
"If enable_auto_mapping=True, automatically map variables (e.g., x+1 equals y+1). "
"If enable_auto_mapping=False (default), variable objects must be exactly the same (not just same "
"name).");
ir.def(
"structural_equal", static_cast<bool (*)(const TypePtr&, const TypePtr&, bool)>(&structural_equal),
py::arg("lhs"), py::arg("rhs"), py::arg("enable_auto_mapping") = false,
"Check if two types are structurally equal. "
"Ignores source location (Span). Returns True if types have identical structure. "
"If enable_auto_mapping=True, automatically map variables (e.g., x+1 equals y+1). "
"If enable_auto_mapping=False (default), variable objects must be exactly the same (not just same "
"name).");
ir.def(
"assert_structural_equal",
static_cast<void (*)(const IRNodePtr&, const IRNodePtr&, bool)>(&assert_structural_equal), py::arg("lhs"),
py::arg("rhs"), py::arg("enable_auto_mapping") = false,
"Assert two IR nodes are structurally equal. "
"Raises ValueError with detailed error message showing the first mismatch location if they differ. "
"Ignores source location (Span). "
"If enable_auto_mapping=True, automatically map variables (e.g., x+1 equals y+1). "
"If enable_auto_mapping=False (default), variable objects must be exactly the same (not just same "
"name).");
ir.def(
"assert_structural_equal",
static_cast<void (*)(const TypePtr&, const TypePtr&, bool)>(&assert_structural_equal), py::arg("lhs"),
py::arg("rhs"), py::arg("enable_auto_mapping") = false,
"Assert two types are structurally equal. "
"Raises ValueError with detailed error message showing the first mismatch location if they differ. "
"Ignores source location (Span). "
"If enable_auto_mapping=True, automatically map variables (e.g., x+1 equals y+1). "
"If enable_auto_mapping=False (default), variable objects must be exactly the same (not just same "
"name).");
}
void BindStmt(py::module_& ir)
{
auto stmt_class = py::class_<Stmt, IRNode, std::shared_ptr<Stmt>>(
ir, "Stmt", "Base class for all statements");
BindFields<Stmt>(stmt_class);
auto assign_stmt_class = py::class_<AssignStmt, Stmt, std::shared_ptr<AssignStmt>>(
ir, "AssignStmt", "Assignment statement: var = value");
assign_stmt_class.def(
py::init<const VarPtr&, const ExprPtr&, const Span&>(), py::arg("var"), py::arg("value"), py::arg("span"),
"Create an assignment statement");
BindFields<AssignStmt>(assign_stmt_class);
auto if_stmt_class = py::class_<IfStmt, Stmt, std::shared_ptr<IfStmt>>(
ir, "IfStmt", "Conditional statement: if condition then then_body else else_body");
if_stmt_class.def(
py::init<
const ExprPtr&, const StmtPtr&, const std::optional<StmtPtr>&, const std::vector<VarPtr>&, const Span&>(),
py::arg("condition"), py::arg("then_body"), py::arg("else_body") = py::none(), py::arg("return_vars"),
py::arg("span"), "Create a conditional statement with then and else branches (else_body can be None)");
BindFields<IfStmt>(if_stmt_class);
auto yield_stmt_class = py::class_<YieldStmt, Stmt, std::shared_ptr<YieldStmt>>(
ir, "YieldStmt", "Yield statement: yield value");
yield_stmt_class.def(
py::init<const std::vector<ExprPtr>&, const Span&>(), py::arg("value"), py::arg("span"),
"Create a yield statement with a list of expressions");
yield_stmt_class.def(py::init<const Span&>(), py::arg("span"), "Create a yield statement without values");
BindFields<YieldStmt>(yield_stmt_class);
auto return_stmt_class = py::class_<ReturnStmt, Stmt, std::shared_ptr<ReturnStmt>>(
ir, "ReturnStmt", "Return statement: return value");
return_stmt_class.def(
py::init<const std::vector<ExprPtr>&, const Span&>(), py::arg("value"), py::arg("span"),
"Create a return statement with a list of expressions");
return_stmt_class.def(py::init<const Span&>(), py::arg("span"), "Create a return statement without values");
BindFields<ReturnStmt>(return_stmt_class);
auto for_stmt_class = py::class_<ForStmt, Stmt, std::shared_ptr<ForStmt>>(
ir, "ForStmt", "For loop statement: for loop_var in range(start, stop, step): body");
for_stmt_class.def(
py::init<
const VarPtr&, const ExprPtr&, const ExprPtr&, const ExprPtr&, const std::vector<IterArgPtr>&,
const StmtPtr&, const std::vector<VarPtr>&, const Span&>(),
py::arg("loop_var"), py::arg("start"), py::arg("stop"), py::arg("step"), py::arg("iter_args"), py::arg("body"),
py::arg("return_vars"), py::arg("span"), "Create a for loop statement");
BindFields<ForStmt>(for_stmt_class);
auto while_stmt_class = py::class_<WhileStmt, Stmt, std::shared_ptr<WhileStmt>>(
ir, "WhileStmt", "While loop statement: while condition: body");
while_stmt_class.def(
py::init<
const ExprPtr&, const std::vector<IterArgPtr>&, const StmtPtr&, const std::vector<VarPtr>&, const Span&>(),
py::arg("condition"), py::arg("iter_args"), py::arg("body"), py::arg("return_vars"), py::arg("span"),
"Create a while loop statement");
BindFields<WhileStmt>(while_stmt_class);
py::enum_<SectionKind>(ir, "SectionKind", "Section kind classification")
.value("Vector", SectionKind::Vector, "Vector section for vector operations")
.value("Cube", SectionKind::Cube, "Cube section for cube operations")
.export_values();
auto section_stmt_class = py::class_<SectionStmt, Stmt, std::shared_ptr<SectionStmt>>(
ir, "SectionStmt",
"Section statement: marks a region with specific section context (Vector or Cube)");
section_stmt_class.def(
py::init<SectionKind, const StmtPtr&, const Span&>(), py::arg("section_kind"), py::arg("body"), py::arg("span"),
"Create a section statement");
BindFields<SectionStmt>(section_stmt_class);
auto seq_stmts_class = py::class_<SeqStmts, Stmt, std::shared_ptr<SeqStmts>>(
ir, "SeqStmts", "Sequence of statements: a sequence of statements");
seq_stmts_class.def(
py::init<const std::vector<StmtPtr>&, const Span&>(), py::arg("stmts"), py::arg("span"),
"Create a sequence of statements");
seq_stmts_class.def(py::init<const Span&>(), py::arg("span"), "Create a sequence of statements");
seq_stmts_class.def(
"__getitem__",
[](SeqStmtsPtr& self, int index) {
int size = static_cast<int>(self->stmts_.size());
if (index < -size || index >= size) {
throw IndexError(
"SeqStmts index " + std::to_string(index) + " out of range [" + std::to_string(-size) + ", " +
std::to_string(size - 1) + "]");
}
if (index < 0)
index += size;
return self->stmts_[index];
},
py::arg("index"), "Get statement by index, supports negative indexing");
BindFields<SeqStmts>(seq_stmts_class);
auto eval_stmt_class = py::class_<EvalStmt, Stmt, std::shared_ptr<EvalStmt>>(
ir, "EvalStmt", "Evaluation statement: expr");
eval_stmt_class.def(
py::init<const ExprPtr&, const Span&>(), py::arg("expr"), py::arg("span"), "Create an evaluation statement");
BindFields<EvalStmt>(eval_stmt_class);
auto break_stmt_class = py::class_<BreakStmt, Stmt, std::shared_ptr<BreakStmt>>(
ir, "BreakStmt", "Break statement: break");
break_stmt_class.def(py::init<const Span&>(), py::arg("span"), "Create a break statement");
break_stmt_class.def(
py::init<const std::vector<ExprPtr>&, const Span&>(), py::arg("operands"), py::arg("span"),
"Create a break statement with operands");
BindFields<BreakStmt>(break_stmt_class);
auto continue_stmt_class = py::class_<ContinueStmt, Stmt, std::shared_ptr<ContinueStmt>>(
ir, "ContinueStmt", "Continue statement: continue");
continue_stmt_class.def(py::init<const Span&>(), py::arg("span"), "Create a continue statement");
continue_stmt_class.def(
py::init<const std::vector<ExprPtr>&, const Span&>(), py::arg("operands"), py::arg("span"),
"Create a continue statement with operands");
BindFields<ContinueStmt>(continue_stmt_class);
auto scalarop_stmt_class = py::class_<ScalarOpStmt, Stmt, std::shared_ptr<ScalarOpStmt>>(
ir, "ScalarOpStmt",
"Scalar operation statement: result, result_token = opcode(args, tokens)");
scalarop_stmt_class.def(
py::init<VarPtr, VarPtr, std::string, const std::vector<ExprPtr>&, const Span&>(), py::arg("result"),
py::arg("result_token"), py::arg("opcode"), py::arg("args"), py::arg("span"),
"Create a scalar operation statement");
BindFields<ScalarOpStmt>(scalarop_stmt_class);
auto tensorop_stmt_class = py::class_<TensorOpStmt, Stmt, std::shared_ptr<TensorOpStmt>>(
ir, "TensorOpStmt",
"Tensor operation statement: results, result_token = opcode(args, attrs, tokens)");
tensorop_stmt_class.def(
py::init([](std::vector<VarPtr> results, VarPtr result_token, std::string opcode, std::vector<ExprPtr> args,
const std::vector<VarPtr>& tokens, py::dict attrs, Span span) {
auto attr_list = ConvertAttrDict(attrs);
return std::make_shared<TensorOpStmt>(results, result_token, opcode, args, tokens, attr_list, span);
}),
py::arg("results"), py::arg("result_token"), py::arg("opcode"), py::arg("args"), py::arg("tokens"),
py::arg("attrs"), py::arg("span"), "Create a tensor operation statement");
BindFields<TensorOpStmt>(tensorop_stmt_class);
}
void BindProgram(py::module_& ir)
{
py::enum_<FunctionType>(ir, "FunctionType", "Function type classification")
.value("Opaque", FunctionType::OPAQUE, "Unspecified function type (default)")
.value("Orchestration", FunctionType::ORCHESTRATION, "Host/AICPU control and coordination")
.value("InCore", FunctionType::IN_CORE, "AICore sub-graph execution")
.value("Helper", FunctionType::HELPER, "Scalar helper callable from kernels (generates func.call)")
.export_values();
auto function_class = py::class_<Function, IRNode, std::shared_ptr<Function>>(
ir, "Function", "Function definition with name, parameters, return types, and body");
function_class.def(
py::init(
[](const std::string& name, const py::list& params, const std::vector<TypePtr>& return_types,
const StmtPtr& body, const Span& span, FunctionType type) -> std::shared_ptr<Function> {
std::vector<VarPtr> param_vars;
param_vars.reserve(py::len(params));
for (auto item : params) {
param_vars.push_back(py::cast<VarPtr>(item));
}
return std::make_shared<Function>(name, std::move(param_vars), return_types, body, span, type);
}),
py::arg("name"), py::arg("params"), py::arg("return_types"), py::arg("body"), py::arg("span"),
py::arg("type") = FunctionType::OPAQUE, "Create a function definition");
BindFields<Function>(function_class);
auto program_class = py::class_<Program, IRNode, std::shared_ptr<Program>>(
ir, "Program",
"Program definition with functions mapped by name. "
"Functions are automatically sorted by name for deterministic ordering.");
program_class.def(
py::init<const std::vector<FunctionPtr>&, const std::string&, const Span&>(), py::arg("functions"),
py::arg("name"), py::arg("span"),
"Create a program from a list of functions. "
"Functions are keyed by their names automatically.");
program_class.def(
"get_function", &Program::GetFunction, py::arg("name"), "Get a function by name, returns None if not found");
program_class.def(
"__getitem__", [](const ProgramPtr& self, const std::string& name) { return self->GetFunction(name); },
py::arg("name"), "Get function by name (dict-like access), returns None if not found");
program_class.def_property_readonly(
"functions",
[](const std::shared_ptr<const Program>& self) {
py::dict result;
for (const auto& [name, func] : self->functions_) {
result[py::cast(name)] = py::cast(func);
}
return result;
},
"Map of function names to their corresponding functions, sorted by name");
program_class.def_readonly("name", &Program::name_, "Program name");
program_class.def_readonly("span", &Program::span_, "Source location");
ir.def(
"python_print", [](const IRNodePtr& node, const std::string& prefix) { return PythonDslPrint(node, prefix); },
py::arg("node"), py::arg("prefix") = "pl",
"Print IR node (Expr, Stmt, Function, or Program) in Python IR syntax.\n\n"
"Args:\n"
" node: IR node to print\n"
" prefix: Module prefix (default 'pl' for 'import pypto.language as pl')");
ir.def(
"python_print_type",
[](const TypePtr& type, const std::string& prefix) { return PythonDslPrint(type, prefix); }, py::arg("type"),
py::arg("prefix") = "pl",
"Print Type object in Python IR syntax.\n\n"
"Args:\n"
" type: Type to print\n"
" prefix: Module prefix (default 'pl' for 'import pypto.language as pl')");
}
void BindHelpers(py::module_& ir)
{
ir.def("add", &MakeAdd, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Addition operator");
ir.def("sub", &MakeSub, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Subtraction operator");
ir.def(
"mul", &MakeMul, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Multiplication operator");
ir.def(
"truediv", &MakeFloatDiv, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"True division operator");
ir.def(
"floordiv", &MakeFloorDiv, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Floor division operator");
ir.def("mod", &MakeFloorMod, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Modulo operator");
ir.def("pow", &MakePow, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Power operator");
ir.def("eq", &MakeEq, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Equality operator");
ir.def("ne", &MakeNe, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Inequality operator");
ir.def("lt", &MakeLt, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Less than operator");
ir.def(
"le", &MakeLe, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Less than or equal operator");
ir.def("gt", &MakeGt, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Greater than operator");
ir.def(
"ge", &MakeGe, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Greater than or equal operator");
ir.def("neg", &MakeNeg, py::arg("operand"), py::arg("span") = Span::Unknown(), "Negation operator");
ir.def("cast", &MakeCast, py::arg("operand"), py::arg("dtype"), py::arg("span") = Span::Unknown(), "Cast operator");
ir.def(
"bit_and", &MakeBitAnd, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Bitwise and operator");
ir.def(
"bit_or", &MakeBitOr, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Bitwise or operator");
ir.def(
"bit_xor", &MakeBitXor, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Bitwise xor operator");
ir.def(
"bit_shift_left", &MakeBitShiftLeft, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Bitwise left shift operator");
ir.def(
"bit_shift_right", &MakeBitShiftRight, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(),
"Bitwise right shift operator");
ir.def("bit_not", &MakeBitNot, py::arg("operand"), py::arg("span") = Span::Unknown(), "Bitwise not operator");
ir.def("min_", &MakeMin, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Minimum operator");
ir.def("max_", &MakeMax, py::arg("lhs"), py::arg("rhs"), py::arg("span") = Span::Unknown(), "Maximum operator");
ir.def(
"register_op_conversion",
[](const std::string& from_op, const std::string& to_op) {
OpConversionRegistry::GetInstance().RegisterSimple(from_op, to_op);
},
py::arg("from_op"), py::arg("to_op"),
"Register a simple tensor-to-block op name mapping.\n\n"
"Args:\n"
" from_op: Source op name (e.g., 'tensor.add')\n"
" to_op: Target op name (e.g., 'block.add')");
ir.def(
"register_op_conversion_custom",
[](const std::string& from_op, py::object func) {
OpConversionRegistry::GetInstance().RegisterCustom(
from_op,
[func](
const std::vector<ExprPtr>& args, const std::vector<std::pair<std::string, std::any>>& kwargs,
const Span& span) -> ConversionResult {
py::gil_scoped_acquire guard;
py::list py_kwargs_list;
for (const auto& [key, val] : kwargs) {
py::object py_val =
AnyToPyObject<DataType, MemorySpace, bool, int, std::string, double>(val, key);
py::tuple pair = py::make_tuple(py::cast(key), py_val);
py_kwargs_list.append(pair);
}
py::object result = func(py::cast(args), py_kwargs_list, py::cast(span));
if (py::isinstance<py::tuple>(result)) {
py::tuple result_tuple = py::cast<py::tuple>(result);
auto prologue = py::cast<std::vector<StmtPtr>>(result_tuple[0]);
auto expr = py::cast<ExprPtr>(result_tuple[1]);
return ConversionResult{std::move(prologue), std::move(expr)};
}
return ConversionResult{py::cast<ExprPtr>(result)};
});
},
py::arg("from_op"), py::arg("func"),
"Register a custom conversion function for a tensor op.\n\n"
"The function receives (args, kwargs, span) and should return either:\n"
"- An Expr (simple conversion)\n"
"- A tuple (list[Stmt], Expr) for complex conversions with prologue statements");
}
}
void BindIR(py::module_& m)
{
py::module_ ir_module = m.def_submodule("ir", "PyPTO IR (Intermediate Representation) module");
ir::BindDType(ir_module);
ir::BindSpan(ir_module);
ir::BindOp(ir_module);
ir::BindTypeClass(ir_module);
ir::BindExpr(ir_module);
ir::BindStmt(ir_module);
ir::BindProgram(ir_module);
ir::BindHelpers(ir_module);
ir::BindIRBuilder(m);
}
}