* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <utility>
#include <vector>
#include "bindings.h"
#include "interface/tensor/irbuilder.h"
#include "ir/kind_traits.h"
#include "ir/scalar_expr.h"
#include "ir/transforms/structural_comparison.h"
namespace npu::tile_fwk {
void BindIRBuilder(py::module_& m)
{
py::module_ ir = m.attr("ir");
py::class_<IRBuilder>(
ir, "IRBuilder",
"IR Builder for incremental IR construction with context management.\n\n"
"The IRBuilder provides a stateful API for building IR incrementally using\n"
"Begin/End patterns. It maintains a context stack to track nested scopes\n"
"and validates proper construction.")
.def(py::init<>(), "Create a new IR builder")
.def(
"begin_function",
[](IRBuilder& self, const std::string& name, const ir::Span& span, ir::FunctionType type) {
self.BeginFunction(name, span, type);
},
py::arg("name"), py::arg("span"), py::arg("type") = ir::FunctionType::OPAQUE,
"Begin building a function.\n\n"
"Creates a new function context. Must be closed with end_function().\n\n"
"Args:\n"
" name: Function name\n"
" span: Source location for function definition\n"
" type: Function type (default: Opaque)\n\n"
"Raises:\n"
" RuntimeError: If already inside a function (nested functions not allowed)")
.def(
"func_arg", &IRBuilder::FuncArg, py::arg("name"), py::arg("type"), py::arg("span"),
"Add a function parameter.\n\n"
"Must be called within a function context.\n\n"
"Args:\n"
" name: Parameter name\n"
" type: Parameter type\n"
" span: Source location for parameter\n\n"
"Returns:\n"
" Var: Variable representing the parameter\n\n"
"Raises:\n"
" RuntimeError: If not inside a function context")
.def(
"return_type", &IRBuilder::ReturnType, py::arg("type"),
"Add a return type to the current function.\n\n"
"Can be called multiple times for multiple return types.\n\n"
"Args:\n"
" type: Return type\n\n"
"Raises:\n"
" RuntimeError: If not inside a function context")
.def(
"end_function", &IRBuilder::EndFunction, py::arg("end_span"),
"End building a function.\n\n"
"Finalizes the function and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of function\n\n"
"Returns:\n"
" Function: The built function\n\n"
"Raises:\n"
" RuntimeError: If not inside a function context")
.def(
"begin_for_loop",
[](IRBuilder& self, const ir::VarPtr& loop_var, const ir::ExprPtr& start, const ir::ExprPtr& stop,
const ir::ExprPtr& step, const ir::Span& span) { self.BeginForLoop(loop_var, start, stop, step, span); },
py::arg("loop_var"), py::arg("start"), py::arg("stop"), py::arg("step"), py::arg("span"),
"Begin building a for loop.\n\n"
"Creates a new for loop context. Must be closed with end_for_loop().\n\n"
"Args:\n"
" loop_var: Loop variable\n"
" start: Start value expression\n"
" stop: Stop value expression\n"
" step: Step value expression\n"
" span: Source location for loop definition\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"add_iter_arg", &IRBuilder::AddIterArg, py::arg("iter_arg"),
"Add an iteration argument to the current for loop.\n\n"
"Iteration arguments are loop-carried values (SSA-style).\n\n"
"Args:\n"
" iter_arg: Iteration argument with initial value\n\n"
"Raises:\n"
" RuntimeError: If not inside a for loop context")
.def(
"add_return_var", &IRBuilder::AddReturnVar, py::arg("var"),
"Add a return variable to the current for loop.\n\n"
"Return variables capture the final values of iteration arguments.\n"
"Must match the number of iteration arguments.\n\n"
"Args:\n"
" var: Return variable\n\n"
"Raises:\n"
" RuntimeError: If not inside a for loop context")
.def(
"end_for_loop", &IRBuilder::EndForLoop, py::arg("end_span"),
"End building a for loop.\n\n"
"Finalizes the loop and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of loop\n\n"
"Returns:\n"
" ForStmt: The built for statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a for loop context\n"
" RuntimeError: If number of return variables doesn't match iteration arguments")
.def(
"begin_while_loop", &IRBuilder::BeginWhileLoop, py::arg("condition"), py::arg("span"),
"Begin building a while loop.\n\n"
"Creates a new while loop context. Must be closed with end_while_loop().\n\n"
"Args:\n"
" condition: Condition expression\n"
" span: Source location for loop definition\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"add_while_iter_arg", &IRBuilder::AddWhileIterArg, py::arg("iter_arg"),
"Add an iteration argument to the current while loop.\n\n"
"Iteration arguments are loop-carried values (SSA-style).\n\n"
"Args:\n"
" iter_arg: Iteration argument with initial value\n\n"
"Raises:\n"
" RuntimeError: If not inside a while loop context")
.def(
"add_while_return_var", &IRBuilder::AddWhileReturnVar, py::arg("var"),
"Add a return variable to the current while loop.\n\n"
"Return variables capture the final values of iteration arguments.\n\n"
"Args:\n"
" var: Return variable\n\n"
"Raises:\n"
" RuntimeError: If not inside a while loop context")
.def(
"set_while_loop_condition", &IRBuilder::SetWhileLoopCondition, py::arg("condition"),
"Set the condition for the current while loop.\n\n"
"Used to update the loop condition after setting up iter_args. This allows\n"
"the condition to reference iter_arg variables that are defined in the loop.\n\n"
"Args:\n"
" condition: New condition expression\n\n"
"Raises:\n"
" RuntimeError: If not inside a while loop context")
.def(
"end_while_loop", &IRBuilder::EndWhileLoop, py::arg("end_span"),
"End building a while loop.\n\n"
"Finalizes the loop and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of loop\n\n"
"Returns:\n"
" WhileStmt: The built while statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a while loop context\n"
" RuntimeError: If number of return variables doesn't match iteration arguments")
.def(
"begin_if", &IRBuilder::BeginIf, py::arg("condition"), py::arg("span"),
"Begin building an if statement.\n\n"
"Creates a new if context. Must be closed with end_if().\n\n"
"Args:\n"
" condition: Condition expression\n"
" span: Source location for if statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"begin_else", &IRBuilder::BeginElse, py::arg("span"),
"Begin the else branch of the current if statement.\n\n"
"Must be called after building the then branch.\n\n"
"Args:\n"
" span: Source location for else keyword\n\n"
"Raises:\n"
" RuntimeError: If not inside an if context\n"
" RuntimeError: If else branch already begun")
.def(
"add_if_return_var", &IRBuilder::AddIfReturnVar, py::arg("var"),
"Add a return variable to the current if statement.\n\n"
"Return variables are used for SSA phi nodes.\n\n"
"Args:\n"
" var: Return variable\n\n"
"Raises:\n"
" RuntimeError: If not inside an if context")
.def(
"end_if", &IRBuilder::EndIf, py::arg("end_span"),
"End building an if statement.\n\n"
"Finalizes the if statement and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of if\n\n"
"Returns:\n"
" IfStmt: The built if statement\n\n"
"Raises:\n"
" RuntimeError: If not inside an if context")
.def(
"begin_section", &IRBuilder::BeginSection, py::arg("section_kind"), py::arg("span"),
"Begin building a section statement.\n\n"
"Creates a new section context. Must be closed with end_section().\n\n"
"Args:\n"
" section_kind: The kind of section (Vector or Cube)\n"
" span: Source location for section statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a function or loop")
.def(
"end_section", &IRBuilder::EndSection, py::arg("end_span"),
"End building a section statement.\n\n"
"Finalizes the section statement and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of section\n\n"
"Returns:\n"
" SectionStmt: The built section statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a section context")
.def(
"emit", &IRBuilder::Emit, py::arg("stmt"),
"Emit a statement in the current context.\n\n"
"Adds a statement to the current context's statement list.\n\n"
"Args:\n"
" stmt: Statement to emit\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"assign", &IRBuilder::Assign, py::arg("var"), py::arg("value"), py::arg("span"),
"Create an assignment statement and emit it.\n\n"
"Convenience method that creates and emits an assignment.\n\n"
"Args:\n"
" var: Variable to assign to\n"
" value: Expression value\n"
" span: Source location for assignment\n\n"
"Returns:\n"
" AssignStmt: The created assignment statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"var", &IRBuilder::Var, py::arg("name"), py::arg("type"), py::arg("span"),
"Create a variable (does not emit).\n\n"
"Helper to create a variable. User must create assignment separately.\n\n"
"Args:\n"
" name: Variable name\n"
" type: Variable type\n"
" span: Source location\n\n"
"Returns:\n"
" Var: The created variable")
.def(
"return_", py::overload_cast<const std::vector<ir::ExprPtr>&, const ir::Span&>(&IRBuilder::Return),
py::arg("values"), py::arg("span"),
"Create a return statement and emit it.\n\n"
"Convenience method that creates and emits a return statement.\n\n"
"Args:\n"
" values: List of expressions to return\n"
" span: Source location for return statement\n\n"
"Returns:\n"
" ReturnStmt: The created return statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"return_", py::overload_cast<const ir::Span&>(&IRBuilder::Return), py::arg("span"),
"Create an empty return statement and emit it.\n\n"
"Convenience method that creates and emits an empty return statement.\n\n"
"Args:\n"
" span: Source location for return statement\n\n"
"Returns:\n"
" ReturnStmt: The created return statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"break_", &IRBuilder::Break, py::arg("span"),
"Create a break statement and emit it.\n\n"
"Convenience method that creates and emits a break statement.\n\n"
"Args:\n"
" span: Source location for break statement\n\n"
"Returns:\n"
" BreakStmt: The created break statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"continue_", &IRBuilder::Continue, py::arg("span"),
"Create a continue statement and emit it.\n\n"
"Convenience method that creates and emits a continue statement.\n\n"
"Args:\n"
" span: Source location for continue statement\n\n"
"Returns:\n"
" ContinueStmt: The created continue statement\n\n"
"Raises:\n"
" RuntimeError: If not inside a valid context")
.def(
"in_function", &IRBuilder::InFunction,
"Check if currently inside a function.\n\n"
"Returns:\n"
" bool: True if inside a function context")
.def(
"in_loop", &IRBuilder::InLoop,
"Check if currently inside a for loop.\n\n"
"Returns:\n"
" bool: True if inside a for loop context")
.def(
"in_if", &IRBuilder::InIf,
"Check if currently inside an if statement.\n\n"
"Returns:\n"
" bool: True if inside an if statement context")
.def(
"in_program", &IRBuilder::InProgram,
"Check if currently inside a program.\n\n"
"Returns:\n"
" bool: True if inside a program context")
.def(
"begin_program", &IRBuilder::BeginProgram, py::arg("name"), py::arg("span"),
"Begin building a program.\n\n"
"Creates a new program context. Must be closed with end_program().\n\n"
"Args:\n"
" name: Program name\n"
" span: Source location for program definition\n\n"
"Raises:\n"
" RuntimeError: If already inside another program")
.def(
"add_function", &IRBuilder::AddFunction, py::arg("func"),
"Add a completed function to the current program.\n\n"
"Args:\n"
" func: Completed function to add\n\n"
"Raises:\n"
" RuntimeError: If not inside a program context")
.def(
"end_program", &IRBuilder::EndProgram, py::arg("end_span"),
"End building a program.\n\n"
"Finalizes the program and returns it.\n\n"
"Args:\n"
" end_span: Source location for end of program\n\n"
"Returns:\n"
" Program: The built program\n\n"
"Raises:\n"
" RuntimeError: If not inside a program context")
.def(
"get_function_return_types", &IRBuilder::GetFunctionReturnTypes, py::arg("func_name"),
"Get return types for a function by name.\n\n"
"Returns the return types for a function if it has been added to the program.\n"
"Returns empty list if not inside a program or function not yet added.\n\n"
"Args:\n"
" func_name: Function name\n\n"
"Returns:\n"
" List[Type]: Vector of return types")
.def(
"create_tensor_var",
[](IRBuilder& self, DataType t, std::vector<SymbolicScalar> shape, TileOpFormat format, std::string name) {
auto staticShape = SymbolicScalar::Concrete(shape, -1);
return self.CreateTensorVar(t, std::move(staticShape), std::move(shape), format, std::move(name));
},
py::arg("dtype"), py::arg("shape"), py::arg("format") = TileOpFormat::TILEOP_ND, py::arg("name") = "",
"Create a tensor variable.\n\n"
"Args:\n"
" dtype: Data type of the tensor\n"
" shape: Shape of the tensor\n"
" format: Data format of the tensor\n"
" name: Name of the tensor\n\n"
"Returns:\n"
" LogicalTensor: The created tensor variable")
.def(
"create_const_int", &IRBuilder::CreateConstInt, py::arg("value"),
"Create a constant integer scalar.\n\n"
"Args:\n"
" value: Integer value\n\n"
"Returns:\n"
" SymbolicScalar: The created constant integer scalar")
.def(
"create_scalar_var", &IRBuilder::CreateScalarVar, py::arg("sym") = "",
"Create a scalar variable.\n\n"
" sym: Symbol of the variable\n\n"
"Returns:\n"
" SymbolicScalar: The created scalar variable")
.def(
"create_tensor_op_stmt",
[](IRBuilder& self, std::vector<ir::VarPtr> result, ir::VarPtr result_token, std::string opcode,
std::vector<ir::ExprPtr> args, std::vector<ir::VarPtr> tokens, py::dict attrs, ir::Span span) {
auto attr_list = pypto::ir::ConvertAttrDict(attrs);
return self.CreateTensorOpStmt(result, result_token, opcode, args, tokens, attr_list, span);
},
py::arg("result"), py::arg("result_token"), py::arg("opcode"), py::arg("args"), py::arg("tokens"),
py::arg("attrs"), py::arg("span"),
"Create a tensor operation statement.\n\n"
"Args:\n"
" result: Result of the operation\n"
" result_token: Result token of the operation\n"
" opcode: Opcode of the operation\n"
" args: Arguments of the operation\n"
" tokens: Tokens of the operation\n"
" attrs: Attributes of the operation\n"
" span: Span of the operation\n\n"
"Returns:\n"
" TensorOpStmt: The created tensor operation statement")
.def(
"create_assign_stmt", &IRBuilder::CreateAssignStmt, py::arg("var"), py::arg("value"), py::arg("span"),
"Create an assignment statement.\n\n"
"Args:\n"
" var: Variable of the assignment\n"
" value: Right-hand side of the assignment\n"
" span: Span of the assignment\n\n"
"Returns:\n"
" AssignStmt: The created assignment statement")
.def(
"create_seq_stmts", &IRBuilder::CreateSeqStmts, py::arg("stmts"), py::arg("span"),
"Create a sequence of statements.\n\n"
"Args:\n"
" stmts: Statements to create\n"
" span: Span of the sequence\n\n"
"Returns:\n"
" SeqStmts: The created sequence of statements")
.def(
"create_if_stmt", &IRBuilder::CreateIfStmt, py::arg("cond"), py::arg("then_body"), py::arg("else_body"),
py::arg("return_vars"), py::arg("span"),
"Create an if statement.\n\n"
"Args:\n"
" cond: Condition of the if statement\n"
" then_body: Then body of the if statement\n"
" else_body: Else body of the if statement\n"
" return_vars: Return variables of the if statement\n"
" span: Span of the if statement\n\n"
"Returns:\n"
" IfStmt: The created if statement")
.def(
"create_return_stmt", &IRBuilder::CreateReturnStmt, py::arg("return_vars"), py::arg("span"),
"Create a return statement.\n\n"
"Args:\n"
" return_vars: Return variables of the return statement\n"
" span: Span of the return statement\n\n"
"Returns:\n"
" ReturnStmt: The created return statement")
.def(
"create_yield_stmt", &IRBuilder::CreateYieldStmt, py::arg("return_vars"), py::arg("span"),
"Create a yield statement.\n\n"
"Args:\n"
" return_vars: Return variables of the yield statement\n"
" span: Span of the yield statement\n\n"
"Returns:\n"
" YieldStmt: The created yield statement")
.def(
"create_for_stmt", &IRBuilder::CreateForStmt, py::arg("loopVar"), py::arg("start"), py::arg("stop"),
py::arg("step"), py::arg("iterArgs"), py::arg("body"), py::arg("returnVars"), py::arg("span"),
"Create a for statement.\n\n"
"Args:\n"
" loopVar: Loop variable of the for statement\n"
" start: Start value of the loop variable\n"
" stop: Stop value of the loop variable\n"
" step: Step value of the loop variable\n"
" body: Body of the for statement\n"
" returnVars: Return variables of the for statement\n"
" span: Span of the for statement\n\n"
"Returns:\n"
" ForStmt: The created for statement")
.def(
"create_while_stmt", &IRBuilder::CreateWhileStmt, py::arg("cond"), py::arg("iterArgs"), py::arg("body"),
py::arg("returnVars"), py::arg("span"),
"Create a while statement.\n\n"
"Args:\n"
" cond: Condition of the while statement\n"
" iterArgs: Iteration arguments of the loop\n"
" body: Body of the while statement\n"
" returnVars: Return variables of the while statement\n"
" span: Span of the while statement\n\n"
"Returns:\n"
" WhileStmt: The created while statement")
.def(
"create_break_stmt", &IRBuilder::CreateBreakStmt, py::arg("return_vars"), py::arg("span"),
"Create a break statement.\n\n"
"Args:\n"
" return_vars: Return variables of the break statement\n"
" span: Span of the break statement\n\n"
"Returns:\n"
" BreakStmt: The created break statement")
.def(
"create_continue_stmt", &IRBuilder::CreateContinueStmt, py::arg("return_vars"), py::arg("span"),
"Create a continue statement.\n\n"
"Args:\n"
" return_vars: Return variables of the continue statement\n"
" span: Span of the continue statement\n\n"
"Returns:\n"
" ContinueStmt: The created continue statement")
.def(
"create_function",
[](IRBuilder& self, std::string name, std::vector<ir::VarPtr> params, std::vector<ir::TypePtr> returnTypes,
ir::StmtPtr body, ir::Span span) { return self.CreateFunction(name, params, returnTypes, body, span); },
py::arg("name"), py::arg("params"), py::arg("returnTypes"), py::arg("body"), py::arg("span"),
"Create a function.\n\n"
"Args:\n"
" name: Name of the function\n"
" params: Parameters of the function\n"
" returnTypes: Return types of the function\n"
" body: Body of the function\n"
" span: Span of the function\n\n"
"Returns:\n"
" Function: The created function")
.def(
"create_program", &IRBuilder::CreateProgram, py::arg("functions"), py::arg("name"), py::arg("span"),
"Create a program.\n\n"
"Args:\n"
" functions: Functions of the program\n"
" name: Name of the program\n"
" span: Span of the program\n\n"
"Returns:\n"
" Program: The created program")
.def(
"create_var_like", &IRBuilder::CreateVarLike, py::arg("name"), py::arg("value"),
"Create a variable like the given value.\n\n"
"Args:\n"
" name: Name of the variable\n"
" value: Value to base the variable on\n"
"Returns:\n"
" Var: The created variable")
.def(
"create_iter_arg",
[](IRBuilder& self, ir::VarPtr var, ir::ExprPtr initValue) { return self.CreateIterArg(var, initValue); },
py::arg("var"), py::arg("initValue"),
"Create an iteration argument.\n\n"
"Args:\n"
" var: Variable of the iteration argument\n"
" initValue: Initial value of the iteration argument\n"
"Returns:\n"
" IterArg: The created iteration argument")
.def(
"create_iter_arg",
[](IRBuilder& self, std::string name, ir::TypePtr type, ir::ExprPtr initValue, ir::Span span) {
return self.CreateIterArg(name, type, initValue, span);
},
py::arg("name"), py::arg("type"), py::arg("initValue"), py::arg("span"),
"Create an iteration argument.\n\n"
"Args:\n"
" name: Name of the iteration argument\n"
" type: Type of the iteration argument\n"
" initValue: Initial value of the iteration argument\n"
" span: Span of the iteration argument\n\n"
"Returns:\n"
" IterArg: The created iteration argument")
.def("emit_tensor_stmts", &IRBuilder::EmitTensorStmts, "Emit tensor statements.\n\n")
.def("none", &IRBuilder::None, "Get the None value.\n\n")
.def(
"set_insert_point", &IRBuilder::SetInsertPoint, py::arg("point"),
"Set insert point.\n\n"
"Args:\n"
" point: Insert point\n"
"Returns:\n"
" None")
.def("clear_insert_point", &IRBuilder::ClearInsertPoint, "Clear insert point.\n\n");
py::class_<ir::InsertPoint, std::shared_ptr<ir::InsertPoint>>(ir, "InsertPoint").def(py::init<ir::SeqStmtsPtr>());
py::class_<Operation, std::shared_ptr<Operation>, ir::TensorOpStmt>(ir, "Operation")
.def("__str__", [](Operation& self) { return self.Dump(); });
py::class_<RawSymbolicExpression, std::shared_ptr<RawSymbolicExpression>, ir::Expr>(ir, "SymbolicExpression")
.def("__str__", [](RawSymbolicExpression& self) { return self.Dump(); });
py::class_<RawSymbolicSymbol, std::shared_ptr<RawSymbolicSymbol>, ir::Var>(ir, "SymbolicSymbol")
.def("__str__", [](RawSymbolicSymbol& self) { return self.Dump(); });
py::class_<RawSymbolicImmediate, std::shared_ptr<RawSymbolicImmediate>, ir::ConstInt>(ir, "SymbolicImmediate")
.def("__str__", [](RawSymbolicImmediate& self) { return self.Dump(); });
ir.def("type_equal", [](ir::ExprPtr a, ir::ExprPtr b) {
bool ret = ir::structural_equal(a->GetType(), b->GetType());
if (ret) {
if (auto type = ir::As<ir::LogicalTensorType>(b->GetType())) {
auto ta = std::dynamic_pointer_cast<LogicalTensor>(std::const_pointer_cast<ir::Expr>(a));
auto tb = std::dynamic_pointer_cast<LogicalTensor>(std::const_pointer_cast<ir::Expr>(b));
return TypeEqual(ta, tb);
}
}
return ret;
});
}
}
namespace pypto::ir {
void BindIRBuilder(py::module_& m) { npu::tile_fwk::BindIRBuilder(m); }
}