"""
本脚本有 2 种执行模式:
1. CI批跑时, 由 cmake/scripts/golden_ctrl.py 调用, 为避免日志过多, 此时 logging 级别为 logging.INFO;
2. 单独调试时, 本脚本单独被调用, 此时 logging 级别为 logging.DEBUG;
"""
import sys
import logging
from pathlib import Path
from typing import List
import numpy as np
if __name__ == "__main__":
""" 单独调试时配置 """
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
g_src_root: Path = Path(Path(__file__).parent, "../../../../../").resolve()
logging.debug("SrcRoot: %s", g_src_root)
g_ctrl_path: Path = Path(g_src_root, "cmake/scripts")
if str(g_ctrl_path) not in sys.path:
sys.path.append(str(g_ctrl_path))
from golden_register import GoldenRegister
else:
from golden_register import GoldenRegister
np.random.seed(0)
b = 1
n = 1
s = 128
d_q = 128
d_v = 128
dtype_f16 = np.float16
dtype_f32 = np.float32
def softmax(x, axis=None):
x_max = x.max(axis=-1, keepdims=True)
x_sub = x - x_max
y = np.exp(x_sub)
x_sum = y.sum(axis=-1, keepdims=True)
ans = y / x_sum
return ans, x_max, x_sum
def softmax_grad(dp, softmax_res):
muls_r = dp * softmax_res
muls_r = muls_r.sum(axis=-1, keepdims=True)
sub_r = dp - muls_r
muls_r2 = softmax_res
res = sub_r * muls_r2
return res
def softmax_grad_flash_1(dx, y):
d = (dx * y).sum(axis=-1, keepdims=True)
return d
def softmax_grad_flash_2(dp, p, d):
sub = dp - d
ds = sub * p
return ds
def softmax_flash(x, max_front=None, sum_front=None, update=None):
"""
Compute the softmax function for each channel of the input x.
"""
if update is None:
x_max = np.max(x, axis=-1, keepdims=True)
x_sub = x - x_max
x_exp = np.exp(x_sub)
x_sum = np.sum(x_exp, axis=-1, keepdims=True)
out = x_exp / x_sum
exp_max = None
return out, x_max, x_sum, exp_max
else:
x_max_tmp = np.max(x, axis=-1, keepdims=True)
x_sub = x - x_max_tmp
x_exp = np.exp(x_sub)
x_sum = np.sum(x_exp, axis=-1, keepdims=True)
x_max = np.max(np.concatenate((max_front, x_max_tmp), axis=-1), axis=-1, keepdims=True)
x_exp_new = np.exp(x_max_tmp - x_max)
exp_max = np.exp(max_front - x_max)
exp_max = exp_max * sum_front
reduce_tmp = x_exp_new * x_sum
x_sum = exp_max + reduce_tmp
exp_max = exp_max / x_sum
out = x_exp * x_exp_new / x_sum
return out, x_max, x_sum, exp_max
def forward(q, k, v, drop_mask):
q, k, v = q.astype(dtype_f32), k.astype(dtype_f32), v.astype(dtype_f32)
qk = np.matmul(q, k.transpose(0, 1, 3, 2))
softmax_res, x_max, x_sum = softmax(qk)
drop_res = softmax_res * drop_mask
y = np.matmul(drop_res, v)
return y, softmax_res, x_max, x_sum
def backward(dx, q, k, v, softmax_res, drop_mask):
dx, q, k, v = dx.astype(dtype_f32), q.astype(dtype_f32), k.astype(dtype_f32), v.astype(dtype_f32)
drop_res = softmax_res * drop_mask
dv = np.matmul(drop_res.transpose(1, 0), dx)
dp = np.matmul(dx, v.transpose(1, 0))
dp_drop = dp * drop_mask
softmax_grad_res = softmax_grad(dp_drop, softmax_res)
dq = np.matmul(softmax_grad_res, k)
dk = np.matmul(softmax_grad_res.transpose(1, 0), q)
dq, dk, dv = dq.astype(dtype_f16), dk.astype(dtype_f16), dv.astype(dtype_f16)
return dq, dk, dv
def gen_golden_np(output: Path):
qkv_shape = [b, n, s, d_q]
v_shape = [b, n, s, d_v]
drop_shape = [b, n, s, s]
dx = np.random.uniform(-1, 1, qkv_shape).astype(dtype_f16)
q = np.random.uniform(0, 0.5, qkv_shape).astype(dtype_f16)
k = np.random.uniform(0, 0.5, qkv_shape).astype(dtype_f16)
v = np.random.uniform(0, 0.5, v_shape).astype(dtype_f16)
drop_mask = np.random.uniform(1, 2, drop_shape).astype(np.uint8)
dx.tofile(Path(output, 'dx.bin'))
q.tofile(Path(output, 'q.bin'))
k.tofile(Path(output, 'k.bin'))
v.tofile(Path(output, 'v.bin'))
drop_mask.tofile(Path(output, 'drop_mask.bin'))
y, softmax_res, x_max, x_sum = forward(q, k, v, drop_mask)
dq, dk, dv = q, k, v
return q, k, v, dx, drop_mask, y, dq, dk, dv, softmax_res, x_max, x_sum
@GoldenRegister.reg_golden_func(
case_names=[
"OnBoardTestAstApi.test_fa_all2all_ast_api_mode",
]
)
def fa_main(case_name: str, output: Path) -> bool:
q, k, v, dx, drop_mask, y, dq_golden, dk_golden, dv_golden, softmax_res_golden, x_max, x_sum \
= gen_golden_np(output)
q, k, v = q.astype(dtype_f32), k.astype(dtype_f32), v.astype(dtype_f32)
m = np.full([b, n, s, 1], 100000000000000.0, dtype=dtype_f32)
_l = np.full([b, n, s, 1], 100000000000000.0, dtype=dtype_f32)
o = np.full([b, n, s, d_v], 100000000000000.0, dtype=dtype_f32)
one_loop_size = 128
i_loop = s // one_loop_size
j_loop = s // one_loop_size
for b_idx in range(b):
for n_idx in range(n):
for j in range(j_loop):
data_start = one_loop_size * j
data_end = one_loop_size * (j + 1)
kj = k[b_idx, n_idx, data_start:data_end, :]
vj = v[b_idx, n_idx, data_start:data_end, :]
drop_mask_i = drop_mask[b_idx, n_idx, :, data_start:data_end]
for i in range(i_loop):
data_start_i = one_loop_size * i
data_end_i = one_loop_size * (i + 1)
qi = q[b_idx, n_idx, data_start_i:data_end_i, :]
oi = o[b_idx, n_idx, data_start_i:data_end_i, :]
mi = m[b_idx, n_idx, data_start_i:data_end_i, :]
li = _l[b_idx, n_idx, data_start_i:data_end_i, :]
qk_j = np.matmul(qi, kj.transpose(1, 0))
tilda_mij = np.max(qk_j, axis=-1, keepdims=True)
tsub = qk_j - tilda_mij
tilda_pij = np.exp(tsub)
tilda_lij = np.sum(tilda_pij, axis=-1, keepdims=True)
if j == 0:
q6 = 1 / tilda_lij
q1 = np.matmul(tilda_pij, vj)
o[b_idx, n_idx, data_start_i:data_end_i, :] = q6 * q1
_l[b_idx, n_idx, data_start_i:data_end_i, :] = tilda_lij
m[b_idx, n_idx, data_start_i:data_end_i, :] = tilda_mij
else:
mi_new = np.max(np.concatenate((mi, tilda_mij), axis=-1), axis=-1, keepdims=True)
t1 = mi - mi_new
t2 = np.exp(t1)
t3 = tilda_mij - mi_new
t4 = np.exp(t3)
t5 = t4 * tilda_lij
t6 = t2 * li
li_new = t6 + t5
q6 = 1 / li_new
q3 = t2 * oi
q4 = li * q3
q1 = np.matmul(tilda_pij, vj)
q2 = t4 * q1
q7 = q4 + q2
o[b_idx, n_idx, data_start_i:data_end_i, :] = q6 * q7
_l[b_idx, n_idx, data_start_i:data_end_i, :] = li_new
m[b_idx, n_idx, data_start_i:data_end_i, :] = mi_new
error = np.array(y - o)
o.tofile(Path(output, 'res_golden.bin'))
m.tofile(Path(output, 'max_golden.bin'))
_l.tofile(Path(output, 'sum_golden.bin'))
return True
def main() -> bool:
"""
单独调试 入口函数
"""
case_name_list: List[str] = [
"ExpandOnBoardTest.test_expand_32_1_to_32_32",
]
ret: bool = True
for cs in case_name_list:
output: Path = Path(g_src_root, "build/output/bin/golden", cs).resolve()
output.mkdir(parents=True, exist_ok=True)
ret = fa_main(case_name=cs, output=output)
return ret
if __name__ == "__main__":
exit(0 if main() else 1)