"""Test for sequence parallel pass.
Transforms: all_reduce → rms_norm => reduce_scatter → rms_norm(local) → all_gather
This matches the expected sequence parallel communication rewrite on NPU.
"""
import unittest
from dataclasses import asdict
import pytest
import torch
from parameterized import parameterized
from tensor_cast import config
from tensor_cast.core.input_generator import generate_inputs
from tensor_cast.core.model_runner import ModelRunner, ModelRunnerMetrics
from tensor_cast.core.quantization.datatypes import QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.model_config import WordEmbeddingTPMode
@pytest.mark.nightly
class SequenceParallelPassTestCase(unittest.TestCase):
"""Test sequence parallel pass transforms all_reduce+norm patterns."""
def setUp(self):
torch.compiler.reset()
self._orig_enable_sequence_parallel = config.compilation.passes.enable_sequence_parallel
def tearDown(self):
config.compilation.passes.enable_sequence_parallel = self._orig_enable_sequence_parallel
@parameterized.expand(
[
(2, 64, False),
(2, 64, True),
]
)
def test_sp_reduces_rms_norm_seq_dim(self, tp_size: int, expected_local_seq: int, disable_repetition: bool):
"""Verify rms_norm operates on reduced seq length with sequence parallel enabled."""
config.compilation.passes.enable_sequence_parallel = True
user_input = UserInputConfig(
model_id="Qwen/Qwen3-32B",
num_queries=1,
query_len=128,
context_length=0,
do_compile=True,
dump_input_shapes=True,
enable_sequence_parallel=True,
disable_repetition=disable_repetition,
num_mtp_tokens=0,
num_hidden_layers_override=1,
world_size=tp_size,
tp_size=tp_size,
word_embedding_tp=WordEmbeddingTPMode.row.value,
quantize_linear_action=QuantizeLinearAction.DISABLED,
)
model_runner = ModelRunner(user_input)
result = model_runner.run_inference(generate_inputs_func=generate_inputs)
if isinstance(result, ModelRunnerMetrics):
result = asdict(result)
table = result["table_result"]
self.assertIn("tensor_cast.rms_norm.default", table)
self.assertIn(
f"[1, {expected_local_seq}, 5120], [5120]",
table,
"Sequence parallel should shard the entry rms_norm sequence dimension",
)
if tp_size > 1:
self.assertIn(
"tensor_cast.reduce_scatter.default",
table,
"Sequence parallel mode should have reduce_scatter",
)
self.assertIn(
"tensor_cast.all_gather.default",
table,
"Sequence parallel mode should have all_gather",
)
self.assertNotIn(
"tensor_cast.all_reduce.default",
table,
"Sequence parallel mode should replace all_reduce",
)
else:
self.assertIn(
"tensor_cast.all_reduce.default",
table,
"Non-sequence-parallel mode should have all_reduce",
)
if __name__ == "__main__":
unittest.main()