import torch
import torch_npu
import torch.nn.functional as F
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNpuMhcPre(TestCase):
def cpu_op_exec(
self,
x: torch.Tensor,
phi: torch.Tensor,
alpha: torch.Tensor,
bias: torch.Tensor,
gamma: torch.Tensor = None,
norm_eps: float = 1e-6,
hc_eps: float = 1e-6,
):
T, N, D = x.shape
ND = N * D
x = x.reshape(T, ND).float()
inv_rms = torch.rsqrt(x.square().mean(-1, keepdim=True) + norm_eps)
if gamma is not None:
gamma = gamma.reshape(ND).float()
h_mix = F.linear(x * gamma, phi.float())
weight = h_mix * inv_rms
else:
h_mix = F.linear(x, phi.float())
weight = h_mix * inv_rms
h_pre, h_post, h_res = weight.split([N, N, N * N], dim=-1)
h_res = h_res.unflatten(-1, (N, N))
h_pre = torch.sigmoid(h_pre * alpha[0] + bias[:N].unsqueeze(0)) + hc_eps
h_post = 2 * torch.sigmoid(h_post * alpha[1] + bias[N:2 * N].unsqueeze(0))
h_res = h_res * alpha[2] + bias[2 * N:].view(N, N).unsqueeze(0)
h_in = torch.sum(
h_pre.unsqueeze(-1) * x.unflatten(dim=-1, sizes=(N, -1)),
dim=1
).bfloat16()
return (h_in, h_post, h_res, inv_rms[:, 0], h_mix, h_pre)
def custom_op_exec(self, x, phi, alpha, bias, gamma, out_flag):
return torch_npu.npu_mhc_pre(
x,
phi,
alpha,
bias,
gamma=gamma,
out_flag=out_flag
)
def build_input_tensors(self, T, n, D):
x = torch.randn(T, n, D, dtype=torch.bfloat16)
phi = torch.randn(n * n + 2 * n, n * D, dtype=torch.float32)
alpha = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32)
gamma = torch.ones(n, D, dtype=torch.float32)
bias_pre = torch.full((n,), 0.01, dtype=torch.float32)
bias_post = torch.full((n,), 0.01, dtype=torch.float32)
bias_res = torch.full((n, n), 0.01, dtype=torch.float32)
bias = torch.cat([bias_pre, bias_post, bias_res.reshape(-1)], dim=0)
return x, phi, alpha, bias, gamma
def run_and_check(self, T, n, D, out_flag, output_names, tol_map):
with torch.no_grad():
x, phi, alpha, bias, gamma = self.build_input_tensors(T, n, D)
expected_output = self.cpu_op_exec(x, phi, alpha, bias, gamma)[:len(output_names)]
actual_output = self.custom_op_exec(
x.npu(), phi.npu(), alpha.npu(), bias.npu(), gamma.npu(), out_flag=out_flag
)[:len(output_names)]
for name, exp, act in zip(output_names, expected_output, actual_output):
try:
self.assertRtolEqual(
exp.float().numpy(),
act.float().cpu().numpy(),
prec=tol_map[name]
)
except AssertionError as e:
raise AssertionError(
f"Output {name} compare failed for shape (T={T}, n={n}, D={D}), "
f"out_flag={out_flag}: {e}"
)
@SupportedDevices(['Ascend950'])
def test_npu_mhc_pre_prefill_training(self, device="npu"):
T, n, D = (4096, 4, 5120)
out_flag = 1
output_names = ["h_in", "h_post", "h_res", "inv_rms", "h_mix", "h_pre"]
tol_map = {
"h_in": 2 ** -7,
"h_post": 1e-3,
"h_res": 1e-3,
"inv_rms": 1e-3,
"h_mix": 1e-3,
"h_pre": 1e-3,
}
self.run_and_check(T, n, D, out_flag, output_names, tol_map)
@SupportedDevices(['Ascend950'])
def test_npu_mhc_pre_prefill_inference(self, device="npu"):
T, n, D = (1024, 4, 2560)
out_flag = 0
output_names = ["h_in", "h_post", "h_res"]
tol_map = {
"h_in": 2 ** -7,
"h_post": 1e-3,
"h_res": 1e-3,
}
self.run_and_check(T, n, D, out_flag, output_names, tol_map)
@SupportedDevices(['Ascend950'])
def test_npu_mhc_pre_decode_inference(self, device="npu"):
T, n, D = (64, 4, 2560)
out_flag = 0
output_names = ["h_in", "h_post", "h_res"]
tol_map = {
"h_in": 2 ** -7,
"h_post": 1e-3,
"h_res": 1e-3,
}
self.run_and_check(T, n, D, out_flag, output_names, tol_map)
if __name__ == "__main__":
run_tests()