import logging
import math
import pytest
import torch
import torch_npu
import ops_multimodal_fusion
if not hasattr(torch.ops.ops_multimodal_fusion, "angle"):
pytest.skip(
"ops_multimodal_fusion.angle not registered for current NPU_ARCH; skipping module",
allow_module_level=True,
)
def test_angle_interface_exist():
"""The 'ops_multimodal_fusion.angle' operator must be registered in torch.ops."""
logging.info(torch.ops.ops_multimodal_fusion.angle)
assert hasattr(torch.ops.ops_multimodal_fusion, "angle"), \
"The 'angle' operator is not registered in the 'torch.ops.ops_multimodal_fusion' namespace."
SHAPES = [
(1,),
(7,),
(1024,),
(10000,),
(10, 10),
(32, 32),
(100, 100),
(10, 100),
(256, 512),
(16, 32, 64),
(1, 3, 32, 32),
(4, 3, 64, 64),
(100000,),
(1000000,),
(2048, 2048),
(64, 128, 256),
]
DTYPES = [torch.float32, torch.float16]
def _dtype_min_abs(dtype):
"""Per-dtype lower bound on |x| for which the kernel's diff/(diff+tiny)
rounds to exactly 1.0 and y rounds to pi. Matches TINY chosen in-kernel:
f32: 1e-35 -> any |x| > ~1e-30 is safe
f16: 1e-4 -> use |x| >= 1e-2 to keep well inside fp16 resolution
"""
return 1e-4 if dtype == torch.float32 else 1e-2
def _dtype_tol(dtype):
"""Per-dtype tolerance vs torch.angle.
f32: 1 ulp of pi is ~4e-7 -> 1e-6 safely.
f16: pi itself isn't representable exactly; ulp(pi) in f16 ~ 2e-3, plus
the diff/(diff+tiny) round error, so 5e-3 / 5e-3 is realistic.
"""
if dtype == torch.float32:
return dict(rtol=1e-6, atol=1e-6)
return dict(rtol=5e-3, atol=5e-3)
def _sample_mixed_sign(shape, dtype, seed=0, low=-20.0, high=20.0):
"""Uniform inputs in [low, high] — straddles zero to exercise both branches."""
g = torch.Generator().manual_seed(seed)
return torch.empty(*shape, dtype=dtype).uniform_(low, high, generator=g)
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_operator(shape, dtype):
"""Compare NPU angle against torch.angle on finite mixed-sign inputs."""
seed = abs(hash(shape)) % 997
x = _sample_mixed_sign(shape, dtype, seed=seed)
floor = _dtype_min_abs(dtype)
x = torch.where(x.abs() < floor,
torch.full_like(x, floor) * x.sign().where(x != 0, torch.ones_like(x)),
x)
x_npu = x.npu()
y_npu = torch.ops.ops_multimodal_fusion.angle(x_npu)
y = y_npu.cpu()
assert y.dtype == dtype, f"dtype mismatch: {y.dtype} vs {dtype}"
assert y.shape == x.shape, f"shape mismatch: {y.shape} vs {x.shape}"
expected = torch.angle(x)
assert torch.allclose(y, expected, **_dtype_tol(dtype)), (
f"angle mismatch (dtype={dtype}): "
f"max abs diff = {(y - expected).abs().max().item()}"
)
logging.info(f"Test passed: shape={shape}, dtype={dtype}")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_named_values(dtype):
"""Expected outputs at signed values, both signs of zero, and a few
mid-range positives/negatives — within each dtype's resolution.
"""
if dtype == torch.float32:
values = [1.0, 2.0, 0.1, 1e-3, 1e-10,
-1.0, -2.0, -0.1, -1e-3, -1e-10,
0.0, -0.0,
1234.5, -1234.5]
else:
values = [1.0, 2.0, 0.1, 0.01,
-1.0, -2.0, -0.1, -0.01,
0.0, -0.0,
100.0, -100.0]
xs = torch.tensor(values, dtype=dtype)
expected = torch.angle(xs)
y = torch.ops.ops_multimodal_fusion.angle(xs.npu()).cpu()
assert torch.allclose(y, expected, **_dtype_tol(dtype)), (
f"angle named-values mismatch (dtype={dtype}); "
f"got {y.tolist()} expected {expected.tolist()}"
)
logging.info(f"Named-values test passed (dtype={dtype}).")
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_positive_is_zero(dtype):
"""Every strictly positive input must map to exactly 0 (both dtypes)."""
high = 1e4 if dtype == torch.float32 else 1e3
x = torch.linspace(_dtype_min_abs(dtype), high, 4096, dtype=dtype)
y = torch.ops.ops_multimodal_fusion.angle(x.npu()).cpu()
assert torch.equal(y, torch.zeros_like(x)), (
f"positive inputs produced nonzero angle (dtype={dtype}); "
f"max |y|={y.abs().max().item()}"
)
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_negative_is_pi(dtype):
"""Every strictly negative input (magnitude >> tiny) must map to pi within tolerance."""
high = 1e4 if dtype == torch.float32 else 1e3
x = -torch.linspace(_dtype_min_abs(dtype), high, 4096, dtype=dtype)
y = torch.ops.ops_multimodal_fusion.angle(x.npu()).cpu()
expected = torch.full_like(x, math.pi)
assert torch.allclose(y, expected, **_dtype_tol(dtype)), (
f"negative inputs deviated from pi (dtype={dtype}): "
f"max abs err={(y - expected).abs().max().item()}"
)
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_nan_propagates(dtype):
"""NaN input must produce NaN output."""
xs = torch.tensor([float('nan'), 1.0, -1.0, float('nan')], dtype=dtype)
y = torch.ops.ops_multimodal_fusion.angle(xs.npu()).cpu()
assert torch.isnan(y[0]).item() and torch.isnan(y[3]).item(), \
f"NaN did not propagate (dtype={dtype}): got {y.tolist()}"
assert y[1].item() == 0.0, f"expected 0 for x=1 (dtype={dtype}), got {y[1].item()}"
pi_tol = _dtype_tol(dtype)["atol"]
assert abs(y[2].item() - math.pi) < pi_tol, \
f"expected ~pi for x=-1 (dtype={dtype}), got {y[2].item()}"
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_empty_tensor(dtype):
"""Empty tensor passes through with matching shape/dtype (no kernel launch)."""
x = torch.empty((0,), dtype=dtype).npu()
y = torch.ops.ops_multimodal_fusion.angle(x).cpu()
assert y.shape == (0,)
assert y.dtype == dtype
@pytest.mark.skip(
reason=(
"torch_npu in the current CANN release lacks D2D strided copy support "
"(aclnnInplaceCopy fails with error 561103 for any non-contiguous NPU "
"tensor, regardless of dtype). The kernel itself only operates on "
"contiguous buffers; callers holding a transposed/strided tensor must "
"materialize it themselves before invoking the op. "
"Re-enable once torch_npu ships strided D2D."
)
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_angle_non_contiguous_input(dtype):
"""Non-contiguous tensors are handled (kernel makes a contiguous copy)."""
base = _sample_mixed_sign((32, 32), dtype, seed=7)
floor = _dtype_min_abs(dtype)
base = torch.where(base.abs() < floor,
torch.full_like(base, floor) *
base.sign().where(base != 0, torch.ones_like(base)),
base)
x = base.t()
assert not x.is_contiguous()
y = torch.ops.ops_multimodal_fusion.angle(x.npu()).cpu()
expected = torch.angle(x.contiguous())
assert torch.allclose(y, expected, **_dtype_tol(dtype)), (
f"non-contiguous mismatch (dtype={dtype}): "
f"max abs diff = {(y - expected).abs().max().item()}"
)
@pytest.mark.skipif(not torch.npu.is_available(), reason="NPU device not found")
def test_angle_rejects_unsupported_dtype():
"""float64 (and other unsupported dtypes) must error cleanly."""
xs = torch.tensor([1.0, -1.0], dtype=torch.float64).npu()
with pytest.raises(RuntimeError, match="angle"):
torch.ops.ops_multimodal_fusion.angle(xs)