"""Test fusible operations - migrated from TransformerEngine
Test content (all in TestBasicOps, following TE convention):
- Forward and backward computation correctness for basic ops (RMSNorm)
- Forward and backward computation correctness for 11 activation functions
- SwiGLU forward + backward test
- Backward activation + bias tests
- Precision validation across multiple datatypes
- Compatibility tests for various input shapes
NOTE: Most Quantize/Bias coverage from the upstream suite is still omitted.
BasicLinear keeps focused FP8/MXFP8 coverage here because it does not
require distributed setup.
Pass criteria:
- Forward output error within tolerance
- Backward gradient error within tolerance
"""
from __future__ import annotations
from collections.abc import Iterable
import math
import pytest
import torch
import torch.nn.functional as F
import torch_npu
from transformer_engine.common import recipe
from transformer_engine.pytorch import autocast
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.constants import NPUVersion
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
is_fp8_available,
is_mxfp8_available,
)
from transformer_engine.pytorch.utils import check_npu_version
from utils import (
dtype_tols,
quantization_tols,
assert_close,
make_reference_and_test_tensors,
to_cpu,
)
_FAST_MODE = True
_dtypes = [torch.float32, torch.float16]
if hasattr(torch, 'bfloat16'):
_dtypes.append(torch.bfloat16)
if _FAST_MODE:
_dtypes = [torch.float32, torch.float16, torch.bfloat16]
_ACTIVATION_TYPES = (
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
)
_OUT_SHAPES = ((37,), (2, 13), (32, 1, 32))
if _FAST_MODE:
_OUT_SHAPES = ((32, 32),)
_DTYPE_NAMES = {
torch.float32: "fp32",
torch.float16: "fp16",
torch.bfloat16: "bf16",
}
def _log(msg: str, end: str = "\n") -> None:
"""Print progress log with flush for real-time output"""
print(msg, end=end, flush=True)
def _bounded_tensor(shape, *, dtype=torch.bfloat16):
data = torch.arange(
torch.tensor(shape).prod().item(),
dtype=torch.float32,
device="cpu",
).reshape(shape)
data = data.remainder(17).sub(8).div(16)
return data.to(device="npu", dtype=dtype)
def pytorch_activation_forward(x: torch.Tensor, activation: str) -> torch.Tensor:
"""PyTorch activation function forward computation reference implementation"""
if activation == "gelu":
return F.gelu(x, approximate="tanh")
elif activation == "geglu":
x1, x2 = x.chunk(2, dim=-1)
return F.gelu(x1, approximate="tanh") * x2
elif activation == "qgelu":
return x * torch.sigmoid(1.702 * x)
elif activation == "qgeglu":
x1, x2 = x.chunk(2, dim=-1)
return x1 * torch.sigmoid(1.702 * x1) * x2
elif activation == "relu":
return F.relu(x)
elif activation == "reglu":
x1, x2 = x.chunk(2, dim=-1)
return F.relu(x1) * x2
elif activation == "sigmoid":
return F.sigmoid(x)
elif activation == "glu":
in_shape = x.shape
x_reshaped = x.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x_flipped = x_reshaped.flip(-2)
x_restored = x_flipped.reshape(in_shape)
return F.glu(x_restored)
elif activation == "srelu":
return F.relu(x) ** 2
elif activation == "sreglu":
x1, x2 = x.chunk(2, dim=-1)
return F.relu(x1) ** 2 * x2
elif activation == "silu":
return F.silu(x)
elif activation == "swiglu":
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
def get_te_activation_op(activation: str, **kwargs):
"""Get TE activation function operation"""
activation_map = dict(
gelu=te_ops.GELU,
geglu=te_ops.GEGLU,
glu=te_ops.GLU,
qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU,
reglu=te_ops.ReGLU,
srelu=te_ops.SReLU,
sreglu=te_ops.SReGLU,
silu=te_ops.SiLU,
swiglu=te_ops.SwiGLU,
)
if activation not in activation_map:
raise ValueError(f"Unsupported activation: {activation}")
return activation_map[activation](**kwargs)
def _run_activation_test(
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
) -> None:
"""Run a single activation forward+backward test with logging"""
dt_name = _DTYPE_NAMES.get(dtype, str(dtype))
_log(f" [{activation:>7s} | {dt_name:>4s} | shape={list(out_shape)}] ", end="")
in_shape = list(out_shape)
if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
y_ref = pytorch_activation_forward(x_ref, activation)
y_ref.backward(dy_ref)
act_op = get_te_activation_op(activation)
y_test = act_op(x_test)
y_test.backward(dy_test)
tols = dtype_tols(dtype)
assert_close(to_cpu(y_test), to_cpu(y_ref), **tols)
assert_close(to_cpu(x_test.grad), to_cpu(x_ref.grad), **tols)
_log("OK")
class TestBasicOps:
"""Tests for individual basic operations
Migrated from TransformerEngine test_fusible_ops.py::TestBasicOps.
Includes basic ops (RMSNorm), activation functions, SwiGLU, and
backward activation+bias tests, following TE's convention of
placing all non-fused operations in a single class.
"""
def test_basic_linear_forward_backward_without_tensor_parallel(self) -> None:
"""BasicLinear without TP should match dense linear math."""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
seq_len = 8
in_features = 8
out_features = 12
input_ref, input_test = make_reference_and_test_tensors(
(seq_len, in_features),
test_dtype=torch.float32,
test_device=device,
)
weight_ref, weight_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=torch.float32,
test_device=device,
)
grad_output_ref, grad_output_test = make_reference_and_test_tensors(
(seq_len, out_features),
test_dtype=torch.float32,
test_device=device,
requires_grad=False,
)
layer = BasicLinear(
in_features,
out_features,
device=device,
dtype=torch.float32,
)
with torch.no_grad():
layer.weight.copy_(weight_test)
output = layer(input_test)
expected_output = torch.nn.functional.linear(input_ref, weight_ref)
assert_close(to_cpu(output), expected_output, **dtype_tols(torch.float32))
output.backward(grad_output_test)
expected_grad_input = torch.matmul(grad_output_ref, weight_ref)
expected_grad_weight = torch.matmul(grad_output_ref.t(), input_ref)
assert_close(to_cpu(input_test.grad), expected_grad_input, **dtype_tols(torch.float32))
assert_close(to_cpu(layer.weight.grad), expected_grad_weight, **dtype_tols(torch.float32))
@pytest.mark.skipif(
not check_npu_version(NPUVersion.A5), reason="BasicLinear FP8 test requires Atlas A5"
)
def test_basic_linear_float8_current_scaling_forward_backward(self) -> None:
"""BasicLinear FP8 current scaling should stay close to dense linear math."""
fp8_available, reason = is_fp8_available(return_reason=True)
if not fp8_available:
pytest.skip(reason)
seq_len = 32
in_features = 32
out_features = 32
full_input = _bounded_tensor((seq_len, in_features))
full_weight = _bounded_tensor((out_features, in_features))
full_grad_output = _bounded_tensor((seq_len, out_features))
local_input = full_input.detach().clone()
local_input.requires_grad_(True)
layer = BasicLinear(
in_features,
out_features,
device="npu",
dtype=torch.bfloat16,
)
with torch.no_grad():
layer.weight.copy_(full_weight)
tols = quantization_tols("fp8_current_scaling")
try:
with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()):
output = layer(local_input)
expected_output = torch.nn.functional.linear(full_input.float(), full_weight.float())
torch.testing.assert_close(output.float(), expected_output, **tols)
output.backward(full_grad_output)
expected_grad_input = torch.matmul(full_grad_output.float(), full_weight.float())
expected_grad_weight = torch.matmul(full_grad_output.float().t(), full_input.float())
torch.testing.assert_close(
local_input.grad.float(),
expected_grad_input,
**tols,
)
torch.testing.assert_close(
layer.weight.grad.float(),
expected_grad_weight,
**tols,
)
finally:
FP8GlobalStateManager.reset()
@pytest.mark.skipif(
not check_npu_version(NPUVersion.A5), reason="BasicLinear MXFP8 test requires Atlas A5"
)
def test_basic_linear_mxfp8_forward_backward(self) -> None:
"""BasicLinear MXFP8 should stay close to dense linear math."""
mxfp8_available, reason = is_mxfp8_available(return_reason=True)
if not mxfp8_available:
pytest.skip(reason)
seq_len = 32
in_features = 32
out_features = 32
full_input = _bounded_tensor((seq_len, in_features))
full_weight = _bounded_tensor((out_features, in_features))
full_grad_output = _bounded_tensor((seq_len, out_features))
local_input = full_input.detach().clone()
local_input.requires_grad_(True)
layer = BasicLinear(
in_features,
out_features,
device="npu",
dtype=torch.bfloat16,
)
with torch.no_grad():
layer.weight.copy_(full_weight)
tols = quantization_tols("mxfp8_block_scaling")
try:
with autocast(enabled=True, recipe=recipe.MXFP8BlockScaling()):
output = layer(local_input)
expected_output = torch.nn.functional.linear(full_input.float(), full_weight.float())
torch.testing.assert_close(output.float(), expected_output, **tols)
output.backward(full_grad_output)
expected_grad_input = torch.matmul(full_grad_output.float(), full_weight.float())
expected_grad_weight = torch.matmul(full_grad_output.float().t(), full_input.float())
torch.testing.assert_close(
local_input.grad.float(),
expected_grad_input,
**tols,
)
torch.testing.assert_close(
layer.weight.grad.float(),
expected_grad_weight,
**tols,
)
finally:
FP8GlobalStateManager.reset()
@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
def test_rmsnorm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
eps: float = 0.3,
zero_centered_gamma: bool,
) -> None:
"""RMSNorm forward + backward test"""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
in_shape = list(in_shape)[:-1] + list(weight_shape)
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
if zero_centered_gamma:
y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
else:
y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
y_ref.backward(dy_ref)
op = te_ops.RMSNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
y_test = op(x_test)
y_test.backward(dy_test)
tols = dtype_tols(dtype)
assert_close(to_cpu(y_test), to_cpu(y_ref), **tols)
assert_close(to_cpu(x_test.grad), to_cpu(x_ref.grad), **tols)
assert_close(to_cpu(op.weight.grad), to_cpu(w_ref.grad), **tols)
@pytest.mark.parametrize("activation", _ACTIVATION_TYPES)
@pytest.mark.parametrize("out_shape", _OUT_SHAPES)
@pytest.mark.parametrize("dtype", _dtypes)
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
) -> None:
"""Activation function forward + backward test"""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
_run_activation_test(activation, out_shape, dtype, device)
@pytest.mark.parametrize("dtype", _dtypes)
def test_activation_shapes(self, dtype: torch.dtype) -> None:
"""Test activation functions with various shapes (1D~4D)"""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
dt_name = _DTYPE_NAMES.get(dtype, str(dtype))
test_shapes = [
(128,),
(32, 64),
(16, 32, 64),
(8, 16, 32, 64),
]
for shape in test_shapes:
_log(f" [GELU shape test | {dt_name:>4s} | shape={list(shape)}] ", end="")
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
y_ref = F.gelu(x_ref, approximate="tanh")
y_test = te_ops.GELU()(x_test)
assert_close(to_cpu(y_test), to_cpu(y_ref), **dtype_tols(dtype))
_log("OK")
@pytest.mark.parametrize("dtype", _dtypes)
def test_swiglu(self, dtype: torch.dtype) -> None:
"""SwiGLU forward + backward test"""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
dt_name = _DTYPE_NAMES.get(dtype, str(dtype))
_log(f" [SwiGLU | {dt_name:>4s}] ", end="")
out_shape = (32, 32)
in_shape = (32, 64)
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = F.silu(x1) * x2
y_ref.backward(dy_ref)
y_test = te_ops.SwiGLU()(x_test)
y_test.backward(dy_test)
tols = dtype_tols(dtype)
assert_close(to_cpu(y_test), to_cpu(y_ref), **tols)
assert_close(to_cpu(x_test.grad), to_cpu(x_ref.grad), **tols)
_log("OK")
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_backward_activation_bias(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
) -> None:
"""Backward activation + bias test
Since te_ops.Bias is not available on NPU, we test:
y = Activation(x + Bias) using PyTorch for bias and TE for activation.
"""
device = torch.device("npu" if torch_npu.npu.is_available() else "cpu")
dt_name = _DTYPE_NAMES.get(dtype, str(dtype))
_log(f" [{activation:>7s}+bias | {dt_name:>4s} | shape={list(out_shape)}] ", end="")
hidden_size = out_shape[-1]
in_shape = list(out_shape)
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
(hidden_size,),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size])
if activation == "gelu":
y_ref = F.gelu(y_ref, approximate="tanh")
elif activation == "relu":
y_ref = F.relu(y_ref)
y_ref.backward(dy_ref)
bias_shape = [1] * (len(in_shape) - 1) + [hidden_size]
z_test = x_test + b_test.reshape(bias_shape)
act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU
y_test = act_type()(z_test)
y_test.backward(dy_test)
tols = dtype_tols(dtype)
assert_close(to_cpu(y_test), to_cpu(y_ref), **tols)
assert_close(to_cpu(x_test.grad), to_cpu(x_ref.grad), **tols)
assert_close(to_cpu(b_test.grad), to_cpu(b_ref.grad), **tols)
_log("OK")
if __name__ == "__main__":
pytest.main([__file__, "-v"])