import torch
import math
from utils import *
def tsoftmax(x):
x_max = x.max(dim=-1, keepdim=True).values
x_sub = x.sub(x_max)
y = torch.exp(x_sub)
x_sum = y.sum(dim=-1, keepdims=True)
res = y.div(x_sum)
return res, x_max, x_sum
def _fix_invalid_rows(softmax_res, x_max, x_sum):
b, n, sq, _ = softmax_res.shape
for i in range(b):
for j in range(n):
for k in range(sq):
if x_max[i, j, k, :] == -40000.:
softmax_res[i, j, k, :] = 0
x_max[i, j, k, :] = torch.finfo(torch.float).min
x_sum[i, j, k, :] = torch.finfo(torch.float).max
return softmax_res, x_max, x_sum
def _attend(q, k, v, atten_mask, scale, need_fix_invalid):
q = q.float()
k = k.float()
v = v.float()
qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
if atten_mask is not None:
qk = qk.masked_fill_(atten_mask.cpu().bool(), value=torch.tensor(-40000))
if qk.shape[-1] == 0:
b, n, sq, _ = qk.shape
softmax_res = torch.zeros((b, n, sq, 0), dtype=qk.dtype)
x_max = torch.zeros(b, n, sq, 1)
x_sum = torch.zeros(b, n, sq, 1)
else:
softmax_res, x_max, x_sum = tsoftmax(qk)
if need_fix_invalid:
softmax_res, x_max, x_sum = _fix_invalid_rows(softmax_res, x_max, x_sum)
if need_fix_invalid:
softmax_res, x_max, x_sum = _fix_invalid_rows(softmax_res, x_max, x_sum)
out = torch.matmul(softmax_res, v)
x_max = x_max
x_sum = x_sum
return out, x_max, x_sum
def get_cu_seqlens(seqlens_list):
cu = torch.zeros(len(seqlens_list) + 1, dtype=torch.int64)
for i in range(1, len(seqlens_list) + 1):
cu[i] = cu[i - 1] + seqlens_list[i - 1]
return cu
def broadcastKV(n1, n2, kv_tensor, dtype):
kv_shape = kv_tensor.shape
b = kv_shape[0]
s = kv_shape[2]
d = kv_shape[3]
kv_res = torch.zeros(b, n1, s, d).to(dtype)
if n1 >= n2:
factor = n1 // n2
for i in range(n1):
j = i // factor
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
else:
factor = n2 // n1
for i in range(n1):
j = i * factor
kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
return kv_res
def _should_fix_invalid_rows(mask_mode, win_left, win_right, act_q_len, act_kv_len, prefix):
if mask_mode is None or mask_mode == 0:
return False
if mask_mode == 3:
return act_q_len > act_kv_len
if mask_mode == 4:
return win_left < 0 or win_right + act_kv_len < act_q_len
return False
def tforward_tnd(q, k, v, **kwargs):
cu_q = kwargs["cu_seqlens_q"]
seqused_q = kwargs["seqused_q"]
seqused_kv = kwargs["seqused_kv"]
b = len(cu_q) - 1
n1 = kwargs.get("N1")
n2 = kwargs.get("N2", n1)
d = kwargs.get("D")
d_v = kwargs.get("DV", d)
s1_total = cu_q[-1]
layout_kv = kwargs.get("layout_kv", None)
is_pa = layout_kv in ("PA_BBND", "PA_BNBD", "PA_NZ")
cu_kv = kwargs.get("cu_seqlens_kv", None)
mask_mode = kwargs.get("mask_mode", None)
win_left = kwargs.get("win_left", 2147483647)
win_right = kwargs.get("win_right", 2147483647)
prefix = kwargs.get("prefix", [])
scale = kwargs.get("scale", 1 / (d ** 0.5))
band_index = 0
qk_size = [int(seqused_q[i] * math.ceil(seqused_kv[i] / 16) * 16) for i in range(b)]
qk_pointer = get_cu_seqlens(qk_size).to(torch.int64)
out_golden = torch.zeros([1, n1, s1_total, d_v], dtype=q.dtype)
x_max = torch.zeros([n1, s1_total], dtype=torch.float32)
x_sum = torch.zeros([n1, s1_total], dtype=torch.float32)
for i in range(b):
act_q_len = seqused_q[i]
act_kv_len = seqused_kv[i]
if act_q_len == 0 or act_kv_len == 0:
continue
q_start = cu_q[i]
kv_start = cu_kv[i] if cu_kv is not None else sum(seqused_kv[:i])
qi = q[:, :, q_start:q_start + act_q_len]
ki = k[:, :, kv_start:kv_start + act_kv_len]
vi = v[:, :, kv_start:kv_start + act_kv_len]
if n1 != n2:
ki = broadcastKV(n1, n2, ki, ki.dtype)
vi = broadcastKV(n1, n2, vi, vi.dtype)
if mask_mode is None:
atten_maski = None
else:
atten_maski = generate_cpu_mask(1, act_q_len, act_kv_len, mask_mode, win_left, win_right, prefix, i, band_index)
need_fix = _should_fix_invalid_rows(mask_mode, win_left, win_right, act_q_len, act_kv_len, prefix)
outi, x_maxi, x_sumi = _attend(qi, ki, vi, atten_maski, scale, need_fix)
out_golden[:, :, q_start:q_start + act_q_len] = outi
x_max[:, q_start:q_start + act_q_len] = x_maxi.squeeze(0).squeeze(-1)
x_sum[:, q_start:q_start + act_q_len] = x_sumi.squeeze(0).squeeze(-1)
return out_golden, x_max, x_sum
def tforward(q, k, v, **kwargs):
input_layout = kwargs.get("input_layout")
if input_layout != "TND":
b = kwargs.get("B")
n1 = kwargs.get("N1")
n2 = kwargs.get("N2", n1)
s1 = kwargs.get("S1")
s2 = kwargs.get("S2", s1)
q = q.reshape(1, b, n1, s1, q.shape[-1]).transpose(1, 2).reshape(1, n1, b * s1, q.shape[-1])
k = k.reshape(1, b, n2, s2, k.shape[-1]).transpose(1, 2).reshape(1, n2, b * s2, k.shape[-1])
v = v.reshape(1, b, n2, s2, v.shape[-1]).transpose(1, 2).reshape(1, n2, b * s2, v.shape[-1])
cu_q = [i * s1 for i in range(b + 1)]
cu_kv = [i * s2 for i in range(b + 1)]
kwargs["cu_seqlens_q"] = cu_q
kwargs["cu_seqlens_kv"] = cu_kv
kwargs.setdefault("seqused_q", [s1] * b)
kwargs.setdefault("seqused_kv", [s2] * b)
out, x_max, x_sum = tforward_tnd(q, k, v, **kwargs)
if input_layout != "TND":
b = kwargs.get("B")
s1 = kwargs.get("S1")
n1 = kwargs.get("N1")
out = out.reshape(n1, b, s1, out.shape[-1]).transpose(0, 1)
x_max = x_max.reshape(n1, b, s1).transpose(1, 0)
x_sum = x_sum.reshape(n1, b, s1).transpose(1, 0)
return out, x_max, x_sum