"""Shared pytest helpers for Ascend NPU tests."""
from __future__ import annotations
import socket
from typing import Optional
import torch
def dtype_tols(dtype: torch.dtype) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
Args:
dtype: PyTorch datatype
Returns:
Dictionary containing rtol and atol
"""
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def quantization_tols(name: str) -> dict[str, float]:
"""Estimated numerical error for a quantization scheme
NOTE: Not used in NPU tests since FP8 quantization is not supported.
Retained for API compatibility when NPU FP8 support is added.
Args:
name: Quantization scheme name
Returns:
Dictionary containing rtol and atol
"""
if name in (
"fp8",
"fp8_delayed_scaling",
"fp8_current_scaling",
"mxfp8",
"mxfp8_block_scaling",
):
return dict(rtol=0.125, atol=0.0675)
if name in ("nvfp4", "mxfp4", "mxfp4_block_scaling"):
return dict(rtol=0.25, atol=0.125)
raise ValueError(f"Unsupported quantization scheme ({name})")
@torch.no_grad
def assert_close(
actual: Optional[torch.Tensor],
expected: Optional[torch.Tensor],
*,
check_device: bool = False,
check_dtype: bool = False,
check_layout: bool = False,
**kwargs,
) -> None:
"""Assert that two tensors are close
This is a wrapper around torch.testing.assert_close. It changes the defaults
for device and dtype checks (useful when the reference implementation is computed
in high precision on CPU).
Args:
actual: Actual tensor
expected: Expected tensor
check_device: Whether to check device
check_dtype: Whether to check dtype
check_layout: Whether to check layout
**kwargs: Additional arguments passed to torch.testing.assert_close
"""
torch.testing.assert_close(
actual,
expected,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
**kwargs,
)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | tuple[int, ...],
*,
min_val: float = 0.0,
max_val: float = 1.0,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "npu",
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch operations
in high precision on CPU. The test tensor is intended for use in
TransformerEngine operations on NPU.
Args:
shape: Tensor shape
min_val: Minimum random value
max_val: Maximum random value
ref_dtype: Reference tensor dtype
ref_device: Reference tensor device
test_dtype: Test tensor dtype
test_device: Test tensor device
requires_grad: Whether gradients are required
Returns:
Tuple of (reference tensor, test tensor)
"""
ref = torch.empty(shape, dtype=ref_dtype, device=ref_device)
ref.uniform_(min_val, max_val)
test = ref.to(device=test_device, dtype=test_dtype)
ref.copy_(test.to(dtype=ref.dtype))
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert to FP64 CPU tensor
Args:
tensor: Input tensor
Returns:
FP64 CPU tensor or None
"""
if tensor is None:
return None
return tensor.detach().to(dtype=torch.float64, device="cpu")
def npu_available() -> bool:
"""Return whether NPU device is available."""
return hasattr(torch, "npu") and torch.npu.is_available()
def npu_device_count() -> int:
"""Return the number of available NPU devices."""
if not npu_available():
return 0
return torch.npu.device_count()
def get_free_port() -> str:
"""Return a free TCP port allocated by the OS."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return str(s.getsockname()[1])