1. 硬件背景

当不同核之间操作同一块全局内存且可能存在读后写、写后读以及写后写等数据依赖问题时,通过调用该函数来插入同步语句来避免上述数据依赖时可能出现的数据读写错误问题。

2. 接口说明

Plain Text
def sync_block_all(mode, event_id, _builder=None):

2.1 入参

参数名 类型 必需 说明
mode str 同步的模式 ,可选字符串:all_cube/all_vector/all/all_sub_vector。<br>all_cube:同步所有cube核<br>all_vector: 同步所有vector核<br>all: 同步所有cube核和vector核<br>all_sub_vector:Vector子块间同步
event_id int 标记id。 范围是[0,15]

2.2 返回值

3. 约束

  • mode可选字符串:all_cube/all_vector/all/all_sub_vector

  • event_id范围是[0,15]

4. 示例

Plain Text
import os
import pytest
import triton
import triton.language as tl
import triton.language.extra.cann.extension as al
from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
from triton._C.libtriton import ir
from triton._C.libtriton.ascend import ir as ascend_ir

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {}, {})
return str(module)

@triton.jit
def test_sync_block_all():
al.sync_block_all("all_cube", 8)
al.sync_block_all("all_vector", 9)
al.sync_block_all("all", 10)
al.sync_block_all("all_sub_vector", 11)

if __name__ == "__main__":
print("=" * 60)
print("Test 1: test_sync_block_all")
print("=" * 60)
mlir = compile_kernel(test_sync_block_all, {}, {})
print(f"✅ Generated MLIR ({len(mlir)} chars):\n")
print(mlir)

输出:

Plain Text
module {
tt.func public @test_sync_block_all() attributes {noinline = false} {
hivm.hir.sync_block[<ALL_CUBE>, 8 : index] tcube_pipe = <PIPE_ALL> loc(#loc1)
hivm.hir.sync_block[<ALL_VECTOR>, 9 : index] tvector_pipe = <PIPE_ALL> loc(#loc2)
hivm.hir.sync_block[<ALL>, 10 : index] tcube_pipe = <PIPE_ALL> tvector_pipe = <PIPE_ALL> loc(#loc3)
hivm.hir.sync_block[<ALL_SUB_VECTOR>, 11 : index] tvector_pipe = <PIPE_ALL> loc(#loc4)
tt.return loc(#loc5)
} loc(#loc)
} loc(#loc)
#loc = loc("/home/linxin/triton-test/sync_block_all.py":37:0)
#loc1 = loc("/home/linxin/triton-test/sync_block_all.py":38:34)
#loc2 = loc("/home/linxin/triton-test/sync_block_all.py":39:36)
#loc3 = loc("/home/linxin/triton-test/sync_block_all.py":40:29)
#loc4 = loc("/home/linxin/triton-test/sync_block_all.py":41:40)
#loc5 = loc("/home/linxin/triton-test/sync_block_all.py":41:4)