sub_vec_id

1. 硬件背景

昇腾硬件AIC与AIV核数配比不同(1:N),Triton编程抽象屏蔽了Cube核与Vector核的硬件细节,因此,Triton算子开发者无法控制如何切分数据在N个Vector核间并行处理,由编译器通过AutoSubTiling Pass自动实现。

sub_vec_id编程接口返回N个Vector核的sub id,允许算子开发者根据vector核sub id决定每个核处理哪些数据。

2. 接口说明

Python
def sub_vec_id() -> i16
  • 返回值:返回范围为 [0, N) 的 Sub Vector ID,算子开发者可根据该ID决定N个并行Vector核中每个核处理的数据分片

  • 入参:无

3. 约束说明

仅在AIC和AIV核混合使用场景中有效,不可在纯Cube类算子或者纯Vector类算子中使用,否则会触发编译报错。

4. 用例示例

Python

import os

import triton

import triton.language as tl

import triton.language.extra.cann.extension as al

from triton.compiler.compiler import ASTSource

from triton.compiler.code_generator import ast_to_ttir

from triton._C.libtriton import ir, buffer_ir

from triton._C.libtriton.ascend import ir as ascend_ir

os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"

class Options:

num_warps = 4

num_stages = 3

num_ctas = 1

cluster_dims = (1, 1, 1)

enable_fp_fusion = True

debug = False

arch = "Ascend910_95"

def compile_kernel(kernel, signature, constants):

"""Helper to compile a kernel to MLIR."""

src = ASTSource(kernel, signature, constants)

context = ir.context()

ir.load_dialects(context)

buffer_ir.load_dialects(context)

ascend_ir.load_dialects(context)

module = ast_to_ttir(kernel, src, context, Options(), {}, {})

return str(module)

@triton.jit

def verify_sub_vec_id_kernel(

out_ptr,

N: tl.constexpr,

):

with al.scope(core_mode="vector"):

sub_id = al.sub_vec_id()



offs = sub_id * N + tl.arange(0, N)

out_ptrs = out_ptr + offs



tl.store(out_ptrs, sub_id.to(tl.int32))

def test_sub_vec_id_1to2():

print("=" * 60)

print("Test: Verify sub_vec_id (Simplified)")

print("=" * 60)



mlir = compile_kernel(

kernel=verify_sub_vec_id_kernel,

signature={"out_ptr": "*i32"},

constants={"N": 8},

)



print(f"✅ Generated MLIR ({len(mlir)} chars):\n")

print(mlir)

# ============== Main ==============

if __name__ == "__main__":

test_sub_vec_id_1to2()

输出:

Plain Text
============================================================

Test: Verify sub_vec_id (Simplified)

============================================================

✅ Generated MLIR (1893 chars):

#loc = loc("/home/linxin/triton-test/sub_vec_id.py":35:0)

module attributes {hivm.disable_auto_tile_and_bind_subblock} {

tt.func public @verify_sub_vec_id_kernel(%arg0: !tt.ptr<i32> loc("/home/linxin/triton-test/sub_vec_id.py":35:0)) attributes {noinline = false} {

%0:3 = scope.scope : () -> (i64, tensor<8xi64>, tensor<8x!tt.ptr<i32>>) {

%1 = hivm.hir.get_sub_block_idx -> i64 loc(#loc2)

%c8_i32 = arith.constant 8 : i32 loc(#loc3)

%c8_i64 = arith.constant 8 : i64 loc(#loc3)

%2 = arith.muli %1, %c8_i64 : i64 loc(#loc3)

%3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4)

%4 = arith.extsi %3 : tensor<8xi32> to tensor<8xi64> loc(#loc5)

%5 = tt.splat %2 : i64 -> tensor<8xi64> loc(#loc5)

%6 = arith.addi %5, %4 : tensor<8xi64> loc(#loc5)

%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<8x!tt.ptr<i32>> loc(#loc6)

%8 = tt.addptr %7, %6 : tensor<8x!tt.ptr<i32>>, tensor<8xi64> loc(#loc6)

%9 = arith.trunci %1 : i64 to i32 loc(#loc7)

%10 = tt.splat %9 : i32 -> tensor<8xi32> loc(#loc8)

tt.store %8, %10 : tensor<8x!tt.ptr<i32>> loc(#loc8)

scope.return %1, %6, %8 : i64, tensor<8xi64>, tensor<8x!tt.ptr<i32>> loc(#loc8)

} {hivm.tcore_type = #hivm.tcore_type<VECTOR>, noinline} loc(#loc1)

tt.return loc(#loc9)

} loc(#loc)

} loc(#loc)

#loc1 = loc("/home/linxin/triton-test/sub_vec_id.py":39:9)

#loc2 = loc("/home/linxin/triton-test/sub_vec_id.py":40:17)

#loc3 = loc("/home/linxin/triton-test/sub_vec_id.py":42:24)

#loc4 = loc("/home/linxin/triton-test/sub_vec_id.py":42:41)

#loc5 = loc("/home/linxin/triton-test/sub_vec_id.py":42:28)

#loc6 = loc("/home/linxin/triton-test/sub_vec_id.py":43:29)

#loc7 = loc("/home/linxin/triton-test/sub_vec_id.py":45:37)

#loc8 = loc("/home/linxin/triton-test/sub_vec_id.py":45:27)

#loc9 = loc("/home/linxin/triton-test/sub_vec_id.py":39:4)