bl.to_tensor 接口文档

1. 硬件背景

将Ascend上分配的buffer转成tl.tensor并返回

2. 接口说明

Python
def to_tensor(memref: bl.buffer, writable: bool = True, _builder=None) -> tl.tensor:

入参说明

参数名 类型 必需 说明
memref bl.buffer 输入bl.buffer对象
writable bool 如果设置成True, 返回的tensor在bufferization过程中允许被原地修改,默认为True
_builder - 内部参数 编译器自动传参,用户无需使用

3. 约束说明

该接口约束同bl.allocate_local_buffer

4. 用例示例

Python
import os

import triton
import triton.language as tl
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
import triton.extension.buffer.language as bl
import triton.language.extra.cann.extension as al
from triton._C.libtriton import ir, buffer_ir
from triton._C.libtriton.ascend import ir as ascend_ir

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
buffer_ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {})
return str(module)


# ============== Kernel definitions ==============

@triton.jit
def kernel_func(XBLOCK: tl.constexpr):
buffer1 = bl.alloc(tl.float32, [XBLOCK])
buffer1.to_tensor(writable=True)
buffer2 = bl.alloc(tl.float32, [XBLOCK])
bl.to_tensor(buffer2, writable=True)


# ============== Main for manual testing ==============

if __name__ == "__main__":
print("=" * 60)
print("Test 1: Nested Scopes")
print("=" * 60)
mlir = compile_kernel(
kernel_func, {}, {"XBLOCK": 256}
)
print(f"✅ Generated MLIR ({len(mlir)} chars):\n")
print(mlir)
triton.compile(src=src, target=target)

5. 编译输出结果

Plain Text
============================================================
Test 1: Nested Scopes
============================================================
✅ Generated MLIR (941 chars):

module {
tt.func public @kernel_func() attributes {noinline = false} {
%alloc = memref.alloc() : memref<256xf32> loc(#loc1)
annotation.mark %alloc {effects = ["write", "read"]} : memref<256xf32> loc(#loc1)
%0 = bufferization.to_tensor %alloc restrict writable : memref<256xf32> loc(#loc2)
%alloc_0 = memref.alloc() : memref<256xf32> loc(#loc3)
annotation.mark %alloc_0 {effects = ["write", "read"]} : memref<256xf32> loc(#loc3)
%1 = bufferization.to_tensor %alloc_0 restrict writable : memref<256xf32> loc(#loc4)
tt.return loc(#loc5)
} loc(#loc)
} loc(#loc)
#loc = loc("/home/linxin/triton-test/to_tensor.py":38:0)
#loc1 = loc("/home/linxin/triton-test/to_tensor.py":39:35)
#loc2 = loc("/home/linxin/triton-test/to_tensor.py":40:22)
#loc3 = loc("/home/linxin/triton-test/to_tensor.py":41:35)
#loc4 = loc("/home/linxin/triton-test/to_tensor.py":42:17)
#loc5 = loc("/home/linxin/triton-test/to_tensor.py":42:4)