1. 硬件背景
支持VF手动同步
2.接口说明
| Python class SYNC_IN_VF(enum.Enum): VV_ALL = auto() VST_VLD = auto() VLD_VST = auto() VST_VST = auto() VS_ALL = auto() VST_LD = auto() VLD_ST = auto() VST_ST = auto() SV_ALL = auto() ST_VLD = auto() LD_VST = auto() ST_VST = auto() @builtin def debug_barrier( sync_mode: SYNC_IN_VF, _builder=None, ) -> None: return semantic.debug_barrier(sync_mode.name, _builder) |
2.1 入参
- sync_mode:指定barrier的类型,为al.SYNC_IN_VF 枚举
| 类型 | 说明 |
| VV_ALL | blocks the execution of vector load/store instructions until all the vector load/store instructions have been completed. |
| VST_VLD | blocks the execution of vector load instructions until all the vector store instructions have been completed. |
| VLD_VST | blocks the execution of vector store instructions until all the vector load instructions have been completed. |
| VST_VST | blocks the execution of vector store instructions until all the vector store instructions have been completed. |
| VS_ALL | blocks the execution of scalar load/store instructions until all the vector load/store instructions have been completed. |
| VST_LD | blocks the execution of scalar load instructions until all the vector store instructions have been completed. |
| VLD_ST | blocks the execution of scalar store instructions until all the vector load instructions have been completed. |
| VST_ST | blocks the execution of scalar store instructions until all the vector store instructions have been completed. |
| SV_ALL | blocks the execution of vector load/store instructions until all the scalar load/store instructions have been completed. |
| ST_VLD | blocks the execution of vector load instructions until all the scalar store instructions have been completed. |
| LD_VST | blocks the execution of vector store instructions until all the scalar load instructions have been completed. |
| ST_VST | blocks the execution of vector store instructions until all the scalar store instructions have been completed. |
3.约束
- 仅可在scope中使用(目前未拦截)
4.用例说明
| Plain Text import os import triton import triton.language as tl import triton.extension.buffer.language as bl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False arch = "Ascend910_95" def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def triton_sub(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): offset = tl.program_id(0) * XBLOCK base1 = tl.arange(0, XBLOCK_SUB) loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB for loop1 in range(loops1): x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 x0 = offset + (loop1 * XBLOCK_SUB) + base1 tmp0 = tl.load(in_ptr0 + (x0), None) tmp1 = tl.load(in_ptr1 + (x0), None) tmp2 = tmp0 - tmp1 tl.debug_barrier() tl.store(out_ptr0 + (x0), tmp2, None) def test_debug_barrier(): print("=" * 60) print("Test 1: debug_barrier ") print("=" * 60) mlir = compile_kernel( triton_sub, {"in_ptr0": "*fp32", "in_ptr1": "*fp32", "out_ptr0": "*fp32"}, {"XBLOCK": 16, "XBLOCK_SUB": 8}, ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) # ============== Main for manual testing ============== if __name__ == "__main__": test_debug_barrier() |
输出:
| Plain Text module { tt.func public @triton_sub(%arg0: !tt.ptr<f32> loc("/home/linxin/triton-test/debug_barrier.py":35:0), %arg1: !tt.ptr<f32> loc("/home/linxin/triton-test/debug_barrier.py":35:0), %arg2: !tt.ptr<f32> loc("/home/linxin/triton-test/debug_barrier.py":35:0)) attributes {noinline = false} { %0 = tt.get_program_id x : i32 loc(#loc1) %c16_i32 = arith.constant 16 : i32 loc(#loc2) %c16_i32_0 = arith.constant 16 : i32 loc(#loc2) %1 = arith.muli %0, %c16_i32_0 : i32 loc(#loc2) %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc3) %c0_i32 = arith.constant 0 : i32 loc(#loc4) %c2_i32 = arith.constant 2 : i32 loc(#loc4) %c1_i32 = arith.constant 1 : i32 loc(#loc4) %3 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc4) %4 = arith.bitcast %c2_i32 : i32 to i32 loc(#loc4) %5 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc4) %6 = ub.poison : i32 loc(#loc4) scf.for %arg3 = %3 to %4 step %5 : i32 { %c8_i32 = arith.constant 8 : i32 loc(#loc5) %c8_i32_1 = arith.constant 8 : i32 loc(#loc5) %7 = arith.muli %arg3, %c8_i32_1 : i32 loc(#loc5) %8 = arith.addi %1, %7 : i32 loc(#loc6) %9 = tt.splat %8 : i32 -> tensor<8xi32> loc(#loc7) %10 = arith.addi %9, %2 : tensor<8xi32> loc(#loc7) %c8_i32_2 = arith.constant 8 : i32 loc(#loc8) %c8_i32_3 = arith.constant 8 : i32 loc(#loc8) %11 = arith.muli %arg3, %c8_i32_3 : i32 loc(#loc8) %12 = arith.addi %1, %11 : i32 loc(#loc9) %13 = tt.splat %12 : i32 -> tensor<8xi32> loc(#loc10) %14 = arith.addi %13, %2 : tensor<8xi32> loc(#loc10) %15 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>> loc(#loc11) %16 = tt.addptr %15, %14 : tensor<8x!tt.ptr<f32>>, tensor<8xi32> loc(#loc11) %17 = tt.load %16 : tensor<8x!tt.ptr<f32>> loc(#loc12) %18 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>> loc(#loc13) %19 = tt.addptr %18, %14 : tensor<8x!tt.ptr<f32>>, tensor<8xi32> loc(#loc13) %20 = tt.load %19 : tensor<8x!tt.ptr<f32>> loc(#loc14) %21 = arith.subf %17, %20 : tensor<8xf32> loc(#loc15) gpu.barrier loc(#loc16) %22 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>> loc(#loc17) %23 = tt.addptr %22, %14 : tensor<8x!tt.ptr<f32>>, tensor<8xi32> loc(#loc17) tt.store %23, %21 : tensor<8x!tt.ptr<f32>> loc(#loc18) } loc(#loc4) tt.return loc(#loc19) } loc(#loc) } loc(#loc) #loc1 = loc("/home/linxin/triton-test/debug_barrier.py":36:27) #loc2 = loc("/home/linxin/triton-test/debug_barrier.py":36:32) #loc3 = loc("/home/linxin/triton-test/debug_barrier.py":37:25) #loc4 = loc("/home/linxin/triton-test/debug_barrier.py":39:23) #loc5 = loc("/home/linxin/triton-test/debug_barrier.py":40:37) #loc6 = loc("/home/linxin/triton-test/debug_barrier.py":40:29) #loc7 = loc("/home/linxin/triton-test/debug_barrier.py":40:51) #loc8 = loc("/home/linxin/triton-test/debug_barrier.py":41:31) #loc9 = loc("/home/linxin/triton-test/debug_barrier.py":41:23) #loc10 = loc("/home/linxin/triton-test/debug_barrier.py":41:45) #loc11 = loc("/home/linxin/triton-test/debug_barrier.py":42:34) #loc12 = loc("/home/linxin/triton-test/debug_barrier.py":42:39) #loc13 = loc("/home/linxin/triton-test/debug_barrier.py":43:34) #loc14 = loc("/home/linxin/triton-test/debug_barrier.py":43:39) #loc15 = loc("/home/linxin/triton-test/debug_barrier.py":44:22) #loc16 = loc("/home/linxin/triton-test/debug_barrier.py":45:8) #loc17 = loc("/home/linxin/triton-test/debug_barrier.py":46:29) #loc18 = loc("/home/linxin/triton-test/debug_barrier.py":46:40) #loc19 = loc("/home/linxin/triton-test/debug_barrier.py":39:4) |