import logging
import math
import pytest
import torch
import torch_npu
import ops_multimodal_fusion
if not hasattr(torch.ops.ops_multimodal_fusion, "polygamma"):
pytest.skip(
"ops_multimodal_fusion.polygamma not registered for current NPU_ARCH; skipping module",
allow_module_level=True,
)
def test_polygamma_interface_exist():
"""Test that the 'ops_multimodal_fusion.polygamma' operator is registered in torch.ops."""
logging.info(torch.ops.ops_multimodal_fusion.polygamma)
assert hasattr(torch.ops.ops_multimodal_fusion, "polygamma"), \
"The 'polygamma' operator is not registered in the 'torch.ops.ops_multimodal_fusion' namespace."
SHAPES = [
(1,),
(7,),
(1024,),
(10000,),
(10, 10),
(32, 32),
(100, 100),
(10, 100),
(256, 512),
(16, 32, 64),
(1, 3, 32, 32),
(4, 3, 64, 64),
(100000,),
(1000000,),
(2048, 2048),
(64, 128, 256),
]
DTYPES = [torch.float32]
ORDERS = [1, 2, 3, 4, 5, 6]
def _sample_positive(shape, dtype, low, high=20.0, seed=0):
"""Uniform positive inputs in [low, high]. x > 0 is required by the kernel
(no reflection is applied for x <= 0).
"""
g = torch.Generator().manual_seed(seed)
return torch.empty(*shape, dtype=dtype).uniform_(low, high, generator=g)
def _tolerances_for(n):
"""Tolerances mirror digamma for n<=2; relax slightly for higher n because
|psi^(n)(x)| scales roughly as n!/x^(n+1) and float32 loses resolution in
large-magnitude regions near x ~ 0.5.
"""
if n <= 2:
return dict(rtol=1e-4, atol=1e-5)
if n <= 4:
return dict(rtol=3e-4, atol=1e-4)
return dict(rtol=5e-4, atol=1e-4)
def _low_bound_for(n):
if n == 1:
return 0.1
if n == 2:
return 0.25
if n <= 4:
return 0.5
return 0.75
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("n", ORDERS)
def test_polygamma_operator(shape, dtype, n):
"""Compare NPU polygamma(x, n) against torch.special.polygamma on positive inputs."""
seed = n * 97 + (abs(hash(shape)) % 997)
x = _sample_positive(shape, dtype, low=_low_bound_for(n), high=20.0, seed=seed)
x_npu = x.npu()
y_npu = torch.ops.ops_multimodal_fusion.polygamma(x_npu, n)
y = y_npu.cpu()
assert y.dtype == dtype, f"dtype mismatch: {y.dtype} vs {dtype}"
assert y.shape == x.shape, f"shape mismatch: {y.shape} vs {x.shape}"
expected = torch.special.polygamma(n, x.to(torch.float64)).to(dtype)
tol = _tolerances_for(n)
assert torch.allclose(y, expected, **tol), (
f"polygamma(n={n}) mismatch: "
f"max abs diff = {(y - expected).abs().max().item()}, "
f"max rel diff = "
f"{((y - expected) / expected.abs().clamp_min(1e-30)).abs().max().item()}"
)
logging.info(f"Test passed: n={n}, shape={shape}, dtype={dtype}")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("n", ORDERS)
def test_polygamma_named_values(n):
"""Reference from torch.special.polygamma in float64 at several named points."""
xs = torch.tensor([1.0, 2.0, 3.0, 0.5, 1.5, 5.0, 10.0, 100.0],
dtype=torch.float32)
expected = torch.special.polygamma(n, xs.to(torch.float64)).to(torch.float32)
y = torch.ops.ops_multimodal_fusion.polygamma(xs.npu(), n).cpu()
tol = _tolerances_for(n)
assert torch.allclose(y, expected, **tol), (
f"polygamma(n={n}) named-value mismatch; got {y.tolist()} "
f"expected {expected.tolist()}"
)
logging.info(f"Named-values test passed (n={n}).")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
def test_polygamma_trigamma_identities():
"""Closed-form trigamma (n=1) values: psi'(1)=pi^2/6, psi'(1/2)=pi^2/2, psi'(2)=pi^2/6-1."""
pi_sq_over_6 = math.pi * math.pi / 6.0
xs = torch.tensor([1.0, 2.0, 0.5], dtype=torch.float32)
expected = torch.tensor(
[pi_sq_over_6, pi_sq_over_6 - 1.0, math.pi * math.pi / 2.0],
dtype=torch.float32,
)
y = torch.ops.ops_multimodal_fusion.polygamma(xs.npu(), 1).cpu()
assert torch.allclose(y, expected, rtol=1e-4, atol=1e-5), (
f"trigamma identity mismatch: got {y.tolist()} expected {expected.tolist()}"
)
logging.info("Trigamma identity test passed.")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("n", ORDERS)
def test_polygamma_wide_range(n):
"""Covers small, mid, and large positive inputs per supported order."""
low = _low_bound_for(n)
xs = torch.tensor(
[low, 0.5, 1.0, 1.5, 2.718281828, 3.14159265,
5.0, 10.0, 50.0, 100.0, 500.0, 1000.0, 10000.0],
dtype=torch.float32,
)
xs = xs[xs >= low]
y = torch.ops.ops_multimodal_fusion.polygamma(xs.npu(), n).cpu()
expected = torch.special.polygamma(n, xs.to(torch.float64)).to(torch.float32)
max_abs = (y - expected).abs().max().item()
max_rel = ((y - expected) / expected.abs().clamp_min(1e-6)).abs().max().item()
tol = _tolerances_for(n)
assert max_rel < tol["rtol"], (
f"polygamma(n={n}) wide-range mismatch: max abs={max_abs}, max rel={max_rel}\n"
f"got = {y.tolist()}\n"
f"expected= {expected.tolist()}"
)
logging.info(f"Wide-range test passed (n={n}, max rel={max_rel:.3e}).")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
def test_polygamma_rejects_unsupported_orders():
"""n=0 (digamma) and n>6 are out of scope on dav-3510 and must error cleanly."""
xs = torch.tensor([1.0, 2.0], dtype=torch.float32).npu()
with pytest.raises(RuntimeError, match="polygamma"):
torch.ops.ops_multimodal_fusion.polygamma(xs, 0)
with pytest.raises(RuntimeError, match="polygamma"):
torch.ops.ops_multimodal_fusion.polygamma(xs, 7)
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
def test_polygamma_empty_tensor():
"""Empty tensor passes through with matching shape/dtype (no kernel launch)."""
x = torch.empty((0,), dtype=torch.float32).npu()
y = torch.ops.ops_multimodal_fusion.polygamma(x, 2).cpu()
assert y.shape == (0,)
assert y.dtype == torch.float32