import os
import sys
import pypto
import torch
import pytest
from hc_pre_impl import hc_pre_kernel, hc_pre_kernel_prefill, npu_hc_pre, check_input_output_shape_dtype
from utils.compare import compare
hc, d, sinkhorn_iters, norm_eps, hc_eps = 4, 4096, 20, 1e-6, 1e-6
mix_hc = (2 + hc) * hc
def gen_rms_norm_denom(x):
_, len = x.shape
print("rms norm x.shape", x.shape)
x = x.square()
x = x.sum(-1, True) / len
x = x + norm_eps
x = x.sqrt()
return x
def gen_sigmoid(x):
x = -x
x = x.exp()
x = 1 / (1 + x)
return x
def gen_hc_split_sinkhorn(x, hc_scale, hc_base):
t, _ = x.shape
pre = x[:, :hc] * hc_scale[0] + hc_base[:, :hc]
pre = gen_sigmoid(pre) + hc_eps
post = x[:, hc: 2*hc] * hc_scale[1] + hc_base[:, hc: 2*hc]
post = 2.0 * gen_sigmoid(post)
comb_flag = (x[:, 2*hc: ] * hc_scale[2] + hc_base[:, 2*hc: ]).reshape(t, hc, hc)
row_max = comb_flag.amax(-1, keepdim=True)
comb_flag = (comb_flag - row_max).exp()
row_sum = comb_flag.sum(-1, keepdim=True)
comb_flag = comb_flag / row_sum + hc_eps
col_sum = comb_flag.sum(-2, keepdim=True)
comb_flag = comb_flag / (col_sum + hc_eps)
for _ in range(sinkhorn_iters - 1):
row_sum = comb_flag.sum(-1, keepdim=True)
comb_flag = comb_flag / (row_sum + hc_eps)
col_sum = comb_flag.sum(-2, keepdim=True)
comb_flag = comb_flag / (col_sum + hc_eps)
return pre, post, comb_flag
def gen_hc_split_sinkhorn_trans(x, hc_scale, hc_base):
_, t = x.shape
hc_base = hc_base.reshape(mix_hc, 1)
print("hc_split_sinkhorn_trans x ", x.shape)
pre = x[:hc, :] * 1.0
pre = pre * hc_scale[0]
pre = pre + hc_base[:hc, :]
print("pre ", pre.shape)
pre_ = pre
pre = gen_sigmoid(pre) + hc_eps
post = x[hc: 2*hc, :] * hc_scale[1] + hc_base[hc: 2*hc, :]
post = 2.0 * gen_sigmoid(post)
comb_flag = (x[2*hc:, :] * hc_scale[2] + hc_base[2*hc:, :]).reshape(hc, hc, t)
row_max = comb_flag.amax(-2, keepdim=True)
comb_flag = (comb_flag - row_max).exp()
row_sum = comb_flag.sum(-2, keepdim=True)
comb_flag = comb_flag / row_sum + hc_eps
col_sum = comb_flag.sum(-3, keepdim=True)
comb_flag = comb_flag / (col_sum + hc_eps)
for _ in range(sinkhorn_iters - 1):
row_sum = comb_flag.sum(-2, keepdim=True)
comb_flag = comb_flag / (row_sum + hc_eps)
col_sum = comb_flag.sum(-3, keepdim=True)
comb_flag = comb_flag / (col_sum + hc_eps)
pre = pre.transpose(0, 1)
post = post.transpose(0, 1)
comb_flag = comb_flag.transpose(1, 2).transpose(0, 1)
return pre, post, comb_flag, pre_
def gen_hc_pre(x, hc_fn, hc_scale, hc_base):
t = x.shape[0]
x_16 = x.reshape((t, hc * d))
hc_base = hc_base.reshape(1, mix_hc)
x = x_16.to(torch.float32)
hc_fn = hc_fn.to(torch.float32)
res = torch.matmul(x, hc_fn.transpose(0, 1))
res = res / gen_rms_norm_denom(x)
mm_res = res
pre, post, comb = gen_hc_split_sinkhorn(res, hc_scale, hc_base)
mul_res = pre.reshape(t, hc, 1) * x.reshape(t, hc, d)
res = mul_res.sum(-2)
res = res.to(torch.bfloat16)
return res, post, comb, mm_res
def gen_hc_pre_trans(x, hc_fn, hc_scale, hc_base):
t = x.shape[0]
x_16 = x.reshape((t, hc * d))
hc_base = hc_base.reshape(1, mix_hc)
x = x_16.to(torch.float32)
hc_fn = hc_fn.to(torch.float32)
res = torch.matmul(hc_fn, x.transpose(0, 1))
rms_res = gen_rms_norm_denom(x)
res = res / (rms_res.reshape(1, t))
pre, post, comb, pre_ = gen_hc_split_sinkhorn_trans(res, hc_scale, hc_base)
mm_res = pre_
mul_res = pre.reshape(t, hc, 1) * x.reshape(t, hc, d)
res = mul_res.sum(-2)
res = res.to(torch.bfloat16)
return res, post, comb, mm_res
def gen_hc_pre_data(t=16, is_trans=False):
torch.manual_seed(42)
print("t is ", t)
x = torch.empty((t, hc, d), dtype=torch.bfloat16).uniform_(-1, 1)
hc_fn = torch.empty((mix_hc, hc*d), dtype=torch.float32).uniform_(-1, 1)
hc_scale = torch.empty((3,), dtype=torch.float32).uniform_(-1, 1)
hc_base = torch.empty((mix_hc, ), dtype=torch.float32).uniform_(-1, 1)
if is_trans:
res, post, comb, mm_res = gen_hc_pre_trans(x, hc_fn, hc_scale, hc_base)
else:
res, post, comb, mm_res = gen_hc_pre(x, hc_fn, hc_scale, hc_base)
return x, hc_fn, hc_scale, hc_base, res, post, comb, mm_res
class HC_PRE(torch.nn.Module):
def forward(self, x, hc_fn, hc_scale, hc_base):
return torch.ops.pypto.hc_pre(x, hc_fn, hc_scale, hc_base, 4, 20, 1e-6)
@pytest.mark.skip(reason="large test case")
def test_hc_pre_inmodel(t=16):
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
torch.manual_seed(42)
x, hc_fn, hc_scale, hc_base, y_gd, post_gd, comb_gd, mm_res_gd = gen_hc_pre_data(t)
print("gen golden success !!!")
x = x.npu()
hc_fn = hc_fn.npu()
hc_scale = hc_scale.npu()
hc_base = hc_base.npu()
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
model = torch.compile(HC_PRE(), dynamic=False, fullgraph=True, backend=npu_backend)
for _ in range(1):
y, post, comb = model(x, hc_fn, hc_scale, hc_base)
pypto.runtime._device_synchronize()
compare(y.cpu(), y_gd, "y", atol=0.0001, rtol=0.0078125)
print("y compare success!!!")
compare(post.cpu(), post_gd, "post", atol=0.000025, rtol=0.005)
print("post compare success!!!")
compare(comb.cpu(), comb_gd, "comb", atol=0.000025, rtol=0.005)
print("comb compare success!!!")
def test_hc_pre(t=16, is_trans=False):
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
torch.manual_seed(42)
x, hc_fn, hc_scale, hc_base, y_gd, post_gd, comb_gd, mm_res_gd = gen_hc_pre_data(t)
print("gen golden success !!!")
check_input_output_shape_dtype(x, hc_fn, hc_scale, hc_base)
y = torch.zeros_like(y_gd).npu()
post = torch.zeros_like(post_gd).npu()
comb = torch.zeros_like(comb_gd).npu()
inputs = [x.npu(), hc_fn.npu(), hc_scale.npu(), hc_base.npu(), y.npu(), post.npu(), comb.npu()]
if not is_trans:
hc_pre_kernel(*inputs)
else:
hc_pre_kernel_prefill(*inputs)
pypto.runtime._device_synchronize()
y = y.cpu()
post = post.cpu()
comb = comb.cpu()
compare(y, y_gd, "y", atol=0.0001, rtol=0.0078125)
print("y compare success!!!")
compare(post, post_gd, "post", atol=0.000025, rtol=0.005)
print("post compare success!!!")
compare(comb, comb_gd, "comb", atol=0.000025, rtol=0.005)
print("comb compare success!!!")
@pytest.mark.skip(reason="ci torch version")
def te_hc_pre_prefill(t=512):
print("hc_pre_prefill ")
test_hc_pre(t=t, is_trans=True)
if __name__ == "__main__":
print("start test !!!")
test_hc_pre_inmodel(16)
test_hc_pre(16)