import triton
import triton.language as tl
import torch
import math
@triton.jit
def add_helper(x, y):
return x + y
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = add_helper(x, y)
tl.store(out_ptr + offsets, output, mask=mask)
def test_module_walk(device):
"""
Test the MLIR bindings exposed for the out-of-tree walk.
"""
def walk_fn(op):
name = op.get_name()
for i in range(op.get_num_results()):
op.get_result(i).id()
for i in range(op.get_num_operands()):
op.get_operand(i).id()
for i in range(op.get_num_regions()):
op.get_region(i).id()
block = op.get_block()
if block is not None:
block.id()
for i in range(block.get_num_arguments()):
block.get_argument(i)
if name == "tt.func":
op.get_str_attr("sym_name")
if name == "tt.call":
op.get_flat_symbol_ref_attr("callee")
kernel = add_kernel
args = [
torch.empty((32, 32), device=device),
torch.empty((32, 32), device=device),
1024,
torch.empty((32, 32), device=device),
16,
]
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={kernel.arg_names[i]: triton.runtime.jit.mangle_type(arg)
for i, arg in enumerate(args)},
constexprs={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
)
context = triton._C.libtriton.ir.context()
options = backend.parse_options(dict())
codegen_fns = dict()
module_map = backend.get_module_map()
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
ttir_module.walk(walk_fn)
def test_python_func_in_visit_call(device):
@triton.jit
def test_py_call_const_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
log2e: tl.constexpr = math.log2(math.e)
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = x * log2e
tl.store(out_ptr + offsets, output, mask=mask)
x = torch.randn(4, device=device)
out = torch.zeros_like(x)
test_py_call_const_kernel[(4, )](x, out, 4, 4)