bl.alloc 接口文档

1. 背景

为了支持Ascend级编程的需要,需要支持用户手动创建指定地址空间上的内存(buffer),本接口是硬件无关的接口,对接memref.alloc。

2. 接口说明

Python
def alloc(
etype: tl.dtype,
shape: List[tl.constexpr],
_address_space: address_space = None,
is_mem_unique: bool = False,
_builder=None
) -> buffer:

3. 返回值

返回一个buffer language下的buffer类型,与triton language下的tensor做语义上的隔离,不支持相互赋值,需要to_tensor和to_buffer来显式转换;表示一段分配在指定地址空间的内存,携带数据类型、形状和地址空间三部分信息。

4. 入参

参数名 类型 必需 说明
type tl.dtype 数据类型/element type
shape List[tl.constexpr] buffer的形状
_address_space bl.address_space buffer所在的地址空间
is_mem_unique bool 是否独占内存,生成的annotation.mark在plan memory时会用到。默认为false

5. 昇腾平台数据类型支持

  int8 int16 int32 uint8 uint16 uint32 uint64 int64 fp16 fp32 fp64 bf16 bool
Ascend         ×  

6. 约束说明

  • dtype不支持tl.void

  • shape每个元素必须是正整数

  • 需自行保证符合指定的地址空间上的大小限制

  • address_space参数默认为空,表示不携带任何地址空间信息

7. 用例示例

Python
import os
import triton
import triton.language as tl
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
import triton.extension.buffer.language as bl
import triton.language.extra.cann.extension as al
from triton._C.libtriton import ir, buffer_ir
from triton._C.libtriton.ascend import ir as ascend_ir


os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
buffer_ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {})
return str(module)


# ============== Kernel definitions ==============


@triton.jit
def allocate_local_buffer(XBLOCK: tl.constexpr):
# this statement has no effect, just to test the builder
bl.alloc(tl.float32, [XBLOCK])
bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB)
bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L1)
bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A)
bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B)
bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C)
bl.alloc(
tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True
)


# ============== Main for manual testing ==============

if __name__ == "__main__":
print("=" * 60)
print("Test 1: Nested Scopes")
print("=" * 60)
mlir = compile_kernel(
allocate_local_buffer, {}, {"XBLOCK": 256}
)
print(f"✅ Generated MLIR ({len(mlir)} chars):\n")
print(mlir)

8. 编译输出结果

Plain Text
============================================================
Test 1: Nested Scopes
============================================================
✅ Generated MLIR (2103 chars):

module {
tt.func public @allocate_local_buffer() attributes {noinline = false} {
%alloc = memref.alloc() : memref<256xf32> loc(#loc1)
annotation.mark %alloc {effects = ["write", "read"]} : memref<256xf32> loc(#loc1)
%alloc_0 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc2)
annotation.mark %alloc_0 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc2)
%alloc_1 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cbuf>> loc(#loc3)
annotation.mark %alloc_1 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cbuf>> loc(#loc3)
%alloc_2 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ca>> loc(#loc4)
annotation.mark %alloc_2 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ca>> loc(#loc4)
%alloc_3 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cb>> loc(#loc5)
annotation.mark %alloc_3 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cb>> loc(#loc5)
%alloc_4 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cc>> loc(#loc6)
annotation.mark %alloc_4 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cc>> loc(#loc6)
%alloc_5 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7)
annotation.mark %alloc_5 {mem_unique} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7)
annotation.mark %alloc_5 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7)
tt.return loc(#loc8)
} loc(#loc)
} loc(#loc)
#loc = loc("/home/linxin/triton-test/alloc.py":41:0)
#loc1 = loc("/home/linxin/triton-test/alloc.py":43:25)
#loc2 = loc("/home/linxin/triton-test/alloc.py":44:43)
#loc3 = loc("/home/linxin/triton-test/alloc.py":45:43)
#loc4 = loc("/home/linxin/triton-test/alloc.py":46:43)
#loc5 = loc("/home/linxin/triton-test/alloc.py":47:43)
#loc6 = loc("/home/linxin/triton-test/alloc.py":48:43)
#loc7 = loc("/home/linxin/triton-test/alloc.py":50:38)
#loc8 = loc("/home/linxin/triton-test/alloc.py":49:4)