最佳实践案例
性能优化案例
Tiling策略
案例说明
基于GPU实现的Triton算子迁移到NPU时,通常发射的逻辑核数量远大于物理核,会有严重的启动及调度开销 建议在编写迁移时,调整Tiling策略,缩减核数,尽量使发射的逻辑核数量等于物理核,提升性能 本案例使用triton实现
out = torch.gather(x, dim=1, index=idx)
输入:
| Input | Shape |
|---|---|
| x | (B, C) |
| idx | (B, K) |
输出
| Input | Shape |
|---|---|
| out | (B, K) |
案例差异点详解
@triton.jit
def gather_dim1_kernel(
x_ptr, # *x [B, C]
idx_ptr, # *idx[B, K]
out_ptr, # *out[B, K]
stride_xb, stride_xc,
stride_ib, stride_ik,
stride_ob, stride_ok,
B, K,
BLOCK_B: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_b = tl.program_id(0) # 1 block per batch row
- # GPU实现
- pid_k = tl.program_id(1) # 1 block per K-tile
- k_off = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
- mask = k_off < K
- idx = tl.load(idx_ptr + pid_b * stride_ib + k_off * stride_ik, mask=mask) # [BLOCK_K]
- x_val = tl.load(x_ptr + pid_b * stride_xb + idx * stride_xc, mask=mask)
- tl.store(out_ptr + pid_b * stride_ob + k_off * stride_ok, x_val, mask=mask)
+ #NPU实现
+ b_idx = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
+ b_mask = b_idx < B
+ # 对 K 维进行循环
+ for k_start in range(0, K, BLOCK_K):
+ ks = tl.arange(0, BLOCK_K)
+ k_mask = ks < K - k_start
+ idx_off = (b_idx[:, None] * stride_ib +
+ (k_start + ks)[None, :] * stride_ik)
+ col_idx = tl.load(idx_ptr + idx_off, mask=b_mask[:, None] & k_mask)
+ x_off = (b_idx[:, None] * stride_xb +
+ col_idx * stride_xc)
+ x_val = tl.load(x_ptr + x_off, mask=b_mask[:, None] & k_mask)
+ out_off = (b_idx[:, None] * stride_ob +
+ (k_start + ks)[None, :] * stride_ok)
+ tl.store(out_ptr + out_off, x_val, mask=b_mask[:, None] & k_mask)
# 调用
B = 128 # batch dim
K = 64
BLOCK_B = 4
BLOCK_K = 128
— # GPU
- grid = (B, triton.cdiv(K, BLOCK_K))
+ # NPU
+ grid = (triton.cdiv(B, BLOCK_B),)
kernel 昇腾亲和改写
案例说明
原始GPU计算流程i64/i32 cmp操作在NPU设备上无法使能vector,退化为scalar计算效率降低;通过转化为fp32来利用vec_cast和vec_cmp实现vector操作加速 需要注意的是,在tl.load和tl.save中的mask使用cmp功能,大部分情况下编译器可以自动优化为vec操作,本例中tl.where则需要手动转换 本案例以layerNorm为例说明实现向量化cmp加速NPU计算流程,该cmp操作用于处理layerNorm中的尾块处理
案例差异点详解
cols = tl.arange(0, BLOCK_N) # cols is int64
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
# calculate mean & rstd
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
- xbar = tl.where(cols < N, X - mean, 0.0)
+ # change cols(i64) into cols_cmp(f32) to enable vector processing
+ cols_cmp = cols.to(tl.float32)
+ xbar = tl.where(cols_cmp < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
功能或精度案例
本章节介绍常见功能或精度案例
卡死类问题
定界
-
现象 算子选项规避超时报错,算子卡死的部分原因是与硬件同步相关,其中可能涉及核内/间同步,或涉及流水同步。若遇上算子卡死的情况,你可以尝试在调用Kernel时,传入以下入参,修改二进制的同步逻辑,以规避算子卡死的问题。
-
写法样例
| 编译选项 | 数值 | 说明 |
|---|---|---|
| inject_barrier_all | false(default). | 前端尝试打开为true,如果卡死问题消失,证明核内同步有问题,适用mix/aic/aiv三类kernel |
| inject_block_all | false(default). | 前端尝试打开为true,如果卡死问题消失,证明核间同步有问题,适用mix类kernel |
以GDN网络的chunk_gated_delta_rule_fwd_kernel_h_blockdim64算子为例,原代码写法样例调用如下:
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
开启CV全流水后的写法样例为
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
inject_block_all = True # 开启核间同步
inject_barrier_all = True # 开启核内同步
)
参数入参不合理
对于varlen类的算子,通常会在seqlen中随机采样indice,需要保证indice的入参合理性。例如严格递增且在[0, seqlen]范围内。
ub overflow类问题
triton argmax op 先32B对齐,再融轴,浪费大量UB空间
mlir代码如下:
%reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [0], sizes: [256, 9, 11], strides: [99, 11, 1] : memref<?xi8, #hivm.address_space<gm>> to memref<256x9x11xi8, strided<[99, 11, 1]>, #hivm.address_space<gm>>
%2 = hivm.hir.pointer_cast(%c0_i64) : memref<256x32x11x1xi8, #hivm.address_space<ub>>
%subview = memref.subview %2[0, 0, 0, 0] [256, 9, 11, 1] [1, 1, 1, 1] : memref<256x32x11x1xi8, #hivm.address_space<ub>> to memref<256x9x11xi8, strided<[352, 11, 1]>, #hivm.address_space<ub>>
%collapse_shape = memref.collapse_shape %reinterpret_cast [[0], [1, 2]] : memref<256x9x11xi8, strided<[99, 11, 1]>, #hivm.address_space<gm>> into memref<256x99xi8, strided<[99, 1]>, #hivm.address_space<gm>>
%collapse_shape_0 = memref.collapse_shape %subview [[0], [1, 2]] : memref<256x9x11xi8, strided<[352, 11, 1]>, #hivm.address_space<ub>> into memref<256x99xi8, strided<[352, 1]>, #hivm.address_space<ub>>
hivm.hir.load ins(%collapse_shape : memref<256x99xi8, strided<[99, 1]>, #hivm.address_space<gm>>) outs(%collapse_shape_0 : memref<256x99xi8, strided<[352, 1]>, #hivm.address_space<ub>>) init_out_buffer = false may_implicit_transpose_with_last_axis = false
-
分析:
第1行,原始的数据大小为256x9x11xi8,保存在GM中(kernel的参数%arg3);
第2行,申请一块大小为256x32x11x1xi8的UB空间,用于从GM中COPY数据到UB,这里对第1轴进行了32字节对齐操作,同时尾轴增加一维;
第3行,对第2行申请的UB形状256×32×11×1xi8,通过subview提取256x9x11xi8的子视图;
第4行,通过collapse_shape,对第1行的GM中的视图256x9x11xi8,进行合并维度,变成256x99xi8类型;
第5行,通过collapse_shape,对第3行的UB中的视图256x9x11xi8,进行合并维度,变成256x99xi8类型;
第6行,将第4行的GM中形状为256x99xi8的数据,COPY到第5行的UB中的256x99xi8形状中;
-
总结:
原始的数据256x9x11xi8大小:25344B;从GM中load到ub,ub中占用的大小(256x32x11x1xi8):90112B,占用的ub大小为原始数据大小的3.5倍多。
triton not op不合理的实现,导致额外占用内存
Triton Not OP的实现,在NPU-IR中,被转成VOR、VAND、VNOT、VAND等一系列操作来处理,实际上可以只执行VNOT操作:
mlir代码如下:
%2 = hivm.hir.pointer_cast(%c0_i64) : memref<65536xi8, #hivm.address_space<ub>>
hivm.hir.load ins(%reinterpret_cast : memref<65536xi8, strided<[1]>, #hivm.address_space<gm>>) outs(%2 : memref<65536xi8, #hivm.address_space<ub>>) init_out_buffer = false may_implicit_transpose_with_last_axis = false
%3 = hivm.hir.pointer_cast(%c131072_i64) : memref<65536xi8, #hivm.address_space<ub>>
hivm.hir.vbrc ins(%c-1_i8 : i8) outs(%3 : memref<65536xi8, #hivm.address_space<ub>>)
%4 = hivm.hir.pointer_cast(%c65536_i64) : memref<65536xi8, #hivm.address_space<ub>>
hivm.hir.vor ins(%2, %3 : memref<65536xi8, #hivm.address_space<ub>>, memref<65536xi8, #hivm.address_space<ub>>) outs(%4 : memref<65536xi8, #hivm.address_space<ub>>)
%5 = hivm.hir.pointer_cast(%c0_i64) : memref<65536xi8, #hivm.address_space<ub>>
hivm.hir.vand ins(%2, %3 : memref<65536xi8, #hivm.address_space<ub>>, memref<65536xi8, #hivm.address_space<ub>>) outs(%5 : memref<65536xi8, #hivm.address_space<ub>>)
hivm.hir.vnot ins(%5 : memref<65536xi8, #hivm.address_space<ub>>) outs(%5 : memref<65536xi8, #hivm.address_space<ub>>)
%6 = hivm.hir.pointer_cast(%c65536_i64) : memref<65536xi8, #hivm.address_space<ub>>
hivm.hir.vand ins(%5, %4 : memref<65536xi8, #hivm.address_space<ub>>, memref<65536xi8, #hivm.address_space<ub>>) outs(%6 : memref<65536xi8, #hivm.address_space<ub>>)
-
分析
第1行,原始的数据大小为65536xi8,保存在GM中(kernel的参数%arg3);
第2行,申请一块大小为65536xi8的UB空间;
第3行,将第1行的GM中的形状为65536xi8的数据,COPY到第2行的形状为65536xi8的UB空间中;
第4行,申请一块大小为65536xi8的UB空间;
第5行,将第4行申请的65536xi8的UB空间,全填充-1;
第6行:申请一块大小为65536xi8的UB空间;
第7行:“输入数据”与“-1”做or运算,结果存储到第6行申请的UB空间中;
第8行:申请一块大小为65536xi8的UB空间;
第9行:“输入数据”与“-1”做and运算,结果存储到第8行申请的UB空间中;
第10行:再对第9行的结果做not运算,将结果存储到第8行申请的UB空间中;
第11行:申请一块大小为65536xi8的UB空间;
第12行:将第7行的结果,与第10行的结果,进行and运算,将结果存储到第11行申请的UB空间中;
-
总结
对输入数据input_data进行not操作,mlir翻译成了如下运算:(input_data|(-1))&(!(input_data&(-1)))。
原始的数据大小为:65536B,为了完成(input_data|(-1))&(!(input_data&(-1)))运算,申请了 5 * 65536B 的UB空间。
triton max_dim0 op 在int64类型输入下,先执行PlanMemory,再执行HIVMLowerToLoops,浪费大量UB空间
mlir代码如下:
%2 = hivm.hir.pointer_cast(%c0_i64) : memref<2x4912xi64, #hivm.address_space<ub>>
%3 = hivm.hir.pointer_cast(%c78592_i64) : memref<1x4912xi64, #hivm.address_space<ub>>
%4 = hivm.hir.pointer_cast(%c117888_i64) : memref<9824xi64, #hivm.address_space<ub>>
hivm.hir.vreduce {already_initialize_init} <max> ins(%2 : memref<2x4912xi64, #hivm.address_space<ub>>) outs(%3 : memref<1x4912xi64, #hivm.address_space<ub>>) temp_buffer(%4 : memref<9824xi64, #hivm.address_space<ub>>) reduce_dims = [0]
-
分析:
第1行,输入数据大小为2x4912xi64,在ub中分配,数据来源于GM
第2行,输出数据大小为1x4912xi64,在ub中分配,存储计算结果,最后存储到GM中
第3行,申请一块大小为9824xi64的ub空间作为vreduce操作的临时节点
第4行,对于int64的输入,vreduce操作会在后面lower为loop scalar操作,并将temp_buffer删掉
-
总结:PlanMemory时考虑的temp_buffer在最后计算时并未使用,导致误报ub overflow,需要在PlanMemory前分配temp_buffer的步骤中修改临时节点分配规则
d-cache类
无效地址访问
- 现象 算子输入合法且均为同一个deviceID, 实际算子的deviceID设置不正确,导致无法取到数据,出现D-cache读写错误
- 写法样例 错误样例
A=torch.empty(shape, dtype)
正确样例
A=torch.empty(shape, dtype).npu()
or
DEVICE="npu:0"
A=torch.empty(shape, dtype, device=DEVICE).npu()
使用非负数iter arg作为访存索引
- 现象 由于编译过程会对访存操作进行分析并优化编译结果,若访存操作的索引涉及到复杂的控制流(如for循环索引引入的访问越界),目前编译器或许没有能力完全覆盖,因此建议使用非负数的for循环iter参数作为访存索引。
- 写法样例
以GDN网络的
causal_conv1d_fwd_kernel为例,源代码中i_w可能是负数. 错误样例
for i_w in tl.static_range(-W+1, 1):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
正确样例
for i_w in tl.static_range(W):
p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w - W + 1, i_d * BD), (BT, BD), (1, 0))
访存类
Load隐式转置
- 现象 “隐式转置”是指在加载或存储数据的同时完成矩阵转置操作,避免单独执行一个转置内核或额外的显式数据重排。 它通常通过调整指针的步长和形状来实现,使得内存访问模式隐含地完成维度交换。 这种技术可以节省全局内存带宽、减少内核启动开销,并提高计算效率。
tl.make_block_ptr(base, shape, strides, offsets, block_shape, order)
order参数指定内存中元素的迭代顺序,可以用来实现转置。或者,通过设置strides参数来指示转置后的步长。
实际上,对于矩阵转置,如果我们有一个输入矩阵A (M, K) 和输出矩阵B (K, M),我们可以让每个线程块处理B的一个块,
并从A中加载对应的转置块。加载时,可以使用make_block_ptr从A中加载,但步长设置为导致转置加载的步长?
或者,更常见的做法是加载一个正常的A块,然后使用tl.trans转置后再存储到B。
import torch
import triton
import triton.language as tl
@triton.jit
def transpose_kernel(
x_ptr, y_ptr,
M, N,
stride_xm, stride_xn,
stride_ym, stride_yn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
"""
矩阵转置内核:Y = X^T, 其中 X 形状 (M, N),Y 形状 (N, M)。
每个程序块处理 Y 的一个 (BLOCK_N, BLOCK_M) 子块。
通过交换输入指针的步长,实现隐式转置加载。
"""
pid_n = tl.program_id(0) # 输出矩阵的行块索引(原列块)
pid_m = tl.program_id(1) # 输出矩阵的列块索引(原行块)
bn = pid_n * BLOCK_N # 输出矩阵的行起始 = 原列起始
bm = pid_m * BLOCK_M # 输出矩阵的列起始 = 原行起始
# 构建输入指针:使用交换后的步长,形状 (N, M) 以匹配转置访问
x_ptr_t = tl.make_block_ptr(
base=x_ptr,
shape=(N, M),
strides=(stride_xn, stride_xm),
offsets=(bn, bm),
block_shape=(BLOCK_N, BLOCK_M),
order=(1, 0)
)
# 构建输出指针:正常行主序步长,形状 (N, M)
y_ptr_b = tl.make_block_ptr(
base=y_ptr,
shape=(N, M),
strides=(stride_ym, stride_yn),
offsets=(bn, bm),
block_shape=(BLOCK_N, BLOCK_M),
order=(1, 0)
)
# 加载输入块(已隐式转置),边界检查防止越界
x_tile = tl.load(x_ptr_t, boundary_check=(0, 1))
# 存储到输出矩阵
tl.store(y_ptr_b, x_tile, boundary_check=(0, 1))
def transpose(x, y=None, BLOCK_M=64, BLOCK_N=32):
"""
使用 Triton 内核计算矩阵转置。
Args:
x: torch.Tensor 形状 (M, N)
y: 可选输出张量,形状 (N, M),如果为 None 则自动创建
BLOCK_M: 块大小(沿 M 维度)
BLOCK_N: 块大小(沿 N 维度)
Returns:
y: 转置后的张量
"""
M, N = x.shape
if y is None:
y = torch.empty(N, M, dtype=x.dtype, device=x.device)
else:
assert y.shape == (N, M), f"y 的形状应为 ({N}, {M}),但得到 {y.shape}"
# 计算网格大小
grid = (triton.cdiv(N, BLOCK_N), triton.cdiv(M, BLOCK_M))
# 调用内核
transpose_kernel[grid](
x, y,
M, N,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
)
return y
# 创建一个随机矩阵
x = torch.randn(512, 1024, device='npu')
# 调用转置函数
y = transpose(x)
执行结束不报错,证明运行成功。
使用mayDiscretememaccess规避UB overflow
- 现象 导致UB overflow的成因各异,除了本身张量数据类型过大,导致超出192KB的UB限制,另一个可能的原因是非连续搬运导致UB内扩轴。以
<Nx1xf32>数据类型为例,由于硬件在尾轴需要32B对齐,而1xf32只有4B大小,因此<Nx1xf32>在硬件上的实际大小会被扩轴至<Nx8xf32>以确保32B对齐。无论因为什么原因导致的UB overflow,都可以通过加上mayDiscretememaccess的编译提示,使张量操作退化为标量操作,从而避免UB overflow。 - 写法样例
改写算子时,只需在load/store操作的数据上加上
compile_hint即可,参考以下代码段: triton-ascend 3.2.0之前的版本
# 若为load操作,compile_hint需加在加载出的value中
value = tl.load(pointer)
tl.compile_hint(value, "mayDiscretememaccess")
# 若为store操作,compile_hint需加在被存入的value中
tl.compile_hint(value, "mayDiscretememaccess")
tl.store(pointer, value)
triton-ascend 3.4.0之后的版本需要改成
# 若为load操作,compile_hint需加在加载出的value中
value = tl.load(pointer)
tl.extra.cann.extension.compile_hint(value, "mayDiscretememaccess")
# 若为store操作,compile_hint需加在被存入的value中
tl.extra.cann.extension.compile_hint(value, "mayDiscretememaccess")
tl.store(pointer, value)
- 写法样例1
b_x = tl.load(x + o_t * D + o_d[:, None], mask=(m_t & m_d[:, None]), other=0)
通过增加编译提示,张量访存会被退化为标量访存,避免UB overflow,参考以下代码段:
b_x = tl.load(x + o_t * D + o_d[:, None], mask=(m_t & m_d[:, None]), other=0)
tl.extra.cann.extension.compile_hint(b_x, "mayDiscretememaccess")
- 写法样例2
import triton
import triton.language as tl
+ import triton.language.extra.cann.extension as extension
@triton.jit
def copy_column_major_to_row_major(
A_ptr, B_ptr,
M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
):
# 获取程序ID
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算块起始位置
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
# 创建A的块指针 (列主序: strides=(1, M)),此时最后一维不连续,会自动扩轴
A_block_ptr = tl.make_block_ptr(
base=A_ptr,
shape=(M, N),
strides=(1, M),
offsets=(start_m, start_n),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(0, 1), # 最内层维度是行(索引0),因为列主序
)
# 创建B的块指针 (行主序: strides=(N, 1))
B_block_ptr = tl.make_block_ptr(
base=B_ptr,
shape=(M, N),
strides=(N, 1),
offsets=(start_m, start_n),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0), # 最内层维度是列(索引1),因为行主序
)
# 加载A的块,进行边界检查(超出部分填充0)
a = tl.load(A_block_ptr, boundary_check=(0, 1))
+ # npu
+ extension.compile_hint(a, "mayDiscretememaccess")
# 存储到B
tl.store(B_block_ptr, a, boundary_check=(0, 1))
- 示例2使用compile hint前后的ir对比
// before using tl.compile_hint(a, "mayDiscretememaccess")
module attributes {hacc.target = #hacc.target<"Ascend910B3">} {
func.func @copy_column_major_to_row_major(%arg0: memref<?xi8> , %arg1: memref<?xi8> , %arg2: memref<?xf32> {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32} , %arg3: memref<?xf32> {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} , %arg6: i32 , %arg7: i32 , %arg8: i32 , %arg9: i32 , %arg10: i32 , %arg11: i32 ) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv", parallel_mode = "simd"} {
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%c64_i32 = arith.constant 64 : i32
%0 = arith.muli %arg9, %c64_i32 : i32
%1 = arith.muli %arg10, %c64_i32 : i32
%2 = arith.maxsi %0, %c0_i32 : i32
%3 = arith.index_cast %2 : i32 to index
%4 = arith.maxsi %1, %c0_i32 : i32
%5 = arith.index_cast %4 : i32 to index
%6 = arith.index_cast %arg5 : i32 to index
%7 = arith.muli %3, %6 : index
%8 = arith.index_cast %arg4 : i32 to index
%9 = arith.addi %7, %5 : index
%reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%9], sizes: [64, 64], strides: [%6, 1] : memref<?xf32> to memref<64x64xf32, strided<[?, 1], offset: ?>>
%10 = arith.muli %5, %8 : index
%11 = arith.addi %10, %3 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%11], sizes: [64, 64], strides: [%8, 1] : memref<?xf32> to memref<64x64xf32, strided<[?, 1], offset: ?>>
%alloc = memref.alloc() : memref<64x64xf32>
%12 = arith.divsi %11, %8 : index
%13 = arith.subi %6, %12 : index
%14 = arith.maxsi %13, %c0 : index
%15 = arith.minsi %14, %c64 : index
%16 = arith.remsi %11, %8 : index
%17 = arith.subi %8, %16 : index
%18 = arith.maxsi %17, %c0 : index
%19 = arith.minsi %18, %c64 : index
%20 = arith.subi %c0_i32, %1 : i32
%21 = arith.maxsi %20, %c0_i32 : i32
%22 = arith.index_cast %21 : i32 to index
%23 = arith.minsi %22, %15 : index
%24 = arith.subi %15, %23 : index
%25 = arith.subi %c0_i32, %0 : i32
%26 = arith.maxsi %25, %c0_i32 : i32
%27 = arith.index_cast %26 : i32 to index
%28 = arith.minsi %27, %19 : index
%29 = arith.subi %19, %28 : index
%subview = memref.subview %reinterpret_cast_0[0, 0] [%24, %29] [1, 1] : memref<64x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_1 = memref.subview %alloc[%23, %28] [%24, %29] [1, 1] : memref<64x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%30 = bufferization.to_tensor %alloc restrict writable : memref<64x64xf32>
%31 = tensor.empty() : tensor<64x64xf32>
%transposed = linalg.transpose ins(%30 : tensor<64x64xf32>) outs(%31 : tensor<64x64xf32>) permutation = [1, 0]
%32 = arith.divsi %9, %6 : index
%33 = arith.subi %8, %32 : index
%34 = arith.maxsi %33, %c0 : index
%35 = arith.minsi %34, %c64 : index
%36 = arith.remsi %9, %6 : index
%37 = arith.subi %6, %36 : index
%38 = arith.maxsi %37, %c0 : index
%39 = arith.minsi %38, %c64 : index
%40 = arith.minsi %27, %35 : index
%41 = arith.subi %35, %40 : index
%42 = arith.minsi %22, %39 : index
%43 = arith.subi %39, %42 : index
%extracted_slice = tensor.extract_slice %transposed[%40, %42] [%41, %43] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
%subview_2 = memref.subview %reinterpret_cast[0, 0] [%41, %43] [1, 1] : memref<64x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview_2 : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
}
// after using tl.compile_hint(a, "mayDiscretememaccess")
module attributes {hacc.target = #hacc.target<"Ascend910B3">} {
func.func @copy_column_major_to_row_major(%arg0: memref<?xi8> , %arg1: memref<?xi8> , %arg2: memref<?xf32> {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32} , %arg3: memref<?xf32> {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} , %arg6: i32 , %arg7: i32 , %arg8: i32 , %arg9: i32 , %arg10: i32 , %arg11: i32 ) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv", parallel_mode = "simd"} {
%c0_i32 = arith.constant 0 : i32
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c64_i32 = arith.constant 64 : i32
%0 = arith.muli %arg9, %c64_i32 : i32
%1 = arith.muli %arg10, %c64_i32 : i32
%2 = arith.extsi %arg5 : i32 to i64
%3 = arith.maxsi %1, %c0_i32 : i32
%4 = arith.index_cast %3 : i32 to index
%5 = arith.maxsi %0, %c0_i32 : i32
%6 = arith.index_cast %5 : i32 to index
%7 = arith.index_cast %arg4 : i32 to index
%8 = arith.muli %4, %7 : index
%9 = arith.index_cast %arg5 : i32 to index
%10 = arith.addi %8, %6 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%10], sizes: [64, 64], strides: [%7, 1] : memref<?xf32> to memref<64x64xf32, strided<[?, 1], offset: ?>>
%alloc = memref.alloc() : memref<64x64xf32>
%11 = arith.divsi %10, %7 : index
%12 = arith.subi %9, %11 : index
%13 = arith.maxsi %12, %c0 : index
%14 = arith.minsi %13, %c64 : index
%15 = arith.remsi %10, %7 : index
%16 = arith.subi %7, %15 : index
%17 = arith.maxsi %16, %c0 : index
%18 = arith.minsi %17, %c64 : index
%19 = arith.subi %c0_i32, %1 : i32
%20 = arith.maxsi %19, %c0_i32 : i32
%21 = arith.index_cast %20 : i32 to index
%22 = arith.minsi %21, %14 : index
%23 = arith.subi %14, %22 : index
%24 = arith.subi %c0_i32, %0 : i32
%25 = arith.maxsi %24, %c0_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%27 = arith.minsi %26, %18 : index
%28 = arith.subi %18, %27 : index
%subview = memref.subview %reinterpret_cast[0, 0] [%23, %28] [1, 1] : memref<64x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_0 = memref.subview %alloc[%22, %27] [%23, %28] [1, 1] : memref<64x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview, %subview_0 : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%29 = bufferization.to_tensor %alloc restrict writable : memref<64x64xf32>
%30 = tensor.empty() : tensor<64x64xf32>
%transposed = linalg.transpose ins(%29 : tensor<64x64xf32>) outs(%30 : tensor<64x64xf32>) permutation = [1, 0]
%31 = arith.index_cast %arg4 : i32 to index
%32 = arith.minsi %31, %c64 : index
scf.for %arg12 = %c0 to %32 step %c1 {
%33 = arith.index_cast %arg5 : i32 to index
%34 = arith.minsi %33, %c64 : index
scf.for %arg13 = %c0 to %34 step %c1 {
%35 = arith.index_cast %arg12 : index to i64
%36 = arith.extsi %0 : i32 to i64
%37 = arith.muli %2, %36 : i64
%38 = arith.muli %2, %35 : i64
%39 = arith.addi %37, %38 : i64
%40 = arith.index_cast %arg13 : index to i64
%41 = arith.extsi %1 : i32 to i64
%42 = arith.addi %39, %41 : i64
%43 = arith.addi %42, %40 : i64
%44 = arith.index_cast %43 : i64 to index
%extracted = tensor.extract %transposed[%arg12, %arg13] {DiscreteMemAccess} : tensor<64x64xf32>
%45 = tensor.empty() : tensor<1xf32>
%inserted = tensor.insert %extracted into %45[%c0] : tensor<1xf32>
%reinterpret_cast_1 = memref.reinterpret_cast %arg3 to offset: [%44], sizes: [1], strides: [1] : memref<?xf32> to memref<1xf32, strided<[1], offset: ?>>
bufferization.materialize_in_destination %inserted in writable %reinterpret_cast_1 : (tensor<1xf32>, memref<1xf32, strided<[1], offset: ?>>) -> ()
} {ExtractedLoadOrStore}
} {ExtractedLoadOrStore}
return
}
}
场景化调试举例
本章节介绍Triton NPU算子性能优化指南。
使用bitwise_mask优化访存掩码
问题描述
在昇腾硬件上,布尔类型(i1)的张量在全局内存(GM)中实际是按i8(一个字节)存储的。当Triton Ascend处理以i1张量作为输入的运算时 ,它会将i1视为i8搬入,但某些情况下(例如作为tl.where的条件掩码)又需要将结果转换回i1,导致不必要的类型转换,带来性能损耗。
为了解决这个问题,提供了compile_hint: "bitwise_mask"。通过该提示,编译器可以识别出该i1张量是作为位掩码使用的,从而直接按位操作,避免中间的类型转换,提升性能。
具体使用方法只需在where后的结果加上compile_hint("bitwise_mask")即可,参考以下代码段:
mask = tl.where(cond, value1, value2)
tl.compile_hint(cond, "bitwise_mask")
需留意,由于mask以bitmask的形式表达,因此对应的mask指针偏移量也需正确运算。


使用 compile_hint 需要注意本地的 TA 版本。
triton-ascend 3.2.0之前的版本: tl.compile_hint(cond, "bitwise_mask")
triton-ascend 3.4.0 之后的版本需要改成: tl.extra.cann.extension.compile_hint(cond, "bitwise_mask")
bitmask 功能 cann9.0 之后的版本才有,因此需要下载 cann.9.0 之后的版本。
算子示例
参考 Ascend where 算子进行改写, 若用户需要输入bitwise的i8掩码作为算子入参,只需为tl.where的结果加上compile_hint即可。
其中依赖的代码脚本请下载链接,并将其和测试脚本放在同个目录下执行python3 test_bitmask.py。
triton testcommon script
# test_bitmask.py
import triton
import triton.language as tl
import torch
import torch_npu
import test_common
@triton.jit
def triton_where_lt_case1(in_ptr0, in_ptr1, cond_ptr, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB):
xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:]
xmask = xindex < xnumel
in0 = tl.load(in_ptr0 + xindex, xmask)
in1 = tl.load(in_ptr1 + xindex, xmask)
cond = tl.load(cond_ptr + xindex, xmask)
res = tl.where(cond, in1, in0)
# versions after triton-ascend 3.4.0
# tl.extra.cann.extension.compile_hint(cond, "bitwise_mask")
# versions before triton-ascend 3.2.0
tl.compile_hint(cond, "bitwise_mask")
tl.store(out_ptr0 + (xindex), res, xmask)
def test_where_lt_case1():
dtype = "float32"
shape = (1, 1024, 8)
ncore = 1
xblock = 8192
xblock_sub = 1024
if shape[-1] %8 != 0:
raise ValueError("The last dimension should be a multiple of 8")
x0 = test_common.generate_tensor(shape, dtype).npu()
x1 = test_common.generate_tensor(shape, dtype).npu()
# Run triton with i8 bitwise mask
cond_i8 = test_common.generate_tensor(shape, 'uint8').npu()
y_cal = test_common.generate_tensor(shape, dtype).npu()
triton_where_lt_case1[ncore, 1, 1](x0, x1, cond_i8, y_cal, x0.numel(), xblock, xblock_sub)
test_where_lt_case1()
执行结束不报错,证明运行成功。
切分逻辑
bitmask和切分逻辑绑定的,算子自身在不同场景下有不同的切分逻辑,当中包括但不限于 (1) CV场景下使能1:2性能优化 (2) 融轴 (3) broadcast场景 (4)非硬件支撑的数据类型 (5) triton算子输入非1的grid切块 等等。由于面对不同的场景,切块逻辑各异,我们有一个泛化的组mask例子(由i8 bitmask组出i1的标杆mask)供你参考,这个组mask逻辑不考虑场景,是从bitmask结果的误差推导组mask逻辑的
假设这是本来的组mask逻辑
for i in range(numel // 8):
byte_value = flatten_cond_i8[i]
for bit in range(8):
flatten_cond_i1[..., i*8 + bit] = (byte_value & (1 << bit)) != 0
假设在某个场景下,shape为(2,X,X,X)时,vimdiff结果为

而在同一场景下,shape为(3,X,X,X)时,vimdiff结果为

由此感知,当shape为(A, X, X, X)时,上述场景的切分逻辑是按首轴(即A)处理,错误的组mask逻辑导致只有首轴的首个切分精度对齐,而剩下的有(A-1)/A的数据则有偏差,如此,精度验证的标杆组mask逻辑就需要考虑A了,见以下代码:
for sub_A in range(A):
# The offset calculation depends on the logic of the kernel
offset_sub_A = D * B * sub_A
for i in range(min(numel, B * D) // 8):
byte_value = flatten_cond_i8[offset_sub_A + i]
for bit in range(8):
flatten_cond_i1[..., offset_sub_A + i*8 + bit] = (byte_value & (1 << bit)) != 0
通过上述的最佳实践,bitmask功能即可通过高度泛化的方法正确实现。
此外,以下亦提供多重切分的逻辑供参考:
# test_bitmask_tile.py
import triton
import triton.language as tl
import torch
import torch_npu
import pytest
import test_common
from itertools import product
def torch_where_lt_case1(x0, x1, cond):
res = torch.where(cond, x0, x1)
return res
@triton.jit
def triton_bitmask(in_ptr0, in_ptr1, cond_ptr, out_ptr0,
X_BLOCK_SIZE: tl.constexpr, Y_BLOCK_SIZE: tl.constexpr, Z_BLOCK_SIZE: tl.constexpr,
X_STRIDE: tl.constexpr, Y_STRIDE: tl.constexpr, Z_STRIDE: tl.constexpr):
# Calculate the offset according to the grid
xoffset = tl.program_id(0) * X_BLOCK_SIZE
yoffset = tl.program_id(1) * Y_BLOCK_SIZE
zoffset = tl.program_id(2) * Z_BLOCK_SIZE
xindex = X_STRIDE * (xoffset + tl.arange(0, X_BLOCK_SIZE))[:, None, None]
yindex = Y_STRIDE * (yoffset + tl.arange(0, Y_BLOCK_SIZE))[None, :, None]
zindex = Z_STRIDE * (zoffset + tl.arange(0, Z_BLOCK_SIZE))[None, None, :]
offset = xindex + yindex + zindex
# Load in0 and in1
in0 = tl.load(in_ptr0 + offset)
in1 = tl.load(in_ptr1 + offset)
cond = tl.load(cond_ptr + offset)
# bitwise where and store
mask = tl.where(cond, in0, in1)
# versions after triton-ascend 3.4.0
# tl.extra.cann.extension.compile_hint(mask, "bitwise_mask")
# versions before triton-ascend 3.2.0
tl.compile_hint(mask, "bitwise_mask")
tl.store(out_ptr0 + offset, mask)
@pytest.mark.parametrize('param_list',
[
['float32', (16, 16, 32), (2, 2, 2)],
['int32', (16, 32, 16), (2, 2, 2)],
['int16', (32, 16, 16), (2, 2, 2)],
['float16', (8, 8, 64), (8, 8, 8)],
['float32', (8, 8, 24), (4, 4, 3)],
['int32', (1, 1, 1024), (1, 1, 16)],
['int16', (1, 1, 16), (1, 1, 2)],
['float16', (8, 80, 16), (1, 80, 2)],
]
)
def test_where_lt_case1(param_list):
# Checking and constant value creation
dtype, shape, grid = param_list
if shape[0] % shape[0] != 0 or \
shape[1] % shape[1] != 0 or \
shape[2] % shape[2] != 0 :
raise ValueError("Shape is not divisible by grid")
x_block_size = shape[0] // grid[0]
y_block_size = shape[1] // grid[1]
z_block_size = shape[2] // grid[2]
if z_block_size%8 != 0:
raise ValueError("The last dimension should be a multiple of 8")
if grid[-1] == 1:
raise ValueError("Please tile the last dim")
if(dtype in ["bool", "int8", "uint8", "int64"]):
raise ValueError(f"The torch mask tiling logic is not applicable with {dtype} type")
x_stride = shape[-1] * shape[-2]
y_stride = shape[-1]
z_stride = 1
# Run triton with i8 bitwise mask
x0 = test_common.generate_tensor(shape, dtype).npu()
x1 = test_common.generate_tensor(shape, dtype).npu()
cond_i8 = test_common.generate_tensor(shape, 'uint8').npu()
y_cal = test_common.generate_tensor(shape, dtype).npu()
triton_bitmask[grid](x0, x1, cond_i8, y_cal, x_block_size, y_block_size, z_block_size, x_stride, y_stride, z_stride)
# Run torch with i1 mask
flatten_cond_bool = torch.zeros(cond_i8.flatten().shape, dtype=torch.bool).npu()
for x_block_id, y_block_id, z_block_id in product(range(grid[0]), range(grid[1]), range(grid[2])):
flatten_subview_cond_i8 = cond_i8[x_block_id * x_block_size: (x_block_id+1) * x_block_size,
y_block_id * y_block_size: (y_block_id+1) * y_block_size,
z_block_id * z_block_size: (z_block_id+1) * z_block_size].flatten()
for i in range(flatten_subview_cond_i8.shape[-1]// 8):
# Get the corresponding i8 value
i8_z_block_offset = i % (z_block_size // 8)
i8_y_block_offset = i // (z_block_size // 8) % y_block_size * z_block_size
i8_x_block_offset = i // (z_block_size // 8) // y_block_size * y_block_size * z_block_size
i8_offset = i8_z_block_offset + i8_y_block_offset + i8_x_block_offset
byte_value = flatten_subview_cond_i8[i8_offset]
# Set the corresponding i1 value
i1_z_block_offset = (z_block_id * z_block_size + (i * 8) % z_block_size) * z_stride
i1_y_block_offset = (y_block_id * y_block_size + (i * 8) // z_block_size % y_block_size) * y_stride
i1_x_block_offset = (x_block_id * x_block_size + (i * 8) // z_block_size // y_block_size) * x_stride
i1_offset = i1_x_block_offset + i1_y_block_offset + i1_z_block_offset
for bit in range(8):
flatten_cond_bool[..., i1_offset + bit] = (byte_value & (1 << bit)) != 0
cond_bool = flatten_cond_bool.view(shape)
y_ref = torch_where_lt_case1(x0, x1, cond_bool)
# Precision test
print("y_cal: ", y_cal)
print("y_ref: ", y_ref)
test_common.validate_cmp(dtype, y_cal, y_ref)
限制
- 由于Triton前端会将i1转换为i8,如果对其他类型如i16/i32等进行bitwise_mask操作反而会带来性能损耗,因此此功能只支持i8类型的mask
使用手动对齐提升尾轴不对齐场景的编译器优化效率
问题描述
在Triton算子开发中,当张量的尾轴维度较小(如4)且未对齐到硬件建议的32字节(对应8个float32元素)时,编译器后端在处理此类非对齐形状时,往往难以生成最优的连续访存和向量化指令,导致性能无法充分发挥。为获得更好的编译器优化效果,推荐开发者在前端kernel中通过手动填充(padding)或mask加载的方式,将数据尾轴维度显式对齐到合适的宽度,从而为编译器提供对齐友好的数据布局。这样能够简化后端的优化决策,显著提升执行效率。
算子示例
以下展示了尾轴为4的两种kernel实现:版本1直接使用4作为尾轴维度,未做对齐处理,性能较差;版本2通过mask加载将尾轴维度对齐至8,是推荐的优化写法。
版本1:尾轴未对齐(存在优化瓶颈)
@triton.jit
def kernel(in_ptr, out_ptr, batch_size,
D: tl.constexpr, iters: tl.constexpr,
eps: tl.constexpr, group: tl.constexpr):
lin = tl.arange(0, D * D)
pid0 = tl.program_id(0) * group
pids = pid0 + tl.arange(0, group)
mask = pids < batch_size
off = pids[:, None] * (D * D)
# 直接加载 D×D 矩阵,无对齐填充
mat = tl.load(in_ptr + off + lin[None, :], mask=mask[:, None])
mat = mat.reshape(group, D, D)
row_max = tl.max(mat, axis=2)
mat = tl.exp(mat - row_max[:, :, None])
for _ in range(iters):
row_sum = tl.sum(mat, axis=2)
mat = mat / (row_sum[:, :, None] + eps)
col_sum = tl.sum(mat, axis=1)
mat = mat / (col_sum[:, None, :] + eps)
mat_flat = tl.reshape(mat, (group, D * D))
tl.store(out_ptr + off + lin[None, :], mat_flat, mask=mask[:, None])
版本2:手动对齐(推荐)
@triton.jit
def kernel_opt(in_ptr, out_ptr, batch_size,
D: tl.constexpr, iters: tl.constexpr,
eps: tl.constexpr, group: tl.constexpr,
ALIGN: tl.constexpr = 8):
pid0 = tl.program_id(0) * group
pids = pid0 + tl.arange(0, group)
p_mask = pids < batch_size
# 基于原始 D×D 形状,每次加载 ALIGN 个元素
off_base = pids[:, None, None] * (D * D)
row_idx = tl.arange(0, D)[:, None]
col_idx = tl.arange(0, ALIGN)[None, :]
offs = row_idx * D + col_idx
valid_cols = col_idx < D
# 通过掩码将无效列填充为 -inf,实现手动对齐
# 形状 (group, D, ALIGN)
mat = tl.load(
in_ptr + off_base + offs[None, :, :],
mask=p_mask[:, None, None] & valid_cols[None, :, :],
other=float('-inf')
)
# 归一化计算(无效列在 exp 后变为 0,不影响结果)
row_max = tl.max(mat, axis=2)
mat = tl.exp(mat - row_max[:, :, None])
for _ in range(iters):
row_sum = tl.sum(mat, axis=2)
mat = mat / (row_sum[:, :, None] + eps)
col_sum = tl.sum(mat, axis=1)
mat = mat / (col_sum[:, None, :] + eps)
# 按 ALIGN 对齐宽度写回
out_flat = tl.reshape(mat, (group, D * ALIGN))
tl.store(out_ptr + pids[:, None] * (D * ALIGN)
+ tl.arange(0, D * ALIGN)[None, :],
out_flat, mask=p_mask[:, None])
版本2通过手动将尾轴维度对齐到8,编译器可以直接利用连续、对齐的访存模式生成高效指令,避免了因尾轴不对齐可能引入的额外处理开销,从而提升整体性能。
限制
- 手动对齐要求
ALIGN为编译期常量,且应等于硬件建议的对齐宽度。 - 填充值(如
-inf)需与后续计算兼容,确保不影响最终结果(例如exp(-inf) = 0)。
CV类
使用hivm.tile_mix_cube_num规避L1越界
问题描述
由于编译器目前只能对单个matmul进行切分需求分析,并不考虑其他matmul的生命周期,因此当matmul被多次触发时(例如执行逻辑为cube -> vector -> cube时),若上一个matmul的生命周期和当前的matmul生命周期有所重叠,算子运行时可能会导致L1越界。后续编译器会对切分的生命周期分析进行增强,目前则需通过加上 hivm.tile_mix_cube_num 编译提示,令编译器可以感知是否需要对相关的matmul操作进行sub tiling。
算子示例
改写算子时,只需为dot操作结果加上hivm.tile_mix_cube_num的编译提示即可,参考以下代码段:
res = tl.dot(lhs, rhs)
tl.compile_hint(res, "hivm.tile_mix_cube_num", 2)
以Flash Attention的_attn_fwd_inner算子为例,原代码的QKV矩阵乘法逻辑大致为
qk = tl.dot(q, trans_k)
# softmax calculation in between
qk = ...
p = tl.math.exp(qk)
pv = tl.dot(p, v)
参考以上代码,qk是cube操作,而softmax等计算属于vector操作,最后vector计算出的结果又再导入到第二次的cube操作中执行矩阵乘法。在以上场景下,编译器无法监控第二次cube操作中的切分逻辑,代码或许会在L1缓存中越界。因此,需要为第二次的dot操作结果加上tile_mix_cube_num的编译提示,令编译器对该操作进行sub tiling,见以下代码段:
qk = tl.dot(q, trans_k)
# softmax calculation in between
qk = ...
p = tl.math.exp(qk)
pv = tl.dot(p, v)
tl.compile_hint(pv, "hivm.tile_mix_cube_num", 2)
参考:编译优化选项
| 编译选项 | 含义 | 取值范围 |
|---|---|---|
| multibuffer | 设置是否启用乒乓流水 | False(默认),True |
| limit_auto_multi_buffer_of_local_buffer | 设置乒乓流水在片中(L1, L0, 及UB)的作用范围"no-limit"表示不限乒乓流水范围"no-l0c"表示只允许L0缓存外启用乒乓流水 | "no-limit","no-l0c"(默认) |
| unit_flag | 设置cube搬出时是否按照block搬出,仅限数据对齐场景下使用 | False(默认),True |
| limit_auto_multi_buffer_only_for_local_buffer | 设置是否在GM workspace中启用CV流水并行,False表示启用 后续会整改接口,提供更可读的选项 | False(默认),True |
| set_workspace_multibuffer | 仅在limit_auto_multi_buffer_only_for_local_buffer=false的场景下生效设置CV并行的并行度使用时需确保数据没有依赖若设置为N,则N个CV操作并行执行 | 2 (默认),4 |
| tile_mix_vector_loop | 仅在limit_auto_multi_buffer_only_for_local_buffer=false的场景下生效设置当前vector的切分数量,数值可由autotuning得出,均可为最优 | 1 (默认),2,4 |
| tile_mix_cube_loop | 仅在limit_auto_multi_buffer_only_for_local_buffer=false的场景下生效设置当前cube的切分数量,数值可由autotuning得出,均可为最优 | 1 (默认),2,4 |
算子选项规避超时报错
问题描述
导致算子卡死的部分原因是与硬件同步相关,其中可能涉及核内/间同步,或涉及流水同步。若遇上算子卡死的情况,你可以尝试在调用Kernel时,传入以下入参,修改二进制的同步逻辑,以规避算子卡死的问题。
# 核同步选项
inject_block_all = True # 开启核间同步
inject_barrier_all = True # 开启核内同步
# 流水选项
limit_auto_multi_buffer_only_for_local_buffer = True # 关闭(GM space) CV流水
multibuffer = False # 关闭乒乓流水
算子示例
以GDN网络的chunk_gated_delta_rule_fwd_kernel_h_blockdim64算子为例,原代码调用为
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
关闭CV流水后的调用则为
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
limit_auto_multi_buffer_only_for_local_buffer = True,
)
Triton NPU 编程案例
Triton NPU 编程请参考: https://github.com/Ascend/triton-ascend-ops/blob/main/tutorial/README.zh.md