"""Unit tests for op.py and fuser.py migration."""
import pytest
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from transformer_engine.pytorch.ops.op import (
OperationContext,
BasicOperation,
FusedOperation,
)
from transformer_engine.pytorch.ops.fuser import (
OperationFuser,
_is_graph_capturing,
_split_tuple,
)
class SimpleLinearOp(BasicOperation):
"""A simple linear operation for testing: y = x * weight + bias"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))
def num_quantizers(self, mode: str) -> int:
"""No quantization for this simple op."""
return 0
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer=None,
next_op_input_quantizer=None,
**kwargs,
) -> torch.Tensor:
"""Forward pass: y = x @ weight.T + bias"""
ctx.save_for_backward(input_)
output = torch.nn.functional.linear(input_, self.weight, self.bias)
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
):
"""Backward pass."""
input_ = ctx.saved_tensors[0]
grad_input = grad_output @ self.weight
grad_weight = grad_output.T @ input_
grad_bias = grad_output.sum(dim=0)
return grad_input, [grad_weight, grad_bias]
class SimpleReLIOp(BasicOperation):
"""A simple ReLU operation for testing."""
def __init__(self):
super().__init__()
def num_quantizers(self, mode: str) -> int:
"""No quantization for this simple op."""
return 0
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer=None,
next_op_input_quantizer=None,
**kwargs,
) -> torch.Tensor:
"""Forward pass: y = ReLU(x)"""
ctx.save_for_backward(input_)
output = torch.nn.functional.relu(input_)
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
):
"""Backward pass."""
input_ = ctx.saved_tensors[0]
grad_input = grad_output * (input_ > 0).float()
return grad_input, []
class TestSimpleOperation:
"""Test simple operations with forward and backward."""
def test_simple_linear_forward(self):
"""Test SimpleLinearOp forward pass."""
op = SimpleLinearOp(in_features=10, out_features=5)
input_tensor = torch.randn(3, 10)
fuser = OperationFuser([op])
output = fuser(input_tensor)
assert output.shape == (3, 5), f"Expected shape (3, 5), got {output.shape}"
expected = torch.nn.functional.linear(input_tensor, op.weight, op.bias)
assert torch.allclose(output, expected, atol=1e-6), "Forward pass result mismatch"
def test_simple_linear_backward(self):
"""Test SimpleLinearOp backward pass."""
op = SimpleLinearOp(in_features=10, out_features=5)
input_tensor = torch.randn(3, 10, requires_grad=True)
fuser = OperationFuser([op])
output = fuser(input_tensor)
grad_output = torch.randn_like(output)
output.backward(grad_output)
assert input_tensor.grad is not None, "Input gradient is None"
assert input_tensor.grad.shape == input_tensor.shape, "Input gradient shape mismatch"
assert op.weight.grad is not None, "Weight gradient is None"
assert op.weight.grad.shape == op.weight.shape, "Weight gradient shape mismatch"
assert op.bias.grad is not None, "Bias gradient is None"
assert op.bias.grad.shape == op.bias.shape, "Bias gradient shape mismatch"
def test_simple_relu_forward(self):
"""Test SimpleReLIOp forward pass."""
op = SimpleReLIOp()
input_tensor = torch.randn(3, 10)
fuser = OperationFuser([op])
output = fuser(input_tensor)
assert output.shape == input_tensor.shape, "Output shape mismatch"
expected = torch.nn.functional.relu(input_tensor)
assert torch.allclose(output, expected, atol=1e-6), "Forward pass result mismatch"
def test_simple_relu_backward(self):
"""Test SimpleReLIOp backward pass."""
op = SimpleReLIOp()
input_tensor = torch.randn(3, 10, requires_grad=True)
fuser = OperationFuser([op])
output = fuser(input_tensor)
grad_output = torch.randn_like(output)
output.backward(grad_output)
assert input_tensor.grad is not None, "Input gradient is None"
assert input_tensor.grad.shape == input_tensor.shape, "Input gradient shape mismatch"
class TestOperationFusion:
"""Test operation fusion with multiple operations."""
def test_linear_relu_fusion_forward(self):
"""Test fusing Linear + ReLU operations."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu_op = SimpleReLIOp()
input_tensor = torch.randn(3, 10)
fuser = OperationFuser([linear_op, relu_op])
output = fuser(input_tensor)
assert output.shape == (3, 5), f"Expected shape (3, 5), got {output.shape}"
expected_linear = torch.nn.functional.linear(input_tensor, linear_op.weight, linear_op.bias)
expected = torch.nn.functional.relu(expected_linear)
assert torch.allclose(output, expected, atol=1e-6), "Fused forward pass result mismatch"
def test_linear_relu_fusion_backward(self):
"""Test backward pass with fused Linear + ReLU."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu_op = SimpleReLIOp()
input_tensor = torch.randn(3, 10, requires_grad=True)
fuser = OperationFuser([linear_op, relu_op])
output = fuser(input_tensor)
grad_output = torch.randn_like(output)
output.backward(grad_output)
assert input_tensor.grad is not None, "Input gradient is None"
assert input_tensor.grad.shape == input_tensor.shape, "Input gradient shape mismatch"
assert linear_op.weight.grad is not None, "Weight gradient is None"
assert linear_op.weight.grad.shape == linear_op.weight.shape, "Weight gradient shape mismatch"
def test_fused_operation_class(self):
"""Test FusedOperation class."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu_op = SimpleReLIOp()
fused_op = FusedOperation([linear_op, relu_op])
assert fused_op.is_fused_op is True, "is_fused_op should be True"
assert len(fused_op.basic_ops) == 2, "Should have 2 basic operations"
input_tensor = torch.randn(3, 10)
output = fused_op(input_tensor)
assert output.shape == (3, 5), f"Expected shape (3, 5), got {output.shape}"
class TestHelperFunctions:
"""Test helper functions."""
def test_is_graph_capturing(self):
"""Test _is_graph_capturing always returns False."""
assert _is_graph_capturing() is False
def test_split_tuple(self):
"""Test _split_tuple function."""
t = (1, 2, 3, 4, 5)
left, right = _split_tuple(t, 0)
assert left == ()
assert right == (1, 2, 3, 4, 5)
left, right = _split_tuple(t, 2)
assert left == (1, 2)
assert right == (3, 4, 5)
left, right = _split_tuple(t, 5)
assert left == (1, 2, 3, 4, 5)
assert right == ()
class TestOperationContext:
"""Test OperationContext class."""
def test_save_for_backward(self):
"""Test save_for_backward method."""
ctx = OperationContext()
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(5, 6)
ctx.save_for_backward(tensor1, tensor2)
assert ctx.to_save is not None
assert len(ctx.to_save) == 2
assert ctx.to_save[0] is tensor1
assert ctx.to_save[1] is tensor2
if __name__ == "__main__":
pytest.main([__file__, "-v"])