"""
Block Attention Residuals Implementation Module
This module implements the core computation functions for Block Attention Residuals.
It provides both forward and backward passes, with optional RMSNorm support and cache outputs.
Main Functions:
- ai_infra_block_attn_res: Forward pass for block attention residuals
- ai_infra_block_attn_res_backward: Backward pass for block attention residuals
Kernels:
- ai_infra_block_attn_res_kernel: 前向 kernel
- ai_infra_block_attn_res_backward_kernel: 反向 kernel
"""
import typing
from typing import List, Optional
import pypto
import torch
@pypto.frontend.function
def ai_infra_block_attn_res_forward_kernel(
v_flat, proj_weight, gamma_fp32, h_out, rms_cache, alpha_cache_3d,
scale, rms_norm_eps, enable_rmsnorm, l_max
):
"""PyPTO function ai_infra_block_attn_res_forward_kernel
v_flat: [B*T, L, D] bf16/fp16
proj_weight: [1, 1, D] fp32
gamma_fp32: [1, 1, D] fp32
h_out: [B*T, 1, D] bf16/fp16
rms_cache: [B*T, L, 1] fp32
alpha_cache_3d: [B*T, L] fp32
scale: softmax 缩放因子
rms_norm_eps: RMSNorm epsilon
enable_rmsnorm: 是否启用 RMSNorm
l_max: L分档大小
"""
bt, l, d = v_flat.shape
dtype = v_flat.dtype
unroll_list = [1, 64, 128]
norm_m_tile = 64 * 1024 // (4 * d)
norm_m_tile = max(norm_m_tile, 1)
if d == 512:
pypto.set_pass_options(
vec_nbuffer_setting={"DEFAULT": 8},
cube_nbuffer_setting={"DEFAULT": 8, "func8_0": 16},
cube_l1_reuse_setting={"DEFAULT": 2}
)
elif d == 1536:
pypto.set_pass_options(
vec_nbuffer_setting={"DEFAULT": 8, "func8_2": 48},
cube_nbuffer_setting={"DEFAULT": 16},
cube_l1_reuse_setting={"DEFAULT": 2}
)
pypto.experimental.set_operation_options(combine_axis=True)
for bt_idx, unroll_length in pypto.loop_unroll(
0, bt, 1,
name="Block_Attn_Res_T_Loop",
idx_name="bt_idx",
unroll_list=unroll_list,
):
tile_bt = unroll_length
v = pypto.view(v_flat, [tile_bt, l_max, d], [bt_idx, 0, 0],
valid_shape=[tile_bt, l, d])
if enable_rmsnorm:
pypto.set_semantic_label("RmsNorm")
pypto.set_vec_tile_shapes(1, norm_m_tile, d)
v_fp32 = pypto.cast(v, pypto.DT_FP32)
v_sq = pypto.mul(v_fp32, v_fp32)
v_sq_sum = pypto.sum(v_sq, dim=-1, keepdim=True)
mean_val = pypto.div(v_sq_sum, d, pypto.PrecisionType.INTRINSIC)
rms_val = pypto.sqrt(pypto.add(mean_val, rms_norm_eps))
k_fp32 = pypto.div(v_fp32, rms_val, pypto.PrecisionType.INTRINSIC)
k_fp32 = pypto.mul(k_fp32, gamma_fp32)
pypto.assemble(rms_val, [bt_idx, 0, 0], rms_cache)
pypto.set_semantic_label("Attention Scores")
if enable_rmsnorm:
matmul_lhs = k_fp32
else:
pypto.set_vec_tile_shapes(1, norm_m_tile, d)
matmul_lhs = pypto.cast(v, pypto.DT_FP32)
pypto.set_vec_tile_shapes(1, 1, d)
if l_max == 32:
pypto.set_cube_tile_shapes([l_max, l_max], [512, 512], [16, 16])
else:
pypto.set_cube_tile_shapes([l_max, l_max], [128, 512], [16, 16])
logits_3d = pypto.matmul(proj_weight, matmul_lhs, pypto.DT_FP32, a_trans=False, b_trans=True)
logits = pypto.reshape(logits_3d, [tile_bt, l_max], valid_shape=[tile_bt, l])
pypto.set_semantic_label("Attention Softmax")
pypto.set_vec_tile_shapes(128, l_max)
if scale != 1.0:
logits = pypto.mul(logits, scale)
alpha = pypto.softmax(logits, dim=-1)
pypto.assemble(alpha, [bt_idx, 0], alpha_cache_3d)
alpha_3d = pypto.reshape(alpha, [tile_bt, 1, l_max], valid_shape=[tile_bt, 1, l])
if not enable_rmsnorm:
pypto.set_vec_tile_shapes(1, norm_m_tile, d)
v_fp32 = pypto.cast(v, pypto.DT_FP32)
pypto.set_semantic_label("Weighted Summation")
if l_max == 32:
pypto.set_cube_tile_shapes([1, 1], [l_max, l_max], [256, 256])
else:
pypto.set_cube_tile_shapes([1, 1], [128, 128], [128, 128])
h = pypto.matmul(alpha_3d, v_fp32, dtype)
pypto.assemble(h, [bt_idx, 0, 0], h_out)
@pypto.frontend.jit(
runtime_options={"stitch_function_max_num": 64, "device_sched_mode": 0},
)
def ai_infra_block_attn_res_forward_kernel_l_max_32(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC]),
h_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC]),
scale: float,
rms_norm_eps: float,
enable_rmsnorm: bool,
):
ai_infra_block_attn_res_forward_kernel(
v_flat, proj_weight, gamma_fp32, h_out, rms_cache, alpha_cache_3d,
scale, rms_norm_eps, enable_rmsnorm, 32)
@pypto.frontend.jit(
runtime_options={"stitch_function_max_num": 64, "device_sched_mode": 0},
)
def ai_infra_block_attn_res_forward_kernel_l_max_64(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC]),
h_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC]),
scale: float,
rms_norm_eps: float,
enable_rmsnorm: bool,
):
ai_infra_block_attn_res_forward_kernel(
v_flat, proj_weight, gamma_fp32, h_out, rms_cache, alpha_cache_3d,
scale, rms_norm_eps, enable_rmsnorm, 64)
@pypto.frontend.jit(
runtime_options={"stitch_function_max_num": 64, "device_sched_mode": 0},
)
def ai_infra_block_attn_res_forward_kernel_l_max_96(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC]),
h_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC]),
scale: float,
rms_norm_eps: float,
enable_rmsnorm: bool,
):
ai_infra_block_attn_res_forward_kernel(
v_flat, proj_weight, gamma_fp32, h_out, rms_cache, alpha_cache_3d,
scale, rms_norm_eps, enable_rmsnorm, 96)
@pypto.frontend.jit(
runtime_options={"stitch_function_max_num": 64, "device_sched_mode": 0},
)
def ai_infra_block_attn_res_forward_kernel_l_max_128(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC]),
h_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC]),
scale: float,
rms_norm_eps: float,
enable_rmsnorm: bool,
):
ai_infra_block_attn_res_forward_kernel(
v_flat, proj_weight, gamma_fp32, h_out, rms_cache, alpha_cache_3d,
scale, rms_norm_eps, enable_rmsnorm, 128)
def _validate_forward_inputs(
blocks: List[torch.Tensor],
proj_weight: torch.Tensor,
partial_block: Optional[torch.Tensor],
rmsnorm_gamma: Optional[torch.Tensor],
enable_rmsnorm: bool,
rms_out_flag: bool,
):
if not blocks:
raise ValueError("blocks must not be empty")
n = len(blocks)
if n < 1 or n > 127:
raise ValueError(f"N (number of blocks) must be in [1, 127], got {n}")
b, t, d = blocks[0].shape
dtype = blocks[0].dtype
for i, block in enumerate(blocks):
if block.shape != (b, t, d):
raise ValueError(f"blocks[{i}] shape {block.shape} != expected {(b, t, d)}")
if block.dtype != dtype:
raise ValueError(f"blocks[{i}] dtype {block.dtype} != expected {dtype}")
if partial_block is not None:
if partial_block.shape != (b, t, d):
raise ValueError(f"partial_block shape {partial_block.shape} != expected {(b, t, d)}")
if partial_block.dtype != dtype:
raise ValueError(f"partial_block dtype {partial_block.dtype} != expected {dtype}")
if proj_weight.shape != (d,):
raise ValueError(f"proj_weight shape must be [D], got {proj_weight.shape}")
if proj_weight.dtype != dtype:
raise ValueError(f"proj_weight dtype {proj_weight.dtype} != expected {dtype}")
if enable_rmsnorm and rmsnorm_gamma is None:
raise ValueError("rmsnorm_gamma is required when enable_rmsnorm=True")
if rmsnorm_gamma is not None:
if rmsnorm_gamma.shape != (d,):
raise ValueError(f"rmsnorm_gamma shape must be [D], got {rmsnorm_gamma.shape}")
if rmsnorm_gamma.dtype != dtype:
raise ValueError(f"rmsnorm_gamma dtype {rmsnorm_gamma.dtype} != expected {dtype}")
if rms_out_flag and not enable_rmsnorm:
raise ValueError("rms_out_flag=True is invalid when enable_rmsnorm=False (no RMSNorm output to cache)")
def ai_infra_block_attn_res(
blocks: List[torch.Tensor],
proj_weight: torch.Tensor,
partial_block: Optional[torch.Tensor] = None,
scale: float = 1.0,
rmsnorm_eps: float = 1e-6,
rmsnorm_gamma: Optional[torch.Tensor] = None,
enable_rmsnorm: bool = True,
rms_out_flag: bool = False,
alpha_out_flag: bool = False,
):
"""Block Attention Residuals 前向算子调用入口。
Args:
blocks: N 个已完成 block 表示, 每个 [B, T, D]
proj_weight: 伪查询投影权重 [D]
partial_block: 当前 block 部分和 [B, T, D]
scale: softmax 缩放因子
rmsnorm_eps: RMSNorm epsilon
rmsnorm_gamma: RMSNorm 缩放参数 [D]
enable_rmsnorm: 是否启用 RMSNorm
rms_out_flag: 是否输出 rms_cache 供反向计算复用, 默认 False
alpha_out_flag: 是否输出 alpha_cache 供反向计算复用, 默认 False
Returns:
- 仅 output [B,T,D] (两个 flag 均为 False)
- (output, rms_cache [B,T,L,1]) (仅 rms_out_flag=True)
- (output, alpha_cache [B,T,L]) (仅 alpha_out_flag=True)
- (output, rms_cache, alpha_cache) (两个 flag 均为 True)
"""
_validate_forward_inputs(blocks, proj_weight, partial_block,
rmsnorm_gamma, enable_rmsnorm, rms_out_flag)
n = len(blocks)
l = n + 1 if partial_block is not None else n
b, t, d = blocks[0].shape
dtype = blocks[0].dtype
device = blocks[0].device
need_cache = rms_out_flag or alpha_out_flag
if partial_block is not None:
v = torch.stack(blocks + [partial_block], dim=2)
else:
v = torch.stack(blocks, dim=2)
v = v.reshape(b * t, l, d)
q = proj_weight.reshape(1, 1, d).to(torch.float32)
gamma_fp32 = (rmsnorm_gamma.reshape(1, 1, d) if rmsnorm_gamma is not None
else torch.ones(1, 1, d, dtype=dtype, device=device))
gamma_fp32 = gamma_fp32.to(torch.float32)
h_out = torch.zeros(b * t, 1, d, dtype=dtype, device=device)
rms_cache_flat = torch.zeros(b * t, l, 1, dtype=torch.float32, device=device)
alpha_cache_flat = torch.zeros(b * t, l, dtype=torch.float32, device=device)
if l <= 32:
ai_infra_block_attn_res_forward_kernel_l_max_32(
v, q, gamma_fp32, h_out, rms_cache_flat, alpha_cache_flat,
scale, rmsnorm_eps, enable_rmsnorm)
elif l <= 64:
ai_infra_block_attn_res_forward_kernel_l_max_64(
v, q, gamma_fp32, h_out, rms_cache_flat, alpha_cache_flat,
scale, rmsnorm_eps, enable_rmsnorm)
elif l <= 96:
ai_infra_block_attn_res_forward_kernel_l_max_96(
v, q, gamma_fp32, h_out, rms_cache_flat, alpha_cache_flat,
scale, rmsnorm_eps, enable_rmsnorm)
elif l <= 128:
ai_infra_block_attn_res_forward_kernel_l_max_128(
v, q, gamma_fp32, h_out, rms_cache_flat, alpha_cache_flat,
scale, rmsnorm_eps, enable_rmsnorm)
else:
raise ValueError(f"Unsupported l={l}, expected l <= 128")
output = h_out.reshape(b, t, d)
if not need_cache:
return output
results = [output]
if rms_out_flag:
results.append(rms_cache_flat.reshape(b, t, l, 1))
if alpha_out_flag:
results.append(alpha_cache_flat.reshape(b, t, l))
return tuple(results)
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 32,
"device_sched_mode": 1
}
)
def ai_infra_block_attn_res_backward_kernel_l_max_32(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_h_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.DYNAMIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_gamma: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
scale: float,
enable_rmsnorm: bool,
):
l_max = 32
bt, l, d = v_flat.shape
if d == 512:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16, "func8_7": 2, "func8_1": 8, "func8_0": 32, "func8_8": 32},
"cube_nbuffer_setting": {"DEFAULT": 16}
}
elif d == 1536:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16},
}
ai_infra_block_attn_res_backward_kernel(v_flat, grad_h_3d, rms_cache, alpha_cache_3d, proj_weight, gamma_fp32,
grad_v_flat, grad_proj_weight, grad_gamma, scale, enable_rmsnorm, l_max, configs)
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 32,
"device_sched_mode": 1
}
)
def ai_infra_block_attn_res_backward_kernel_l_max_64(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_h_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.DYNAMIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_gamma: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
scale: float,
enable_rmsnorm: bool,
):
l_max = 64
bt, l, d = v_flat.shape
if d == 512:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16}
}
elif d == 1536:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16},
}
ai_infra_block_attn_res_backward_kernel(v_flat, grad_h_3d, rms_cache, alpha_cache_3d, proj_weight, gamma_fp32,
grad_v_flat, grad_proj_weight, grad_gamma, scale, enable_rmsnorm, l_max, configs)
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 32,
"device_sched_mode": 1
}
)
def ai_infra_block_attn_res_backward_kernel_l_max_96(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_h_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.DYNAMIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_gamma: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
scale: float,
enable_rmsnorm: bool,
):
l_max = 96
bt, l, d = v_flat.shape
if d == 512:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16}
}
elif d == 1536:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16},
}
ai_infra_block_attn_res_backward_kernel(v_flat, grad_h_3d, rms_cache, alpha_cache_3d, proj_weight, gamma_fp32,
grad_v_flat, grad_proj_weight, grad_gamma, scale, enable_rmsnorm, l_max, configs)
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 32,
"device_sched_mode": 1
}
)
def ai_infra_block_attn_res_backward_kernel_l_max_128(
v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_h_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
rms_cache: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
alpha_cache_3d: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.DYNAMIC]),
proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
gamma_fp32: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_v_flat: pypto.Tensor([pypto.DYNAMIC, pypto.DYNAMIC, pypto.STATIC]),
grad_proj_weight: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
grad_gamma: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC]),
scale: float,
enable_rmsnorm: bool,
):
l_max = 128
bt, l, d = v_flat.shape
if d == 512:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16}
}
elif d == 1536:
configs = {
"vec_nbuffer_setting": {"DEFAULT": 16},
"cube_nbuffer_setting": {"DEFAULT": 16},
}
ai_infra_block_attn_res_backward_kernel(v_flat, grad_h_3d, rms_cache, alpha_cache_3d, proj_weight, gamma_fp32,
grad_v_flat, grad_proj_weight, grad_gamma, scale, enable_rmsnorm, l_max, configs)
@pypto.frontend.function
def ai_infra_block_attn_res_backward_kernel(v_flat, grad_h_3d, rms_cache, alpha_cache_3d, proj_weight, gamma_fp32,
grad_v_flat, grad_proj_weight, grad_gamma, scale, enable_rmsnorm, l_max, configs):
"""PyPTO jit ai_infra_block_attn_res_backward_kernel
v_flat: [B*T, L, D] bf16/fp16
grad_h_3d: [B*T, 1, D] bf16/fp16
rms_cache: [B*T, L, 1] fp32 (enable_rmsnorm=False 时为空)
alpha_cache_3d: [B*T, 1, L] fp32
proj_weight: [1, 1, D] bf16/fp16
gamma_fp32: [1, 1, D] fp32
grad_v_flat: [B*T, L, D] bf16/fp16
grad_proj_weight: [1, 1, D] bf16/fp16
grad_gamma: [1, 1, D] bf16/fp16
scale: softmax 缩放因子
enable_rmsnorm: 是否启用 RMSNorm
l_max: L分档大小
"""
bt, l, d = v_flat.shape
dtype = v_flat.dtype
unroll_list = [1, 64]
l_tile = max(16384 // d, 1)
l_tile = max((l_tile // 4) * 4, 4)
pypto.experimental.set_operation_options(combine_axis=True)
pypto.set_pass_options(
vec_nbuffer_setting=configs["vec_nbuffer_setting"],
cube_nbuffer_setting=configs["cube_nbuffer_setting"],
)
pypto.set_vec_tile_shapes(1, 1, d)
grad_proj_weight_acc = pypto.zeros([1, 1, d], dtype=pypto.DT_FP32)
if enable_rmsnorm:
grad_gamma_acc = pypto.zeros([1, 1, d], dtype=pypto.DT_FP32)
for bt_idx, tile_bt in pypto.loop_unroll(0, bt, 1, name="BT_LOOP", idx="bt_idx", unroll_list=unroll_list):
pypto.set_semantic_label("Weighted Summation Backward")
pypto.set_vec_tile_shapes(1, l_tile, d)
grad_h_tile_3d = pypto.view(grad_h_3d, [tile_bt, 1, d], [bt_idx, 0, 0])
v_tile = pypto.view(v_flat, [tile_bt, l_max, d], [bt_idx, 0, 0],
valid_shape=[tile_bt, l, d])
pypto.set_cube_tile_shapes([1, 1], [32768 // l_max, 32768 // l_max], [l_max, l_max])
grad_alpha_fp32 = pypto.matmul(grad_h_tile_3d, v_tile, pypto.DT_FP32, b_trans=True)
alpha_tile_3d = pypto.view(alpha_cache_3d, [tile_bt, 1, l_max], [bt_idx, 0, 0],
valid_shape=[tile_bt, 1, l])
pypto.set_cube_tile_shapes([l_max, l_max], [16, 16], [16384 // l_max, 16384 // l_max])
pypto.set_vec_tile_shapes(l_tile, 1, d)
grad_h_tile_3d_fp32 = pypto.cast(grad_h_tile_3d, pypto.DT_FP32)
grad_v_agg_fp32 = pypto.matmul(alpha_tile_3d, grad_h_tile_3d_fp32, pypto.DT_FP32, a_trans=True)
pypto.set_semantic_label("Attention Softmax Backward")
pypto.set_vec_tile_shapes(128, 1, 128)
grad_alpha_2d_fp32 = pypto.reshape(grad_alpha_fp32, [tile_bt, l_max],
valid_shape=[tile_bt, l])
alpha_2d = pypto.reshape(alpha_tile_3d, [tile_bt, l_max],
valid_shape=[tile_bt, l])
pypto.set_vec_tile_shapes(128, 128)
alpha_fp32_2d = pypto.cast(alpha_2d, pypto.DT_FP32)
dot_fp32 = pypto.sum(grad_alpha_2d_fp32 * alpha_fp32_2d, dim=-1, keepdim=True)
grad_logits_fp32 = alpha_fp32_2d * (grad_alpha_2d_fp32 - dot_fp32)
grad_logits_reshaped = pypto.reshape(grad_logits_fp32, [tile_bt, 1, l_max],
valid_shape=[tile_bt, 1, l])
pypto.set_semantic_label("Attention Scores Backward")
pypto.set_vec_tile_shapes(1, l_tile, d)
v_tile_fp32 = pypto.cast(v_tile, pypto.DT_FP32)
proj_weight_fp32 = pypto.cast(proj_weight, pypto.DT_FP32)
proj_weight_scale = scale * proj_weight_fp32
proj_weight_scale_fp32 = proj_weight_scale
if enable_rmsnorm:
pypto.set_vec_tile_shapes(1, l_tile, d)
rms_tile = pypto.view(rms_cache, [tile_bt, l_max, 1], [bt_idx, 0, 0],
valid_shape=[tile_bt, l, 1])
rms_tile_fp32 = pypto.cast(rms_tile, pypto.DT_FP32)
k_tile_fp32 = v_tile_fp32 / rms_tile_fp32 * gamma_fp32
else:
k_tile_fp32 = v_tile_fp32
pypto.set_cube_tile_shapes([l_max, l_max], [16, 16], [16384 // l_max, 16384 // l_max])
grad_k_fp32 = pypto.matmul(grad_logits_reshaped, proj_weight_scale_fp32, pypto.DT_FP32, a_trans=True)
pypto.set_vec_tile_shapes(1, l_tile, d)
grad_logits_reshaped = pypto.reshape(grad_logits_fp32 + 0.0, [tile_bt, l_max, 1],
valid_shape=[tile_bt, l, 1])
grad_proj_weight_tile_fp32 = pypto.mul(grad_logits_reshaped, k_tile_fp32)
pypto.set_vec_tile_shapes(tile_bt, 1, 16384 // tile_bt)
grad_proj_weight_tile_fp32 = pypto.sum(grad_proj_weight_tile_fp32, dim=0, keepdim=True)
pypto.set_pass_options(sg_set_scope=1)
pypto.set_vec_tile_shapes(1, l_max, 16384 // l_max)
grad_proj_weight_tile_fp32 = pypto.sum(grad_proj_weight_tile_fp32, dim=1, keepdim=True)
pypto.set_pass_options(sg_set_scope=-1)
pypto.set_vec_tile_shapes(1, 1, d)
grad_proj_weight_scaled_fp32 = grad_proj_weight_tile_fp32 * scale
grad_proj_weight_acc[:] = grad_proj_weight_acc + grad_proj_weight_scaled_fp32
if enable_rmsnorm:
pypto.set_semantic_label("RmsNorm Backward")
pypto.set_vec_tile_shapes(1, l_tile, d)
c_fp32 = pypto.sum(gamma_fp32 * grad_k_fp32 * v_tile_fp32, dim=-1, keepdim=True)
rms_sq_fp32 = rms_tile_fp32 * rms_tile_fp32
grad_v_rms_fp32 = (gamma_fp32 * grad_k_fp32 - v_tile_fp32 * c_fp32 / (d * rms_sq_fp32)) / rms_tile_fp32
grad_gamma_tile_fp32 = grad_k_fp32 * (v_tile_fp32 / rms_tile_fp32)
pypto.set_vec_tile_shapes(tile_bt, 1, 16384 // tile_bt)
grad_gamma_tile_fp32 = pypto.sum(grad_gamma_tile_fp32, dim=0, keepdim=True)
pypto.set_pass_options(sg_set_scope=1)
pypto.set_vec_tile_shapes(1, l_max, 16384 // l_max)
grad_gamma_tile_fp32 = pypto.sum(grad_gamma_tile_fp32, dim=1, keepdim=True)
pypto.set_pass_options(sg_set_scope=-1)
pypto.set_vec_tile_shapes(1, 1, d)
grad_gamma_acc[:] = grad_gamma_acc + grad_gamma_tile_fp32
else:
grad_v_rms_fp32 = grad_k_fp32
pypto.set_vec_tile_shapes(1, l_tile, d)
grad_v_fp32 = grad_v_agg_fp32 + grad_v_rms_fp32
del grad_v_agg_fp32, grad_alpha_fp32, grad_k_fp32
grad_v_tile = pypto.cast(grad_v_fp32, dtype)
pypto.assemble(grad_v_tile, [bt_idx, 0, 0], grad_v_flat)
grad_proj_weight[:] = pypto.cast(grad_proj_weight_acc, dtype)
if enable_rmsnorm:
grad_gamma[:] = pypto.cast(grad_gamma_acc, dtype)
def _validate_backward_inputs(
grad_h: torch.Tensor,
blocks: List[torch.Tensor],
proj_weight: torch.Tensor,
alpha_cache: torch.Tensor,
partial_block: Optional[torch.Tensor],
rmsnorm_gamma: Optional[torch.Tensor],
rms_cache: Optional[torch.Tensor],
enable_rmsnorm: bool,
):
n = len(blocks)
if n == 0:
raise ValueError("blocks must not be empty")
has_partial = partial_block is not None
b, t, d = blocks[0].shape
l = n + 1 if has_partial else n
dtype = blocks[0].dtype
device = blocks[0].device
if grad_h.shape != (b, t, d):
raise ValueError(f"grad_h shape must be {(b, t, d)}, but got {tuple(grad_h.shape)}")
if grad_h.dtype != dtype or grad_h.device != device:
raise ValueError(f"grad_h dtype/device must be {dtype}/{device}, but got {grad_h.dtype}/{grad_h.device}")
for i, block in enumerate(blocks):
if block.shape != (b, t, d):
raise ValueError(f"blocks[{i}] shape must be {(b, t, d)}, but got {tuple(block.shape)}")
if block.dtype != dtype or block.device != device:
raise ValueError(f"blocks[{i}] dtype/device must be {dtype}/{device}, but got {block.dtype}/{block.device}")
if has_partial:
if partial_block.shape != (b, t, d):
raise ValueError(f"partial_block shape must be {(b, t, d)}, but got {tuple(partial_block.shape)}")
if partial_block.dtype != dtype or partial_block.device != device:
raise ValueError(
f"partial_block dtype/device must be {dtype}/{device}, "
f"but got {partial_block.dtype}/{partial_block.device}"
)
if proj_weight.shape != (d,):
raise ValueError(f"proj_weight shape must be {(d,)}, but got {tuple(proj_weight.shape)}")
if proj_weight.dtype != dtype or proj_weight.device != device:
raise ValueError(
f"proj_weight dtype/device must be {dtype}/{device}, "
f"but got {proj_weight.dtype}/{proj_weight.device}"
)
if alpha_cache is None:
raise ValueError("alpha_cache must be provided and cannot be None")
if alpha_cache.shape != (b, t, l):
raise ValueError(f"alpha_cache shape must be {(b, t, l)}, but got {tuple(alpha_cache.shape)}")
if enable_rmsnorm:
if rmsnorm_gamma is None:
raise ValueError("when enable_rmsnorm is True, rmsnorm_gamma must be provided")
if rmsnorm_gamma.shape != (d,):
raise ValueError(f"rmsnorm_gamma shape must be {(d,)}, but got {tuple(rmsnorm_gamma.shape)}")
if rmsnorm_gamma.dtype != dtype or rmsnorm_gamma.device != device:
raise ValueError(
f"rmsnorm_gamma dtype/device must be {dtype}/{device}, "
f"but got {rmsnorm_gamma.dtype}/{rmsnorm_gamma.device}"
)
if rms_cache is None:
raise ValueError("when enable_rmsnorm is True, rms_cache must be provided")
if rms_cache.shape != (b, t, l, 1):
raise ValueError(f"rms_cache shape must be {(b, t, l, 1)}, but got {tuple(rms_cache.shape)}")
def ai_infra_block_attn_res_backward(
grad_h: torch.Tensor,
blocks: List[torch.Tensor],
proj_weight: torch.Tensor,
alpha_cache: torch.Tensor,
partial_block: Optional[torch.Tensor] = None,
rmsnorm_gamma: Optional[torch.Tensor] = None,
rms_cache: Optional[torch.Tensor] = None,
scale: float = 1.0,
enable_rmsnorm: bool = True,
):
"""Block Attention Residuals 反向算子调用入口。
Args:
grad_h: 上游传递的输出梯度 [B, T, D]
blocks: 正向使用的已完成的 block 表示列表, 每个 tensor shape 均为 [B, T, D]
proj_weight: 伪查询投影权重 [D]
alpha_cache: 正向缓存的 softmax 权重 [B, T, L]
partial_block: 正向使用的当前 block 部分和 [B, T, D]
rmsnorm_gamma: RMSNorm 缩放参数 [D]
rms_cache: 正向缓存的 RMSNorm 分母 [B, T, L, 1]
scale: softmax 缩放因子, 默认值为 1.0
enable_rmsnorm: 是否启用 RMSNorm, 默认为 True
Returns:
- (grad_blocks, grad_partial_block, grad_proj_weight, grad_rmsnorm_gamma)
- (grad_blocks, grad_partial_block, grad_proj_weight, grad_rmsnorm_gamma,
grad_rmsnorm_gamma) (enable_rmsnorm为True时)
"""
_validate_backward_inputs(grad_h, blocks, proj_weight, alpha_cache,
partial_block, rmsnorm_gamma, rms_cache,
enable_rmsnorm)
n = len(blocks)
has_partial = partial_block is not None
b, t, d = blocks[0].shape
l = n + 1 if has_partial else n
dtype = blocks[0].dtype
device = blocks[0].device
tensors = blocks + ([partial_block] if has_partial else [])
v = torch.stack(tensors, dim=2)
v_flat = v.reshape(b * t, l, d)
if rmsnorm_gamma is not None:
gamma_reshaped = rmsnorm_gamma.reshape(1, 1, d).to(torch.float32)
else:
gamma_reshaped = torch.ones(1, 1, d, dtype=dtype, device=device).to(torch.float32)
proj_weight_reshaped = proj_weight.reshape(1, 1, d)
grad_h_3d = grad_h.reshape(b * t, 1, d)
alpha_cache_3d = alpha_cache.reshape(b * t, 1, l)
if rms_cache is not None:
rms_cache_flat = rms_cache.reshape(b * t, l, 1)
else:
rms_cache_flat = torch.empty(b * t, l, 1, dtype=dtype, device=device)
grad_v_flat = torch.zeros(b * t, l, d, dtype=dtype, device=device)
grad_proj_weight_out = torch.zeros(1, 1, d, dtype=dtype, device=device)
grad_gamma_out = torch.zeros(1, 1, d, dtype=dtype, device=device)
kernel_args = (v_flat, grad_h_3d, rms_cache_flat, alpha_cache_3d, proj_weight_reshaped, gamma_reshaped,
grad_v_flat, grad_proj_weight_out, grad_gamma_out, scale, enable_rmsnorm)
if l <= 32:
ai_infra_block_attn_res_backward_kernel_l_max_32(*kernel_args)
elif l <= 64:
ai_infra_block_attn_res_backward_kernel_l_max_64(*kernel_args)
elif l <= 96:
ai_infra_block_attn_res_backward_kernel_l_max_96(*kernel_args)
elif l <= 128:
ai_infra_block_attn_res_backward_kernel_l_max_128(*kernel_args)
else:
raise ValueError(f"Unsupported l={l}, expected l <= 128")
grad_v = grad_v_flat.reshape(b, t, l, d)
grad_blocks = [
grad_v[:, :, i, :].contiguous()
for i in range(n)
]
if has_partial:
grad_partial_block = grad_v[:, :, n, :].contiguous()
else:
grad_partial_block = None
grad_proj_weight_final = grad_proj_weight_out.reshape(d)
if enable_rmsnorm:
grad_rmsnorm_gamma_final = grad_gamma_out.reshape(d)
return grad_blocks, grad_partial_block, grad_proj_weight_final, grad_rmsnorm_gamma_final
return grad_blocks, grad_partial_block, grad_proj_weight_final