sub_vec_id
1. 硬件背景
昇腾硬件AIC与AIV核数配比不同(1:N),Triton编程抽象屏蔽了Cube核与Vector核的硬件细节,因此,Triton算子开发者无法控制如何切分数据在N个Vector核间并行处理,由编译器通过AutoSubTiling Pass自动实现。
sub_vec_id编程接口返回N个Vector核的sub id,允许算子开发者根据vector核sub id决定每个核处理哪些数据。
2. 接口说明
| Python def sub_vec_id() -> i16 |
-
返回值:返回范围为 [0, N) 的 Sub Vector ID,算子开发者可根据该ID决定N个并行Vector核中每个核处理的数据分片
-
入参:无
3. 约束说明
仅在AIC和AIV核混合使用场景中有效,不可在纯Cube类算子或者纯Vector类算子中使用,否则会触发编译报错。
4. 用例示例
| Python import os import triton import triton.language as tl import triton.language.extra.cann.extension as al from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir from triton._C.libtriton import ir, buffer_ir from triton._C.libtriton.ascend import ir as ascend_ir os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" class Options: num_warps = 4 num_stages = 3 num_ctas = 1 cluster_dims = (1, 1, 1) enable_fp_fusion = True debug = False arch = "Ascend910_95" def compile_kernel(kernel, signature, constants): """Helper to compile a kernel to MLIR.""" src = ASTSource(kernel, signature, constants) context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) ascend_ir.load_dialects(context) module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) @triton.jit def verify_sub_vec_id_kernel( out_ptr, N: tl.constexpr, ): with al.scope(core_mode="vector"): sub_id = al.sub_vec_id() offs = sub_id * N + tl.arange(0, N) out_ptrs = out_ptr + offs tl.store(out_ptrs, sub_id.to(tl.int32)) def test_sub_vec_id_1to2(): print("=" * 60) print("Test: Verify sub_vec_id (Simplified)") print("=" * 60) mlir = compile_kernel( kernel=verify_sub_vec_id_kernel, signature={"out_ptr": "*i32"}, constants={"N": 8}, ) print(f"✅ Generated MLIR ({len(mlir)} chars):\n") print(mlir) # ============== Main ============== if __name__ == "__main__": test_sub_vec_id_1to2() |
输出:
| Plain Text ============================================================ Test: Verify sub_vec_id (Simplified) ============================================================ ✅ Generated MLIR (1893 chars): #loc = loc("/home/linxin/triton-test/sub_vec_id.py":35:0) module attributes {hivm.disable_auto_tile_and_bind_subblock} { tt.func public @verify_sub_vec_id_kernel(%arg0: !tt.ptr<i32> loc("/home/linxin/triton-test/sub_vec_id.py":35:0)) attributes {noinline = false} { %0:3 = scope.scope : () -> (i64, tensor<8xi64>, tensor<8x!tt.ptr<i32>>) { %1 = hivm.hir.get_sub_block_idx -> i64 loc(#loc2) %c8_i32 = arith.constant 8 : i32 loc(#loc3) %c8_i64 = arith.constant 8 : i64 loc(#loc3) %2 = arith.muli %1, %c8_i64 : i64 loc(#loc3) %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) %4 = arith.extsi %3 : tensor<8xi32> to tensor<8xi64> loc(#loc5) %5 = tt.splat %2 : i64 -> tensor<8xi64> loc(#loc5) %6 = arith.addi %5, %4 : tensor<8xi64> loc(#loc5) %7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<8x!tt.ptr<i32>> loc(#loc6) %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr<i32>>, tensor<8xi64> loc(#loc6) %9 = arith.trunci %1 : i64 to i32 loc(#loc7) %10 = tt.splat %9 : i32 -> tensor<8xi32> loc(#loc8) tt.store %8, %10 : tensor<8x!tt.ptr<i32>> loc(#loc8) scope.return %1, %6, %8 : i64, tensor<8xi64>, tensor<8x!tt.ptr<i32>> loc(#loc8) } {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc1) tt.return loc(#loc9) } loc(#loc) } loc(#loc) #loc1 = loc("/home/linxin/triton-test/sub_vec_id.py":39:9) #loc2 = loc("/home/linxin/triton-test/sub_vec_id.py":40:17) #loc3 = loc("/home/linxin/triton-test/sub_vec_id.py":42:24) #loc4 = loc("/home/linxin/triton-test/sub_vec_id.py":42:41) #loc5 = loc("/home/linxin/triton-test/sub_vec_id.py":42:28) #loc6 = loc("/home/linxin/triton-test/sub_vec_id.py":43:29) #loc7 = loc("/home/linxin/triton-test/sub_vec_id.py":45:37) #loc8 = loc("/home/linxin/triton-test/sub_vec_id.py":45:27) #loc9 = loc("/home/linxin/triton-test/sub_vec_id.py":39:4) |