import os
import pytest
import triton
import triton.language as tl
import triton.extension.buffer.language as bl
import triton.language.extra.cann.extension as al
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
from triton._C.libtriton import ir, buffer_ir
from triton._C.libtriton.ascend import ir as ascend_ir
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False
arch = "Ascend910_95"
def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
buffer_ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {}, {})
return str(module)
@triton.jit
def copy(
A_ptr,
A1_ptr,
M: tl.constexpr,
N: tl.constexpr,
):
offs_a = tl.arange(0, M)[:, None]
offs_b = tl.arange(0, N)[None, :]
offs_c = (offs_a) * M + (offs_b)
a_ptr = A_ptr + offs_c
a_val = tl.load(a_ptr)
a1_ptr = A1_ptr + offs_c
a1_val = tl.load(a1_ptr)
add = tl.add(a_val, a1_val)
add_ub = bl.to_buffer(add, al.ascend_address_space.UB)
A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1)
al.copy_from_ub_to_l1(add_ub, A_l1)
A_ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB)
al.copy(add_ub, A_ub)
def test_copy():
print("=" * 60)
print("Test 1: copy ")
print("=" * 60)
mlir = compile_kernel(
copy,
{"A_ptr": "*fp32", "A1_ptr": "*fp32"},
{"M": 16, "N": 16},
)
print(f"✅ Generated MLIR ({len(mlir)} chars):\n")
print(mlir)
if __name__ == "__main__":
test_copy()