"""Unit tests for sequential.py."""
import pytest
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from transformer_engine.pytorch.ops.sequential import Sequential
from transformer_engine.pytorch.ops.op import (
OperationContext,
BasicOperation,
)
from transformer_engine.pytorch.ops.fuser import OperationFuser
class SimpleLinearOp(BasicOperation):
"""A simple linear operation for testing: y = x * weight + bias"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))
def num_quantizers(self, mode: str) -> int:
"""No quantization for this simple op."""
return 0
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer=None,
next_op_input_quantizer=None,
**kwargs,
) -> torch.Tensor:
"""Forward pass: y = x @ weight.T + bias"""
ctx.save_for_backward(input_)
output = torch.nn.functional.linear(input_, self.weight, self.bias)
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
):
"""Backward pass."""
input_ = ctx.saved_tensors[0]
grad_input = grad_output @ self.weight
grad_weight = grad_output.T @ input_
grad_bias = grad_output.sum(dim=0)
return grad_input, [grad_weight, grad_bias]
class SimpleReLIOp(BasicOperation):
"""A simple ReLU operation for testing."""
def __init__(self):
super().__init__()
def num_quantizers(self, mode: str) -> int:
"""No quantization for this simple op."""
return 0
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer=None,
next_op_input_quantizer=None,
**kwargs,
) -> torch.Tensor:
"""Forward pass: y = ReLU(x)"""
ctx.save_for_backward(input_)
output = torch.nn.functional.relu(input_)
return output
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
):
"""Backward pass."""
input_ = ctx.saved_tensors[0]
grad_input = grad_output * (input_ > 0).float()
return grad_input, []
class TestSequentialBasic:
"""Test basic Sequential container functionality."""
def test_empty_sequential(self):
"""Test empty Sequential container."""
seq = Sequential()
assert len(seq) == 0
assert list(seq) == []
def test_single_module(self):
"""Test Sequential with a single module."""
linear = torch.nn.Linear(10, 5)
seq = Sequential(linear)
assert len(seq) == 1
assert seq[0] is linear
def test_multiple_modules(self):
"""Test Sequential with multiple modules."""
linear1 = torch.nn.Linear(10, 5)
relu = torch.nn.ReLU()
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, relu, linear2)
assert len(seq) == 3
assert seq[0] is linear1
assert seq[1] is relu
assert seq[2] is linear2
def test_sequential_with_dict(self):
"""Test Sequential initialized with dict."""
linear1 = torch.nn.Linear(10, 5)
relu = torch.nn.ReLU()
seq = Sequential({'linear': linear1, 'relu': relu})
assert len(seq) == 2
assert seq.linear is linear1
assert seq.relu is relu
class TestSequentialOperations:
"""Test Sequential container operations."""
def test_append(self):
"""Test append operation."""
seq = Sequential()
linear = torch.nn.Linear(10, 5)
result = seq.append(linear)
assert result is seq
assert len(seq) == 1
assert seq[0] is linear
def test_extend(self):
"""Test extend operation."""
seq = Sequential()
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
result = seq.extend([linear1, linear2])
assert result is seq
assert len(seq) == 2
assert seq[0] is linear1
assert seq[1] is linear2
def test_insert(self):
"""Test insert operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
relu = torch.nn.ReLU()
seq = Sequential(linear1, linear2)
result = seq.insert(1, relu)
assert result is seq
assert len(seq) == 3
assert seq[0] is linear1
assert seq[1] is relu
assert seq[2] is linear2
def test_pop(self):
"""Test pop operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, linear2)
popped = seq.pop(0)
assert popped is linear1
assert len(seq) == 1
assert seq[0] is linear2
def test_getitem_slice(self):
"""Test getitem with slice."""
linear1 = torch.nn.Linear(10, 5)
relu = torch.nn.ReLU()
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, relu, linear2)
sub_seq = seq[0:2]
assert isinstance(sub_seq, Sequential)
assert len(sub_seq) == 2
assert sub_seq[0] is linear1
assert sub_seq[1] is relu
def test_setitem(self):
"""Test setitem operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
relu = torch.nn.ReLU()
seq = Sequential(linear1, linear2)
seq[1] = relu
assert seq[1] is relu
def test_delitem(self):
"""Test delitem operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, linear2)
del seq[0]
assert len(seq) == 1
assert seq[0] is linear2
def test_iadd(self):
"""Test in-place add operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq1 = Sequential(linear1)
seq2 = Sequential(linear2)
seq1 += seq2
assert len(seq1) == 2
assert seq1[0] is linear1
assert seq1[1] is linear2
def test_add(self):
"""Test add operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq1 = Sequential(linear1)
seq2 = Sequential(linear2)
result = seq1 + seq2
assert len(result) == 2
assert result[0] is linear1
assert result[1] is linear2
assert len(seq1) == 1
assert len(seq2) == 1
class TestSequentialForward:
"""Test Sequential forward pass."""
def test_forward_with_torch_modules(self):
"""Test forward pass with standard PyTorch modules."""
linear = torch.nn.Linear(10, 5)
relu = torch.nn.ReLU()
seq = Sequential(linear, relu)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert output.shape == (3, 5)
expected = relu(linear(input_tensor))
assert torch.allclose(output, expected, atol=1e-6)
def test_forward_with_fusible_operations(self):
"""Test forward pass with FusibleOperation instances."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu_op = SimpleReLIOp()
seq = Sequential(linear_op, relu_op)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert output.shape == (3, 5)
expected_linear = torch.nn.functional.linear(input_tensor, linear_op.weight, linear_op.bias)
expected = torch.nn.functional.relu(expected_linear)
assert torch.allclose(output, expected, atol=1e-6)
def test_forward_with_mixed_modules(self):
"""Test forward pass with mixed PyTorch and FusibleOperation modules."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu = torch.nn.ReLU()
seq = Sequential(linear_op, relu)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert output.shape == (3, 5)
expected_linear = torch.nn.functional.linear(input_tensor, linear_op.weight, linear_op.bias)
expected = torch.nn.functional.relu(expected_linear)
assert torch.allclose(output, expected, atol=1e-6)
def test_forward_backward(self):
"""Test forward and backward pass."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
relu_op = SimpleReLIOp()
seq = Sequential(linear_op, relu_op)
input_tensor = torch.randn(3, 10, requires_grad=True)
output = seq(input_tensor)
grad_output = torch.randn_like(output)
output.backward(grad_output)
assert input_tensor.grad is not None
assert input_tensor.grad.shape == input_tensor.shape
assert linear_op.weight.grad is not None
assert linear_op.weight.grad.shape == linear_op.weight.shape
assert linear_op.bias.grad is not None
assert linear_op.bias.grad.shape == linear_op.bias.shape
class TestSequentialModuleGrouping:
"""Test Sequential module grouping for fusion."""
def test_module_groups_creation(self):
"""Test that module groups are created correctly."""
linear_op1 = SimpleLinearOp(in_features=10, out_features=5)
linear_op2 = SimpleLinearOp(in_features=5, out_features=3)
relu = torch.nn.ReLU()
seq = Sequential(linear_op1, linear_op2, relu)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert seq._module_groups is not None
assert len(seq._module_groups) == 2
assert isinstance(seq._module_groups[0], OperationFuser)
assert seq._module_groups[1] is relu
def test_module_groups_with_intervening_non_fusible(self):
"""Test module grouping with non-fusible operations in between."""
linear_op1 = SimpleLinearOp(in_features=10, out_features=5)
relu = torch.nn.ReLU()
linear_op2 = SimpleLinearOp(in_features=5, out_features=3)
seq = Sequential(linear_op1, relu, linear_op2)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert seq._module_groups is not None
assert len(seq._module_groups) == 3
assert isinstance(seq._module_groups[0], OperationFuser)
assert seq._module_groups[1] is relu
assert isinstance(seq._module_groups[2], OperationFuser)
class TestSequentialEdgeCases:
"""Test edge cases and error handling."""
def test_negative_index(self):
"""Test negative indexing."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, linear2)
assert seq[-1] is linear2
assert seq[-2] is linear1
def test_index_out_of_range(self):
"""Test index out of range error."""
linear = torch.nn.Linear(10, 5)
seq = Sequential(linear)
with pytest.raises(IndexError):
_ = seq[5]
def test_iteration(self):
"""Test iteration over modules."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, linear2)
modules = list(seq)
assert len(modules) == 2
assert modules[0] is linear1
assert modules[1] is linear2
def test_len(self):
"""Test len operation."""
linear1 = torch.nn.Linear(10, 5)
linear2 = torch.nn.Linear(5, 3)
seq = Sequential(linear1, linear2)
assert len(seq) == 2
def test_module_invalidation(self):
"""Test that module groups are invalidated on modification."""
linear_op = SimpleLinearOp(in_features=10, out_features=5)
seq = Sequential(linear_op)
input_tensor = torch.randn(3, 10)
output = seq(input_tensor)
assert seq._module_groups is not None
seq.append(torch.nn.ReLU())
assert seq._module_groups is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])