al.sync_block_wait 接口文档
1. 硬件背景
面向分离模式的核间同步控制接口。
该接口和 sync_block_set 接口配合使用。使用时需传入核间同步的标记ID(flagId),每个ID对应一个初始值为0的计数器。执行CrossCoreSetFlag后ID对应的计数器增加1;执行CrossCoreWaitFlag时如果对应的计数器数值为0则阻塞不执行;如果对应的计数器大于0,则计数器减一,同时后续指令开始执行。
2. 接口说明
| Python def sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): class PIPE(enum.Enum): PIPE_S = ascend_ir.PIPE.PIPE_S PIPE_V = ascend_ir.PIPE.PIPE_V PIPE_M = ascend_ir.PIPE.PIPE_M PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 PIPE_ALL = ascend_ir.PIPE.PIPE_ALL PIPE_FIX = ascend_ir.PIPE.PIPE_FIX |
返回值
无返回值
3. 入参说明
| 参数名 | 类型 | 说明 |
| sender | str | 发送端,仅支持 "cube" / "vector" |
| receiver | str | 接收端,仅支持 "cube" / "vector" |
| event_id | int | 同步标记ID,取值范围 [0,15] |
| sender_pipe | al.PIPE | 发送端流水线类型 |
| receiver_pipe | al.PIPE | 接收端流水线类型 |
| _builder | - | JIT编译器自动传参 |
PIPE 枚举说明
| 流水类型 | 含义 |
| PIPE_S | 标量流水线,使用Tensor GetValue函数时为此流水 |
| PIPE_V | 矢量计算流水及L0C->UB数据搬运流水 |
| PIPE_M | 矩阵计算流水 |
| PIPE_MTE1 | L1->L0A、L1->L0B数据搬运流水 |
| PIPE_MTE2 | GM->L1、GM->L0A、GM->L0B、GM->UB数据搬运流水 |
| PIPE_MTE3 | UB->GM、UB->L1数据搬运流水 |
| PIPE_ALL | 所有流水 |
| PIPE_FIX | L0C->GM、L0C->L1数据搬运流水 |
4. 约束说明
- sender != receiver
用例示例
| Python import os os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" import pytest import triton import triton.language as tl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False arch = "Ascend910_95" def compile_kernel(kernel, signature, constants): src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def kernel_sync_cube_to_vector(): with al.scope(core_mode="cube"): al.sync_block_set("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3) with al.scope(core_mode="vector"): al.sync_block_wait("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3) @triton.jit def kernel_sync_vector_to_cube(): with al.scope(core_mode="vector"): al.sync_block_set("vector", "cube", 1, al.PIPE.PIPE_V, al.PIPE.PIPE_FIX) with al.scope(core_mode="cube"): al.sync_block_wait("vector", "cube", 1, al.PIPE.PIPE_V, al.PIPE.PIPE_FIX) @triton.jit def kernel_sync_multi_id(): with al.scope(core_mode="cube"): al.sync_block_set("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3) al.sync_block_set("cube", "vector", 1, al.PIPE.PIPE_MTE2, al.PIPE.PIPE_MTE3) with al.scope(core_mode="vector"): al.sync_block_wait("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3) al.sync_block_wait("cube", "vector", 1, al.PIPE.PIPE_MTE2, al.PIPE.PIPE_MTE3) if __name__ == "__main__": print("=" * 60) print("Test 1: Cube -> Vector Sync") print("=" * 60) mlir = compile_kernel(kernel_sync_cube_to_vector, {}, {}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) print("\n" + "=" * 60) print("Test 2: Vector -> Cube Sync") print("=" * 60) mlir = compile_kernel(kernel_sync_vector_to_cube, {}, {}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) print("\n" + "=" * 60) print("Test 3: Multi-ID Sync") print("=" * 60) mlir = compile_kernel(kernel_sync_multi_id, {}, {}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) |
输出:
| Plain Text ============================================================ Test 1: Cube -> Vector Sync ============================================================ ✅ Generated MLIR (1275 chars): module { tt.func public @kernel_sync_cube_to_vector() attributes {noinline = false} { scope.scope : () -> () { %c0_i32 = arith.constant 0 : i32 loc(#loc2) %0 = arith.extui %c0_i32 : i32 to i64 loc(#loc2) hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc2) scope.return loc(#loc2) } {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc1) scope.scope : () -> () { %c0_i32 = arith.constant 0 : i32 loc(#loc4) %0 = arith.extui %c0_i32 : i32 to i64 loc(#loc4) hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc4) scope.return loc(#loc4) } {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc3) tt.return loc(#loc5) } loc(#loc) } loc(#loc) #loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":36:0) #loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":37:9) #loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":38:66) #loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":39:9) #loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":40:67) #loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":39:4) ============================================================ Test 2: Vector -> Cube Sync ============================================================ ✅ Generated MLIR (1267 chars): module { tt.func public @kernel_sync_vector_to_cube() attributes {noinline = false} { scope.scope : () -> () { %c1_i32 = arith.constant 1 : i32 loc(#loc2) %0 = arith.extui %c1_i32 : i32 to i64 loc(#loc2) hivm.hir.sync_block_set[<VECTOR>, <PIPE_V>, <PIPE_FIX>] flag = %0 loc(#loc2) scope.return loc(#loc2) } {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc1) scope.scope : () -> () { %c1_i32 = arith.constant 1 : i32 loc(#loc4) %0 = arith.extui %c1_i32 : i32 to i64 loc(#loc4) hivm.hir.sync_block_wait[<CUBE>, <PIPE_V>, <PIPE_FIX>] flag = %0 loc(#loc4) scope.return loc(#loc4) } {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc3) tt.return loc(#loc5) } loc(#loc) } loc(#loc) #loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":44:0) #loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":45:9) #loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":46:63) #loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":47:9) #loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":48:64) #loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":47:4) ============================================================ Test 3: Multi-ID Sync ============================================================ ✅ Generated MLIR (1818 chars): module { tt.func public @kernel_sync_multi_id() attributes {noinline = false} { scope.scope : () -> () { %c0_i32 = arith.constant 0 : i32 loc(#loc2) %0 = arith.extui %c0_i32 : i32 to i64 loc(#loc2) hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc2) %c1_i32 = arith.constant 1 : i32 loc(#loc3) %1 = arith.extui %c1_i32 : i32 to i64 loc(#loc3) hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE2>, <PIPE_MTE3>] flag = %1 loc(#loc3) scope.return loc(#loc3) } {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc1) scope.scope : () -> () { %c0_i32 = arith.constant 0 : i32 loc(#loc5) %0 = arith.extui %c0_i32 : i32 to i64 loc(#loc5) hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc5) %c1_i32 = arith.constant 1 : i32 loc(#loc6) %1 = arith.extui %c1_i32 : i32 to i64 loc(#loc6) hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE2>, <PIPE_MTE3>] flag = %1 loc(#loc6) scope.return loc(#loc6) } {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc4) tt.return loc(#loc7) } loc(#loc) } loc(#loc) #loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":52:0) #loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":53:9) #loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":54:66) #loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":55:66) #loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":56:9) #loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":57:67) #loc6 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":58:67) #loc7 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":56:4) |