al.sync_block_wait 接口文档

1. 硬件背景

面向分离模式的核间同步控制接口。

该接口和 sync_block_set 接口配合使用。使用时需传入核间同步的标记ID(flagId),每个ID对应一个初始值为0的计数器。执行CrossCoreSetFlag后ID对应的计数器增加1;执行CrossCoreWaitFlag时如果对应的计数器数值为0则阻塞不执行;如果对应的计数器大于0,则计数器减一,同时后续指令开始执行。

2. 接口说明

Python
def sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None):

class PIPE(enum.Enum):
PIPE_S = ascend_ir.PIPE.PIPE_S
PIPE_V = ascend_ir.PIPE.PIPE_V
PIPE_M = ascend_ir.PIPE.PIPE_M
PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1
PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2
PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3
PIPE_ALL = ascend_ir.PIPE.PIPE_ALL
PIPE_FIX = ascend_ir.PIPE.PIPE_FIX

返回值

无返回值

3. 入参说明

参数名 类型 说明
sender str 发送端,仅支持 "cube" / "vector"
receiver str 接收端,仅支持 "cube" / "vector"
event_id int 同步标记ID,取值范围 [0,15]
sender_pipe al.PIPE 发送端流水线类型
receiver_pipe al.PIPE 接收端流水线类型
_builder - JIT编译器自动传参

PIPE 枚举说明

流水类型 含义
PIPE_S 标量流水线,使用Tensor GetValue函数时为此流水
PIPE_V 矢量计算流水及L0C->UB数据搬运流水
PIPE_M 矩阵计算流水
PIPE_MTE1 L1->L0A、L1->L0B数据搬运流水
PIPE_MTE2 GM->L1、GM->L0A、GM->L0B、GM->UB数据搬运流水
PIPE_MTE3 UB->GM、UB->L1数据搬运流水
PIPE_ALL 所有流水
PIPE_FIX L0C->GM、L0C->L1数据搬运流水

4. 约束说明

  • sender != receiver

用例示例

Python
import os

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"

import pytest

import triton

import triton.language as tl

import triton.language.extra.cann.extension as al

from triton.compiler.compiler import ASTSource

from triton.compiler.code_generator import ast_to_ttir

from triton._C.libtriton import ir, buffer_ir

from triton._C.libtriton.ascend import ir as ascend_ir

class Options:

num_warps = 4

num_stages = 3

num_ctas = 1

cluster_dims = (1, 1, 1)

enable_fp_fusion = True

debug = False

arch = "Ascend910_95"

def compile_kernel(kernel, signature, constants):

src = ASTSource(kernel, signature, constants)

context = ir.context()

ir.load_dialects(context)

buffer_ir.load_dialects(context)

ascend_ir.load_dialects(context)

module = ast_to_ttir(kernel, src, context, Options(), {}, {})

return str(module)

@triton.jit

def kernel_sync_cube_to_vector():

with al.scope(core_mode="cube"):

al.sync_block_set("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3)

with al.scope(core_mode="vector"):

al.sync_block_wait("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3)

@triton.jit

def kernel_sync_vector_to_cube():

with al.scope(core_mode="vector"):

al.sync_block_set("vector", "cube", 1, al.PIPE.PIPE_V, al.PIPE.PIPE_FIX)

with al.scope(core_mode="cube"):

al.sync_block_wait("vector", "cube", 1, al.PIPE.PIPE_V, al.PIPE.PIPE_FIX)

@triton.jit

def kernel_sync_multi_id():

with al.scope(core_mode="cube"):

al.sync_block_set("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3)

al.sync_block_set("cube", "vector", 1, al.PIPE.PIPE_MTE2, al.PIPE.PIPE_MTE3)

with al.scope(core_mode="vector"):

al.sync_block_wait("cube", "vector", 0, al.PIPE.PIPE_MTE1, al.PIPE.PIPE_MTE3)

al.sync_block_wait("cube", "vector", 1, al.PIPE.PIPE_MTE2, al.PIPE.PIPE_MTE3)

if __name__ == "__main__":

print("=" * 60)

print("Test 1: Cube -> Vector Sync")

print("=" * 60)

mlir = compile_kernel(kernel_sync_cube_to_vector, {}, {})

print(f"✅ Generated MLIR ({len(mlir)} chars):\n")

print(mlir)

print("\n" + "=" * 60)

print("Test 2: Vector -> Cube Sync")

print("=" * 60)

mlir = compile_kernel(kernel_sync_vector_to_cube, {}, {})

print(f"✅ Generated MLIR ({len(mlir)} chars):\n")

print(mlir)

print("\n" + "=" * 60)

print("Test 3: Multi-ID Sync")

print("=" * 60)

mlir = compile_kernel(kernel_sync_multi_id, {}, {})

print(f"✅ Generated MLIR ({len(mlir)} chars):\n")

print(mlir)

输出:

Plain Text
============================================================

Test 1: Cube -> Vector Sync

============================================================

✅ Generated MLIR (1275 chars):

module {

tt.func public @kernel_sync_cube_to_vector() attributes {noinline = false} {

scope.scope : () -> () {

%c0_i32 = arith.constant 0 : i32 loc(#loc2)

%0 = arith.extui %c0_i32 : i32 to i64 loc(#loc2)

hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc2)

scope.return loc(#loc2)

} {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc1)

scope.scope : () -> () {

%c0_i32 = arith.constant 0 : i32 loc(#loc4)

%0 = arith.extui %c0_i32 : i32 to i64 loc(#loc4)

hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc4)

scope.return loc(#loc4)

} {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc3)

tt.return loc(#loc5)

} loc(#loc)

} loc(#loc)

#loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":36:0)

#loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":37:9)

#loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":38:66)

#loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":39:9)

#loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":40:67)

#loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":39:4)

============================================================

Test 2: Vector -> Cube Sync

============================================================

✅ Generated MLIR (1267 chars):

module {

tt.func public @kernel_sync_vector_to_cube() attributes {noinline = false} {

scope.scope : () -> () {

%c1_i32 = arith.constant 1 : i32 loc(#loc2)

%0 = arith.extui %c1_i32 : i32 to i64 loc(#loc2)

hivm.hir.sync_block_set[<VECTOR>, <PIPE_V>, <PIPE_FIX>] flag = %0 loc(#loc2)

scope.return loc(#loc2)

} {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc1)

scope.scope : () -> () {

%c1_i32 = arith.constant 1 : i32 loc(#loc4)

%0 = arith.extui %c1_i32 : i32 to i64 loc(#loc4)

hivm.hir.sync_block_wait[<CUBE>, <PIPE_V>, <PIPE_FIX>] flag = %0 loc(#loc4)

scope.return loc(#loc4)

} {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc3)

tt.return loc(#loc5)

} loc(#loc)

} loc(#loc)

#loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":44:0)

#loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":45:9)

#loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":46:63)

#loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":47:9)

#loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":48:64)

#loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":47:4)

============================================================

Test 3: Multi-ID Sync

============================================================

✅ Generated MLIR (1818 chars):

module {

tt.func public @kernel_sync_multi_id() attributes {noinline = false} {

scope.scope : () -> () {

%c0_i32 = arith.constant 0 : i32 loc(#loc2)

%0 = arith.extui %c0_i32 : i32 to i64 loc(#loc2)

hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc2)

%c1_i32 = arith.constant 1 : i32 loc(#loc3)

%1 = arith.extui %c1_i32 : i32 to i64 loc(#loc3)

hivm.hir.sync_block_set[<CUBE>, <PIPE_MTE2>, <PIPE_MTE3>] flag = %1 loc(#loc3)

scope.return loc(#loc3)

} {hivm.tcore_type = #hivm.tcore_type<CUBE>, noinline} loc(#loc1)

scope.scope : () -> () {

%c0_i32 = arith.constant 0 : i32 loc(#loc5)

%0 = arith.extui %c0_i32 : i32 to i64 loc(#loc5)

hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE1>, <PIPE_MTE3>] flag = %0 loc(#loc5)

%c1_i32 = arith.constant 1 : i32 loc(#loc6)

%1 = arith.extui %c1_i32 : i32 to i64 loc(#loc6)

hivm.hir.sync_block_wait[<VECTOR>, <PIPE_MTE2>, <PIPE_MTE3>] flag = %1 loc(#loc6)

scope.return loc(#loc6)

} {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc4)

tt.return loc(#loc7)

} loc(#loc)

} loc(#loc)

#loc = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":52:0)

#loc1 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":53:9)

#loc2 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":54:66)

#loc3 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":55:66)

#loc4 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":56:9)

#loc5 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":57:67)

#loc6 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":58:67)

#loc7 = loc("/home/ganpengfei/workspace/triton-test/sync_block_set_wait.py":56:4)