bl.alloc 接口文档
1. 背景
为了支持Ascend级编程的需要,需要支持用户手动创建指定地址空间上的内存(buffer),本接口是硬件无关的接口,对接memref.alloc。
2. 接口说明
| Python def alloc( etype: tl.dtype, shape: List[tl.constexpr], _address_space: address_space = None, is_mem_unique: bool = False, _builder=None ) -> buffer: |
3. 返回值
返回一个buffer language下的buffer类型,与triton language下的tensor做语义上的隔离,不支持相互赋值,需要to_tensor和to_buffer来显式转换;表示一段分配在指定地址空间的内存,携带数据类型、形状和地址空间三部分信息。
4. 入参
| 参数名 | 类型 | 必需 | 说明 |
| type | tl.dtype | 是 | 数据类型/element type |
| shape | List[tl.constexpr] | 是 | buffer的形状 |
| _address_space | bl.address_space | 否 | buffer所在的地址空间 |
| is_mem_unique | bool | 否 | 是否独占内存,生成的annotation.mark在plan memory时会用到。默认为false |
5. 昇腾平台数据类型支持
| int8 | int16 | int32 | uint8 | uint16 | uint32 | uint64 | int64 | fp16 | fp32 | fp64 | bf16 | bool | |
| Ascend | √ | √ | √ | √ | √ | √ | × | √ |
6. 约束说明
-
dtype不支持tl.void
-
shape每个元素必须是正整数
-
需自行保证符合指定的地址空间上的大小限制
-
address_space参数默认为空,表示不携带任何地址空间信息
7. 用例示例
| Python import os import triton import triton.language as tl from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir import triton.extension.buffer.language as bl import triton.language.extra.cann.extension as al from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, {}) return str(module) # ============== Kernel definitions ============== @triton.jit def allocate_local_buffer(XBLOCK: tl.constexpr): # this statement has no effect, just to test the builder bl.alloc(tl.float32, [XBLOCK]) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L1) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B) bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C) bl.alloc( tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB, is_mem_unique=True ) # ============== Main for manual testing ============== if __name__ == "__main__": print("=" * 60) print("Test 1: Nested Scopes") print("=" * 60) mlir = compile_kernel( allocate_local_buffer, {}, {"XBLOCK": 256} ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) |
8. 编译输出结果
| Plain Text ============================================================ Test 1: Nested Scopes ============================================================ ✅ Generated MLIR (2103 chars): module { tt.func public @allocate_local_buffer() attributes {noinline = false} { %alloc = memref.alloc() : memref<256xf32> loc(#loc1) annotation.mark %alloc {effects = ["write", "read"]} : memref<256xf32> loc(#loc1) %alloc_0 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc2) annotation.mark %alloc_0 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc2) %alloc_1 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cbuf>> loc(#loc3) annotation.mark %alloc_1 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cbuf>> loc(#loc3) %alloc_2 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ca>> loc(#loc4) annotation.mark %alloc_2 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ca>> loc(#loc4) %alloc_3 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cb>> loc(#loc5) annotation.mark %alloc_3 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cb>> loc(#loc5) %alloc_4 = memref.alloc() : memref<256x256xf32, #hivm.address_space<cc>> loc(#loc6) annotation.mark %alloc_4 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<cc>> loc(#loc6) %alloc_5 = memref.alloc() : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7) annotation.mark %alloc_5 {mem_unique} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7) annotation.mark %alloc_5 {effects = ["write", "read"]} : memref<256x256xf32, #hivm.address_space<ub>> loc(#loc7) tt.return loc(#loc8) } loc(#loc) } loc(#loc) #loc = loc("/home/linxin/triton-test/alloc.py":41:0) #loc1 = loc("/home/linxin/triton-test/alloc.py":43:25) #loc2 = loc("/home/linxin/triton-test/alloc.py":44:43) #loc3 = loc("/home/linxin/triton-test/alloc.py":45:43) #loc4 = loc("/home/linxin/triton-test/alloc.py":46:43) #loc5 = loc("/home/linxin/triton-test/alloc.py":47:43) #loc6 = loc("/home/linxin/triton-test/alloc.py":48:43) #loc7 = loc("/home/linxin/triton-test/alloc.py":50:38) #loc8 = loc("/home/linxin/triton-test/alloc.py":49:4) |