import torch
import torch_npu
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
pass
try:
import triton.language.extra.cann.extension as tle
except ImportError:
import triton.language as tle
if TRITON_AVAILABLE:
@triton.jit
def hc_post_bmm2_fwd_kernel(
H_ptr, X_ptr, Y_ptr,
stride_h_bs: tl.constexpr, stride_h_n: tl.constexpr, stride_h_k: tl.constexpr,
stride_x_bs: tl.constexpr, stride_x_k: tl.constexpr, stride_x_c: tl.constexpr,
stride_y_bs: tl.constexpr, stride_y_n: tl.constexpr, stride_y_c: tl.constexpr,
GROUP: tl.constexpr,
BLOCK_C: tl.constexpr,
C: tl.constexpr
):
pid_bs_blk = tl.program_id(0)
pid_c_blk = tl.program_id(1)
pid0 = pid_bs_blk * GROUP
pids = pid0 + tl.arange(0, GROUP)
c = pid_c_blk * BLOCK_C + tl.arange(0, BLOCK_C)
c_mask = c < C
X_base = X_ptr + pids[:, None] * stride_x_bs + c[None, :] * stride_x_c
x0 = tl.load(X_base + 0 * stride_x_k, mask=c_mask[None, :]).to(tl.float32)
x1 = tl.load(X_base + 1 * stride_x_k, mask=c_mask[None, :]).to(tl.float32)
x2 = tl.load(X_base + 2 * stride_x_k, mask=c_mask[None, :]).to(tl.float32)
x3 = tl.load(X_base + 3 * stride_x_k, mask=c_mask[None, :]).to(tl.float32)
k = tl.arange(0, 4)
h0 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 0 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h1 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 1 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h2 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 2 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h3 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 3 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h00 = tle.extract_slice(h0, [0, 0], [GROUP, 1], [1, 1])
h01 = tle.extract_slice(h0, [0, 1], [GROUP, 1], [1, 1])
h02 = tle.extract_slice(h0, [0, 2], [GROUP, 1], [1, 1])
h03 = tle.extract_slice(h0, [0, 3], [GROUP, 1], [1, 1])
h10 = tle.extract_slice(h1, [0, 0], [GROUP, 1], [1, 1])
h11 = tle.extract_slice(h1, [0, 1], [GROUP, 1], [1, 1])
h12 = tle.extract_slice(h1, [0, 2], [GROUP, 1], [1, 1])
h13 = tle.extract_slice(h1, [0, 3], [GROUP, 1], [1, 1])
h20 = tle.extract_slice(h2, [0, 0], [GROUP, 1], [1, 1])
h21 = tle.extract_slice(h2, [0, 1], [GROUP, 1], [1, 1])
h22 = tle.extract_slice(h2, [0, 2], [GROUP, 1], [1, 1])
h23 = tle.extract_slice(h2, [0, 3], [GROUP, 1], [1, 1])
h30 = tle.extract_slice(h3, [0, 0], [GROUP, 1], [1, 1])
h31 = tle.extract_slice(h3, [0, 1], [GROUP, 1], [1, 1])
h32 = tle.extract_slice(h3, [0, 2], [GROUP, 1], [1, 1])
h33 = tle.extract_slice(h3, [0, 3], [GROUP, 1], [1, 1])
y0 = tl.fma(x0, h00, tl.fma(x1, h01, tl.fma(x2, h02, x3 * h03)))
y1 = tl.fma(x0, h10, tl.fma(x1, h11, tl.fma(x2, h12, x3 * h13)))
y2 = tl.fma(x0, h20, tl.fma(x1, h21, tl.fma(x2, h22, x3 * h23)))
y3 = tl.fma(x0, h30, tl.fma(x1, h31, tl.fma(x2, h32, x3 * h33)))
Y_base = Y_ptr + pids[:, None] * stride_y_bs + c[None, :] * stride_y_c
tl.store(Y_base + 0 * stride_y_n, y0, mask=c_mask[None, :])
tl.store(Y_base + 1 * stride_y_n, y1, mask=c_mask[None, :])
tl.store(Y_base + 2 * stride_y_n, y2, mask=c_mask[None, :])
tl.store(Y_base + 3 * stride_y_n, y3, mask=c_mask[None, :])
@triton.jit
def hc_post_bmm2_bwd_dx_kernel(
H_ptr, dY_ptr, dX_ptr,
stride_h_bs: tl.constexpr, stride_h_n: tl.constexpr, stride_h_k: tl.constexpr,
stride_dy_bs: tl.constexpr, stride_dy_n: tl.constexpr, stride_dy_c: tl.constexpr,
stride_dx_bs: tl.constexpr, stride_dx_k: tl.constexpr, stride_dx_c: tl.constexpr,
GROUP: tl.constexpr,
BLOCK_C: tl.constexpr,
C: tl.constexpr
):
pid_bs_blk = tl.program_id(0)
pid_c_blk = tl.program_id(1)
pid0 = pid_bs_blk * GROUP
pids = pid0 + tl.arange(0, GROUP)
c = pid_c_blk * BLOCK_C + tl.arange(0, BLOCK_C)
c_mask = c < C
dY_base = dY_ptr + pids[:, None] * stride_dy_bs + c[None, :] * stride_dy_c
dy0 = tl.load(dY_base + 0 * stride_dy_n, mask=c_mask[None, :]).to(tl.float32)
dy1 = tl.load(dY_base + 1 * stride_dy_n, mask=c_mask[None, :]).to(tl.float32)
dy2 = tl.load(dY_base + 2 * stride_dy_n, mask=c_mask[None, :]).to(tl.float32)
dy3 = tl.load(dY_base + 3 * stride_dy_n, mask=c_mask[None, :]).to(tl.float32)
k = tl.arange(0, 4)
h0 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 0 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h1 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 1 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h2 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 2 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h3 = tl.load(H_ptr + pids[:, None] * stride_h_bs + 3 * stride_h_n + k[None, :] * stride_h_k).to(tl.float32)
h00 = tle.extract_slice(h0, [0, 0], [GROUP, 1], [1, 1])
h01 = tle.extract_slice(h0, [0, 1], [GROUP, 1], [1, 1])
h02 = tle.extract_slice(h0, [0, 2], [GROUP, 1], [1, 1])
h03 = tle.extract_slice(h0, [0, 3], [GROUP, 1], [1, 1])
h10 = tle.extract_slice(h1, [0, 0], [GROUP, 1], [1, 1])
h11 = tle.extract_slice(h1, [0, 1], [GROUP, 1], [1, 1])
h12 = tle.extract_slice(h1, [0, 2], [GROUP, 1], [1, 1])
h13 = tle.extract_slice(h1, [0, 3], [GROUP, 1], [1, 1])
h20 = tle.extract_slice(h2, [0, 0], [GROUP, 1], [1, 1])
h21 = tle.extract_slice(h2, [0, 1], [GROUP, 1], [1, 1])
h22 = tle.extract_slice(h2, [0, 2], [GROUP, 1], [1, 1])
h23 = tle.extract_slice(h2, [0, 3], [GROUP, 1], [1, 1])
h30 = tle.extract_slice(h3, [0, 0], [GROUP, 1], [1, 1])
h31 = tle.extract_slice(h3, [0, 1], [GROUP, 1], [1, 1])
h32 = tle.extract_slice(h3, [0, 2], [GROUP, 1], [1, 1])
h33 = tle.extract_slice(h3, [0, 3], [GROUP, 1], [1, 1])
dx0 = tl.fma(dy0, h00, tl.fma(dy1, h10, tl.fma(dy2, h20, dy3 * h30)))
dx1 = tl.fma(dy0, h01, tl.fma(dy1, h11, tl.fma(dy2, h21, dy3 * h31)))
dx2 = tl.fma(dy0, h02, tl.fma(dy1, h12, tl.fma(dy2, h22, dy3 * h32)))
dx3 = tl.fma(dy0, h03, tl.fma(dy1, h13, tl.fma(dy2, h23, dy3 * h33)))
dX_base = dX_ptr + pids[:, None] * stride_dx_bs + c[None, :] * stride_dx_c
tl.store(dX_base + 0 * stride_dx_k, dx0, mask=c_mask[None, :])
tl.store(dX_base + 1 * stride_dx_k, dx1, mask=c_mask[None, :])
tl.store(dX_base + 2 * stride_dx_k, dx2, mask=c_mask[None, :])
tl.store(dX_base + 3 * stride_dx_k, dx3, mask=c_mask[None, :])
@triton.jit
def hc_post_bmm2_bwd_dh_kernel(
X_ptr, dY_ptr, dH_ptr,
stride_x_bs: tl.constexpr, stride_x_k: tl.constexpr, stride_x_c: tl.constexpr,
stride_dy_bs: tl.constexpr, stride_dy_n: tl.constexpr, stride_dy_c: tl.constexpr,
stride_dh_bs: tl.constexpr, stride_dh_n: tl.constexpr, stride_dh_k: tl.constexpr,
C: tl.constexpr,
BLOCK_C_R: tl.constexpr,
):
pid_bs = tl.program_id(0)
pid_c_blk = tl.program_id(1)
acc00, acc01 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc02, acc03 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc10, acc11 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc12, acc13 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc20, acc21 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc22, acc23 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc30, acc31 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
acc32, acc33 = tl.zeros((), tl.float32), tl.zeros((), tl.float32)
c = pid_c_blk * BLOCK_C_R + tl.arange(0, BLOCK_C_R)
c_mask = c < C
X_base = X_ptr + pid_bs * stride_x_bs + c * stride_x_c
x0 = tl.load(X_base + 0 * stride_x_k, mask=c_mask).to(tl.float32)
x1 = tl.load(X_base + 1 * stride_x_k, mask=c_mask).to(tl.float32)
x2 = tl.load(X_base + 2 * stride_x_k, mask=c_mask).to(tl.float32)
x3 = tl.load(X_base + 3 * stride_x_k, mask=c_mask).to(tl.float32)
dY_base = dY_ptr + pid_bs * stride_dy_bs + c * stride_dy_c
dy0 = tl.load(dY_base + 0 * stride_dy_n, mask=c_mask).to(tl.float32)
dy1 = tl.load(dY_base + 1 * stride_dy_n, mask=c_mask).to(tl.float32)
dy2 = tl.load(dY_base + 2 * stride_dy_n, mask=c_mask).to(tl.float32)
dy3 = tl.load(dY_base + 3 * stride_dy_n, mask=c_mask).to(tl.float32)
acc00 += tl.sum(dy0 * x0, axis=0)
acc01 += tl.sum(dy0 * x1, axis=0)
acc02 += tl.sum(dy0 * x2, axis=0)
acc03 += tl.sum(dy0 * x3, axis=0)
acc10 += tl.sum(dy1 * x0, axis=0)
acc11 += tl.sum(dy1 * x1, axis=0)
acc12 += tl.sum(dy1 * x2, axis=0)
acc13 += tl.sum(dy1 * x3, axis=0)
acc20 += tl.sum(dy2 * x0, axis=0)
acc21 += tl.sum(dy2 * x1, axis=0)
acc22 += tl.sum(dy2 * x2, axis=0)
acc23 += tl.sum(dy2 * x3, axis=0)
acc30 += tl.sum(dy3 * x0, axis=0)
acc31 += tl.sum(dy3 * x1, axis=0)
acc32 += tl.sum(dy3 * x2, axis=0)
acc33 += tl.sum(dy3 * x3, axis=0)
dH_bs = dH_ptr + pid_bs * stride_dh_bs
tl.atomic_add(dH_bs + 0 * stride_dh_n + 0 * stride_dh_k, acc00)
tl.atomic_add(dH_bs + 0 * stride_dh_n + 1 * stride_dh_k, acc01)
tl.atomic_add(dH_bs + 0 * stride_dh_n + 2 * stride_dh_k, acc02)
tl.atomic_add(dH_bs + 0 * stride_dh_n + 3 * stride_dh_k, acc03)
tl.atomic_add(dH_bs + 1 * stride_dh_n + 0 * stride_dh_k, acc10)
tl.atomic_add(dH_bs + 1 * stride_dh_n + 1 * stride_dh_k, acc11)
tl.atomic_add(dH_bs + 1 * stride_dh_n + 2 * stride_dh_k, acc12)
tl.atomic_add(dH_bs + 1 * stride_dh_n + 3 * stride_dh_k, acc13)
tl.atomic_add(dH_bs + 2 * stride_dh_n + 0 * stride_dh_k, acc20)
tl.atomic_add(dH_bs + 2 * stride_dh_n + 1 * stride_dh_k, acc21)
tl.atomic_add(dH_bs + 2 * stride_dh_n + 2 * stride_dh_k, acc22)
tl.atomic_add(dH_bs + 2 * stride_dh_n + 3 * stride_dh_k, acc23)
tl.atomic_add(dH_bs + 3 * stride_dh_n + 0 * stride_dh_k, acc30)
tl.atomic_add(dH_bs + 3 * stride_dh_n + 1 * stride_dh_k, acc31)
tl.atomic_add(dH_bs + 3 * stride_dh_n + 2 * stride_dh_k, acc32)
tl.atomic_add(dH_bs + 3 * stride_dh_n + 3 * stride_dh_k, acc33)
def hc_post_bmm2_forward(
H_res: torch.Tensor,
x: torch.Tensor,
) -> torch.Tensor:
"""
H_res: [B,S,4,4] fp32
x : [B,S,4,C] bf16
out : [B,S,4,C] fp32
"""
B, S, N, _ = H_res.shape
_, _, _, C = x.shape
GROUP = 1
BLOCK_C = 4096 if C > 4096 else C
BS = B * S
H = H_res.contiguous().view(BS, N, N)
X = x.contiguous().view(BS, N, C)
Y = torch.empty((BS, N, C), device=x.device, dtype=torch.float32)
grid = (triton.cdiv(BS, GROUP), triton.cdiv(C, BLOCK_C))
hc_post_bmm2_fwd_kernel[grid](
H, X, Y,
stride_h_bs=H.stride(0), stride_h_n=H.stride(1), stride_h_k=H.stride(2),
stride_x_bs=X.stride(0), stride_x_k=X.stride(1), stride_x_c=X.stride(2),
stride_y_bs=Y.stride(0), stride_y_n=Y.stride(1), stride_y_c=Y.stride(2),
GROUP=GROUP, BLOCK_C=BLOCK_C, C=C,
)
return Y.view(B, S, N, C)
def hc_post_bmm2_backward(H_res: torch.Tensor, x: torch.Tensor, dY: torch.Tensor):
"""
Returns:
dH_res: [B,S,4,4] fp32
dX : [B,S,4,C] fp32 (or cast outside if you want bf16)
"""
B, S, N, N2 = H_res.shape
_, _, _, C = x.shape
BS = B * S
H = H_res.contiguous().view(BS, N, N)
X = x.contiguous().view(BS, N, C)
dY_ = dY.contiguous().view(BS, N, C)
dX_fp32 = torch.empty((BS, N, C), device=x.device, dtype=torch.float32)
GROUP = 1
BLOCK_C = 4096 if C > 4096 else C
grid_dx = (triton.cdiv(BS, GROUP), triton.cdiv(C, BLOCK_C))
hc_post_bmm2_bwd_dx_kernel[grid_dx](
H, dY_, dX_fp32,
stride_h_bs=H.stride(0), stride_h_n=H.stride(1), stride_h_k=H.stride(2),
stride_dy_bs=dY_.stride(0), stride_dy_n=dY_.stride(1), stride_dy_c=dY_.stride(2),
stride_dx_bs=dX_fp32.stride(0), stride_dx_k=dX_fp32.stride(1), stride_dx_c=dX_fp32.stride(2),
GROUP=GROUP, BLOCK_C=BLOCK_C, C=C
)
dX = dX_fp32.view(B, S, N, C)
dH = torch.zeros((BS, N, N), device=x.device, dtype=torch.float32)
BLOCK_C_R = 4096 if C > 4096 else C
grid_dh = (BS, triton.cdiv(C, BLOCK_C_R))
hc_post_bmm2_bwd_dh_kernel[grid_dh](
X, dY_, dH,
stride_x_bs=X.stride(0), stride_x_k=X.stride(1), stride_x_c=X.stride(2),
stride_dy_bs=dY_.stride(0), stride_dy_n=dY_.stride(1), stride_dy_c=dY_.stride(2),
stride_dh_bs=dH.stride(0), stride_dh_n=dH.stride(1), stride_dh_k=dH.stride(2),
C=C,
BLOCK_C_R=BLOCK_C_R,
)
dH = dH.view(B, S, N, N)
return dH, dX