1. 硬件背景
当不同核之间操作同一块全局内存且可能存在读后写、写后读以及写后写等数据依赖问题时,通过调用该函数来插入同步语句来避免上述数据依赖时可能出现的数据读写错误问题。
2. 接口说明
| Plain Text def sync_block_all(mode, event_id, _builder=None): |
2.1 入参
| 参数名 | 类型 | 必需 | 说明 |
| mode | str | 是 | 同步的模式 ,可选字符串:all_cube/all_vector/all/all_sub_vector。<br>all_cube:同步所有cube核<br>all_vector: 同步所有vector核<br>all: 同步所有cube核和vector核<br>all_sub_vector:Vector子块间同步 |
| event_id | int | 是 | 标记id。 范围是[0,15] |
2.2 返回值
无
3. 约束
-
mode可选字符串:all_cube/all_vector/all/all_sub_vector
-
event_id范围是[0,15]
4. 示例
| Plain Text import os import pytest import triton import triton.language as tl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def test_sync_block_all(): al.sync_block_all("all_cube", 8) al.sync_block_all("all_vector", 9) al.sync_block_all("all", 10) al.sync_block_all("all_sub_vector", 11) if __name__ == "__main__": print("=" * 60) print("Test 1: test_sync_block_all") print("=" * 60) mlir = compile_kernel(test_sync_block_all, {}, {}) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) |
输出:
| Plain Text module { tt.func public @test_sync_block_all() attributes {noinline = false} { hivm.hir.sync_block[<ALL_CUBE>, 8 : index] tcube_pipe = <PIPE_ALL> loc(#loc1) hivm.hir.sync_block[<ALL_VECTOR>, 9 : index] tvector_pipe = <PIPE_ALL> loc(#loc2) hivm.hir.sync_block[<ALL>, 10 : index] tcube_pipe = <PIPE_ALL> tvector_pipe = <PIPE_ALL> loc(#loc3) hivm.hir.sync_block[<ALL_SUB_VECTOR>, 11 : index] tvector_pipe = <PIPE_ALL> loc(#loc4) tt.return loc(#loc5) } loc(#loc) } loc(#loc) #loc = loc("/home/linxin/triton-test/sync_block_all.py":37:0) #loc1 = loc("/home/linxin/triton-test/sync_block_all.py":38:34) #loc2 = loc("/home/linxin/triton-test/sync_block_all.py":39:36) #loc3 = loc("/home/linxin/triton-test/sync_block_all.py":40:29) #loc4 = loc("/home/linxin/triton-test/sync_block_all.py":41:40) #loc5 = loc("/home/linxin/triton-test/sync_block_all.py":41:4) |