"""Unit tests for SequenceParallelPass ordered rewrites.
P1: all_reduce -> [begin?] -> rms_norm | add_rms_norm
P2: all_reduce -> add_rms_norm2, selective all_gather on getitems
P3: getitem[1] + all_reduce -> add -> [end? -> copy*] -> norm
Edges: world_size=1, non-shardable candidates, and mixed-shape graphs
"""
import operator
import unittest
import torch
import torch.fx as fx
from torch.fx.passes.shape_prop import ShapeProp
from tensor_cast import config
import tensor_cast.ops
from tensor_cast.compilation.passes.sequence_parallel_pass import (
Pattern1Rewriter,
Pattern2Rewriter,
Pattern3Rewriter,
SequenceParallelPass,
)
INPUT_SHAPE = (1, 128, 4096)
LOCAL_INPUT_SHAPE = (1, 64, 4096)
WEIGHT_SHAPE = (4096,)
RANK_GROUP = [0, 1]
BAD_RANK_GROUP = [0, 1, 2]
EPS = 1e-5
def _meta_tensor(shape):
return torch.empty(shape, device="meta")
def _make_meta_value(tensor_meta):
if hasattr(tensor_meta, "shape"):
dtype = getattr(tensor_meta, "dtype", torch.float32)
return torch.empty(tuple(tensor_meta.shape), dtype=dtype, device="meta")
if isinstance(tensor_meta, (tuple, list)):
if not tensor_meta:
return None
tensor_meta = tensor_meta[0]
return _make_meta_value(tensor_meta)
return None
def _populate_val_metadata(graph_module):
for node in graph_module.graph.nodes:
if "val" in node.meta:
continue
meta_value = _make_meta_value(node.meta.get("tensor_meta"))
if meta_value is not None:
node.meta["val"] = meta_value
return graph_module
def _trace_graph_module(module, *meta_inputs):
graph_module = fx.symbolic_trace(module)
ShapeProp(graph_module).propagate(*meta_inputs)
return _populate_val_metadata(graph_module)
def _trace_program(module, *input_shapes):
return _trace_graph_module(module, *(_meta_tensor(shape) for shape in input_shapes))
def _run_pass(graph_module):
saved = config.compilation.passes.enable_sequence_parallel
config.compilation.passes.enable_sequence_parallel = True
try:
graph_module = SequenceParallelPass()(graph_module)
finally:
config.compilation.passes.enable_sequence_parallel = saved
return graph_module
def _count_calls(graph_module, target):
return sum(1 for node in graph_module.graph.nodes if node.op == "call_function" and node.target is target)
def _find_calls(graph_module, target):
return [node for node in graph_module.graph.nodes if node.op == "call_function" and node.target is target]
def _has_user(node, target):
return any(user.op == "call_function" and user.target is target for user in node.users)
def _find_getitems(graph_module, index):
return [
node
for node in graph_module.graph.nodes
if node.op == "call_function" and node.target is operator.getitem and node.args[1] == index
]
def _mark_add_rms_norm2_sp_local(graph_module):
[norm2] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)
norm2.meta["tensor_cast_sp_local"] = True
return norm2
def _get_node_index(graph_module, node):
for index, current_node in enumerate(graph_module.graph.nodes):
if current_node is node:
return index
raise AssertionError("node not found in graph")
def _trace_p1_graph(
*,
marker=False,
add_rms_norm=False,
rank_group=RANK_GROUP,
input_shape=INPUT_SHAPE,
residual_shape=INPUT_SHAPE,
):
class Program(torch.nn.Module):
def forward(self, x, residual, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, rank_group)
current = torch.ops.tensor_cast._internal_mark_region_begin.default(all_reduce, 0) if marker else all_reduce
if add_rms_norm:
return torch.ops.tensor_cast.add_rms_norm.default(current, residual, w, EPS)
return torch.ops.tensor_cast.rms_norm.default(current, w, EPS)
return _trace_program(Program(), input_shape, residual_shape, WEIGHT_SHAPE)
def _trace_p2_p3_graph(
*,
begin_on_res=False,
with_end=True,
with_view=False,
num_copies=0,
tail_rank_group=RANK_GROUP,
tail_input_shape=INPUT_SHAPE,
):
class Program(torch.nn.Module):
def forward(self, x, res, x2, w, w2):
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
res_input = torch.ops.tensor_cast._internal_mark_region_begin.default(res, 0) if begin_on_res else res
getitem0, getitem1 = torch.ops.tensor_cast.add_rms_norm2.default(
all_reduce_1,
res_input,
w,
EPS,
)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(x2, 0, tail_rank_group)
comm_output = torch.ops.aten.reshape.default(all_reduce_2, list(INPUT_SHAPE)) if with_view else all_reduce_2
current = torch.ops.aten.add.Tensor(getitem1, comm_output)
if with_end:
current = torch.ops.tensor_cast._internal_mark_region_end.default(current, 0)
for index in range(num_copies):
current = torch.ops.tensor_cast._internal_copy_region.default(current, index)
tail = torch.ops.tensor_cast.rms_norm.default(current, w2, EPS)
return getitem0, tail
return _trace_program(
Program(),
INPUT_SHAPE,
LOCAL_INPUT_SHAPE,
tail_input_shape,
WEIGHT_SHAPE,
WEIGHT_SHAPE,
)
class Pattern1RewriterTestCase(unittest.TestCase):
def test_pattern1_rewriter_apply(self):
graph_module = _trace_p1_graph()
rewritten = Pattern1Rewriter().apply(graph_module.graph)
self.assertEqual(rewritten, 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 1)
def test_p1_with_marker(self):
graph_module = _run_pass(_trace_p1_graph(marker=True))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 1)
[reduce_scatter] = _find_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default)
[begin_node] = _find_calls(graph_module, torch.ops.tensor_cast._internal_mark_region_begin.default)
[norm_node] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
[all_gather] = _find_calls(graph_module, torch.ops.tensor_cast.all_gather.default)
self.assertIs(begin_node.args[0], reduce_scatter)
self.assertIs(norm_node.args[0], begin_node)
self.assertIs(all_gather.args[0], norm_node)
def test_p1_without_marker(self):
graph_module = _run_pass(_trace_p1_graph())
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
[reduce_scatter] = _find_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default)
[norm_node] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
[all_gather] = _find_calls(graph_module, torch.ops.tensor_cast.all_gather.default)
self.assertIs(norm_node.args[0], reduce_scatter)
self.assertIs(all_gather.args[0], norm_node)
def test_p1_add_rms_norm(self):
graph_module = _run_pass(_trace_p1_graph(marker=True, add_rms_norm=True))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
[norm_node] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm.default)
self.assertTrue(_has_user(norm_node, torch.ops.tensor_cast.all_gather.default))
def test_p1_dim_mismatch_view_uses_comm_output_shard_dim(self):
graph_module = _trace_p1_graph(input_shape=(1, 256, 4096))
[all_reduce] = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
[norm] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
all_reduce.meta["val"] = _meta_tensor((256, 4096))
norm.meta["val"] = _meta_tensor((256, 4096))
graph_module = _run_pass(graph_module)
[view] = _find_calls(graph_module, torch.ops.aten.view.default)
self.assertEqual(view.args[1], [128, 4096])
class Pattern2RewriterTestCase(unittest.TestCase):
def test_pattern2_rewriter_apply(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
return torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)[0]
graph_module = _trace_program(Program(), INPUT_SHAPE, LOCAL_INPUT_SHAPE, WEIGHT_SHAPE)
rewritten = Pattern2Rewriter().apply(graph_module.graph)
self.assertEqual(rewritten, 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 1)
def test_p2_skips_when_other_input_is_full_shape(self):
class Program(torch.nn.Module):
def forward(self, x, full_residual, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, full_residual, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
[norm2_new] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)
[all_reduce] = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
full_residual = next(node for node in graph_module.graph.nodes if node.name == "full_residual")
self.assertIs(norm2_new.args[0], all_reduce)
self.assertIs(norm2_new.args[1], full_residual)
def test_p2_skips_when_other_input_is_all_gathered_full_shape(self):
class Program(torch.nn.Module):
def forward(self, x, local_src, w):
full_residual = torch.ops.tensor_cast.all_gather.default(local_src, 1, 0, RANK_GROUP)
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, full_residual, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, LOCAL_INPUT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 1)
[norm2_new] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)
[all_reduce] = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
[full_residual] = _find_calls(graph_module, torch.ops.tensor_cast.all_gather.default)
self.assertIs(norm2_new.args[0], all_reduce)
self.assertIs(norm2_new.args[1], full_residual)
def test_p2_skips_bias_like_residual_shape(self):
class Program(torch.nn.Module):
def forward(self, x, bias, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, bias, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, WEIGHT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def test_p3_skips_tail_when_source_p2_was_not_localized(self):
class Program(torch.nn.Module):
def forward(self, x, full_residual, x2, w, w2):
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
_, residual = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce_1, full_residual, w, EPS)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(x2, 0, RANK_GROUP)
add_node = torch.ops.aten.add.Tensor(residual, all_reduce_2)
return torch.ops.tensor_cast.rms_norm.default(add_node, w2, EPS)
graph_module = _run_pass(
_trace_program(
Program(),
INPUT_SHAPE,
INPUT_SHAPE,
INPUT_SHAPE,
WEIGHT_SHAPE,
WEIGHT_SHAPE,
)
)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 2)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def _assert_p2_p3_middle(self, graph_module):
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
self.assertGreaterEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 2)
for getitem1 in _find_getitems(graph_module, 1):
self.assertFalse(
_has_user(getitem1, torch.ops.tensor_cast.all_gather.default),
"gi[1] must NOT be gathered",
)
self.assertTrue(
_has_user(getitem1, torch.ops.aten.add.Tensor),
"gi[1] must directly feed an add node",
)
norms = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
self.assertTrue(
any(_has_user(norm_node, torch.ops.tensor_cast.all_gather.default) for norm_node in norms),
"tail rms_norm must feed all_gather",
)
def test_p2_middle_no_marker(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=True))
self._assert_p2_p3_middle(graph_module)
def test_p2_middle_with_marker(self):
graph_module = _run_pass(_trace_p2_p3_graph(begin_on_res=True, with_end=True))
self._assert_p2_p3_middle(graph_module)
def test_p2_residual_fanout(self):
class Program(torch.nn.Module):
def forward(self, x, res, other, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, residual = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
add_node = torch.ops.aten.add.Tensor(residual, other)
return out, add_node
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, LOCAL_INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE))
for node in _find_getitems(graph_module, 1):
self.assertTrue(_has_user(node, torch.ops.tensor_cast.all_gather.default))
def test_p2_dual_all_reduce(self):
class Program(torch.nn.Module):
def forward(self, x1, x2, w):
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x1, 0, RANK_GROUP)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(x2, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce_1, all_reduce_2, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 2)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
[norm2_new] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)
all_reduce_nodes = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
self.assertIn(norm2_new.args[0], all_reduce_nodes)
self.assertIn(norm2_new.args[1], all_reduce_nodes)
self.assertIsNot(norm2_new.args[0], norm2_new.args[1])
def test_p2_marker_wrapped_comm(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
begin = torch.ops.tensor_cast._internal_mark_region_begin.default(all_reduce, 0)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(begin, res, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default), 1)
[norm2_new] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)
[begin_new] = _find_calls(graph_module, torch.ops.tensor_cast._internal_mark_region_begin.default)
[all_reduce_new] = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
self.assertIs(norm2_new.args[0], begin_new)
self.assertIs(begin_new.args[0], all_reduce_new)
def test_p2_2d_shard_dim(self):
shape_2d = (256, 4096)
local_shape_2d = (128, 4096)
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, residual = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
return out, residual
graph_module = _run_pass(_trace_program(Program(), shape_2d, local_shape_2d, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
[reduce_scatter] = _find_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default)
self.assertEqual(reduce_scatter.args[1], 0)
for node in _find_getitems(graph_module, 1):
self.assertTrue(_has_user(node, torch.ops.tensor_cast.all_gather.default))
for node in _find_getitems(graph_module, 0):
self.assertTrue(_has_user(node, torch.ops.tensor_cast.all_gather.default))
all_gathers = _find_calls(graph_module, torch.ops.tensor_cast.all_gather.default)
self.assertGreaterEqual(len(all_gathers), 2)
for all_gather in all_gathers:
self.assertEqual(all_gather.args[1], 0)
def test_p2_last(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, LOCAL_INPUT_SHAPE, WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
for node in _find_getitems(graph_module, 0):
self.assertTrue(_has_user(node, torch.ops.tensor_cast.all_gather.default))
def test_p2_reduce_scatter_inserted_after_all_reduce(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
return out
graph_module = _trace_program(Program(), INPUT_SHAPE, LOCAL_INPUT_SHAPE, WEIGHT_SHAPE)
[all_reduce] = _find_calls(graph_module, torch.ops.tensor_cast.all_reduce.default)
Pattern2Rewriter().apply(graph_module.graph)
[reduce_scatter] = _find_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default)
self.assertLess(
_get_node_index(graph_module, all_reduce),
_get_node_index(graph_module, reduce_scatter),
)
def test_markerless_shared_entry_all_reduce_and_fused_tail(self):
class Program(torch.nn.Module):
def forward(self, x, x2, x3, w, w2, w3):
all_reduce_0 = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
torch.ops.tensor_cast.rms_norm.default(all_reduce_0, w, EPS)
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x2, 0, RANK_GROUP)
getitem0, getitem1 = torch.ops.tensor_cast.add_rms_norm2.default(
all_reduce_0,
all_reduce_1,
w2,
EPS,
)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(x3, 0, RANK_GROUP)
tail = torch.ops.tensor_cast.add_rms_norm.default(getitem1, all_reduce_2, w3, EPS)
return getitem0, tail
graph_module = _run_pass(
_trace_program(
Program(),
INPUT_SHAPE,
INPUT_SHAPE,
INPUT_SHAPE,
WEIGHT_SHAPE,
WEIGHT_SHAPE,
WEIGHT_SHAPE,
)
)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 3)
self.assertGreaterEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 2)
for node in _find_getitems(graph_module, 0):
self.assertTrue(_has_user(node, torch.ops.tensor_cast.all_gather.default))
for node in _find_getitems(graph_module, 1):
self.assertFalse(_has_user(node, torch.ops.tensor_cast.all_gather.default))
[tail_new] = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm.default)
self.assertTrue(_has_user(tail_new, torch.ops.tensor_cast.all_gather.default))
def test_p2_residual_chained_into_next_p2_stays_local(self):
class Program(torch.nn.Module):
def forward(self, res0, x1, x2, w1, w2):
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x1, 0, RANK_GROUP)
_, residual = torch.ops.tensor_cast.add_rms_norm2.default(res0, all_reduce_1, w1, EPS)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(x2, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(residual, all_reduce_2, w2, EPS)
return out
graph_module = _run_pass(
_trace_program(
Program(),
LOCAL_INPUT_SHAPE,
INPUT_SHAPE,
INPUT_SHAPE,
WEIGHT_SHAPE,
WEIGHT_SHAPE,
)
)
chained_getitem1 = _find_getitems(graph_module, 1)[0]
self.assertFalse(
_has_user(chained_getitem1, torch.ops.tensor_cast.all_gather.default),
"intermediate residual feeding next P2 must stay local",
)
norm2_2_new = _find_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default)[1]
self.assertIs(norm2_2_new.args[0], chained_getitem1)
def test_p2_gathers_residual_when_downstream_p2_is_not_shardable(self):
class Program(torch.nn.Module):
def forward(self, x, res, bad_x, w1, w2):
all_reduce_1 = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out1, residual = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce_1, res, w1, EPS)
all_reduce_2 = torch.ops.tensor_cast.all_reduce.default(bad_x, 0, RANK_GROUP)
out2, _ = torch.ops.tensor_cast.add_rms_norm2.default(residual, all_reduce_2, w2, EPS)
return out1, out2
graph_module = _run_pass(
_trace_program(
Program(),
INPUT_SHAPE,
LOCAL_INPUT_SHAPE,
(1, 127, 4096),
WEIGHT_SHAPE,
WEIGHT_SHAPE,
)
)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
first_residual = _find_getitems(graph_module, 1)[0]
self.assertTrue(_has_user(first_residual, torch.ops.tensor_cast.all_gather.default))
def test_p2_gathers_residual_when_downstream_p3_is_not_shardable(self):
graph_module = _run_pass(_trace_p2_p3_graph(tail_rank_group=BAD_RANK_GROUP))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
[residual] = _find_getitems(graph_module, 1)
self.assertTrue(_has_user(residual, torch.ops.tensor_cast.all_gather.default))
class Pattern3RewriterTestCase(unittest.TestCase):
def _assert_p3_base(self, graph_module):
self.assertEqual(
_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default),
0,
"no all_reduce should remain",
)
self.assertGreaterEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 2)
for getitem1 in _find_getitems(graph_module, 1):
self.assertFalse(
_has_user(getitem1, torch.ops.tensor_cast.all_gather.default),
"gi[1] must NOT be gathered",
)
self.assertTrue(_has_user(getitem1, torch.ops.aten.add.Tensor), "gi[1] should feed add")
norms = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
self.assertTrue(
any(_has_user(node, torch.ops.tensor_cast.all_gather.default) for node in norms),
"tail norm must feed all_gather",
)
def _get_add_comm_input(self, graph_module):
[add_node] = _find_calls(graph_module, torch.ops.aten.add.Tensor)
getitem_ids = {id(node) for node in _find_getitems(graph_module, 1)}
for arg in add_node.args:
if id(arg) not in getitem_ids:
return arg
self.fail("add node has no non-gi1 input")
def test_p3_with_marker(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=True))
self._assert_p3_base(graph_module)
comm = self._get_add_comm_input(graph_module)
self.assertIs(comm.target, torch.ops.tensor_cast.reduce_scatter.default)
[add_node] = _find_calls(graph_module, torch.ops.aten.add.Tensor)
end_users = [
user
for user in add_node.users
if hasattr(user, "target") and user.target is torch.ops.tensor_cast._internal_mark_region_end.default
]
self.assertEqual(len(end_users), 1, "add must feed exactly one region_end")
end_node = end_users[0]
self.assertIs(end_node.args[0], add_node)
[norm] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
self.assertIs(norm.args[0], end_node, "rms_norm must consume region_end")
def test_p3_without_marker(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=False))
self._assert_p3_base(graph_module)
comm = self._get_add_comm_input(graph_module)
self.assertIs(comm.target, torch.ops.tensor_cast.reduce_scatter.default)
[add_node] = _find_calls(graph_module, torch.ops.aten.add.Tensor)
[norm] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
self.assertIs(norm.args[0], add_node, "norm must consume add directly")
def test_p3_with_view(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_view=True))
self._assert_p3_base(graph_module)
comm = self._get_add_comm_input(graph_module)
self.assertIs(comm.target, torch.ops.aten.reshape.default)
self.assertIs(comm.args[0].target, torch.ops.tensor_cast.reduce_scatter.default)
def test_p3_no_marker_copy_region(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=False, num_copies=2))
self._assert_p3_base(graph_module)
self.assertEqual(
_count_calls(graph_module, torch.ops.tensor_cast._internal_mark_region_end.default),
0,
)
[norm] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
copy1 = norm.args[0]
self.assertIs(copy1.target, torch.ops.tensor_cast._internal_copy_region.default)
copy0 = copy1.args[0]
self.assertIs(copy0.target, torch.ops.tensor_cast._internal_copy_region.default)
[add_node] = _find_calls(graph_module, torch.ops.aten.add.Tensor)
self.assertIs(copy0.args[0], add_node)
def test_p3_with_copy_region_chain(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=True, num_copies=2))
self._assert_p3_base(graph_module)
[norm] = _find_calls(graph_module, torch.ops.tensor_cast.rms_norm.default)
copy1 = norm.args[0]
self.assertIs(copy1.target, torch.ops.tensor_cast._internal_copy_region.default)
copy0 = copy1.args[0]
self.assertIs(copy0.target, torch.ops.tensor_cast._internal_copy_region.default)
end_node = copy0.args[0]
self.assertIs(end_node.target, torch.ops.tensor_cast._internal_mark_region_end.default)
[add_node] = _find_calls(graph_module, torch.ops.aten.add.Tensor)
self.assertIs(end_node.args[0], add_node)
def test_p3_reduce_scatter_inserted_after_all_reduce(self):
graph_module = _run_pass(_trace_p2_p3_graph(with_end=True))
reduce_scatter_nodes = _find_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default)
add_comm = self._get_add_comm_input(graph_module)
p3_reduce_scatter = (
add_comm if add_comm.target is torch.ops.tensor_cast.reduce_scatter.default else add_comm.args[0]
)
reduce_scatter_positions = [_get_node_index(graph_module, node) for node in reduce_scatter_nodes]
p3_position = _get_node_index(graph_module, p3_reduce_scatter)
entry_position = min(pos for pos in reduce_scatter_positions if pos != p3_position)
self.assertGreater(p3_position, entry_position)
class SequenceParallelPassEdgeTestCase(unittest.TestCase):
def test_sequence_parallel_pass_owns_rewriters(self):
sequence_parallel_pass = SequenceParallelPass()
self.assertIsInstance(sequence_parallel_pass._p1_rewriter, Pattern1Rewriter)
self.assertIsInstance(sequence_parallel_pass._p2_rewriter, Pattern2Rewriter)
self.assertIsInstance(sequence_parallel_pass._p3_rewriter, Pattern3Rewriter)
def test_no_tp(self):
graph_module = _run_pass(_trace_p1_graph(rank_group=[0]))
self.assertEqual(
_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default),
1,
"should not rewrite when world_size=1",
)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def test_skip_when_shard_dim_not_divisible(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
out, _ = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
return out
graph_module = _run_pass(_trace_program(Program(), (1, 127, 4096), (1, 127, 4096), WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def test_mixed_shardability_rewrites_valid_p1_candidate(self):
class Program(torch.nn.Module):
def forward(self, x, bad_x, w):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
norm = torch.ops.tensor_cast.rms_norm.default(all_reduce, w, EPS)
bad_all_reduce = torch.ops.tensor_cast.all_reduce.default(bad_x, 0, RANK_GROUP)
return norm, bad_all_reduce
graph_module = _run_pass(_trace_program(Program(), INPUT_SHAPE, (1, 127, 4096), WEIGHT_SHAPE))
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_reduce.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 1)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 1)
class SequenceParallelPassNegativeTestCase(unittest.TestCase):
def test_p1_all_reduce_without_norm_consumer(self):
class Program(torch.nn.Module):
def forward(self, x, y):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
return torch.ops.aten.add.Tensor(all_reduce, y)
graph_module = _trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE)
matched = Pattern1Rewriter().apply(graph_module.graph)
self.assertEqual(matched, 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def test_p1_rms_norm_without_all_reduce_source(self):
class Program(torch.nn.Module):
def forward(self, x, w):
return torch.ops.tensor_cast.rms_norm.default(x, w, EPS)
graph_module = _trace_program(Program(), INPUT_SHAPE, WEIGHT_SHAPE)
matched = Pattern1Rewriter().apply(graph_module.graph)
self.assertEqual(matched, 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.all_gather.default), 0)
def test_p2_add_rms_norm2_without_all_reduce(self):
class Program(torch.nn.Module):
def forward(self, x, res, w):
out, residual = torch.ops.tensor_cast.add_rms_norm2.default(x, res, w, EPS)
return out, residual
graph_module = _trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE)
matched = Pattern2Rewriter().apply(graph_module.graph)
self.assertEqual(matched, 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.reduce_scatter.default), 0)
self.assertEqual(_count_calls(graph_module, torch.ops.tensor_cast.add_rms_norm2.default), 1)
def test_p3_add_sibling_is_plain_placeholder(self):
class Program(torch.nn.Module):
def forward(self, x, res, plain, w, w2):
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
_, residual = torch.ops.tensor_cast.add_rms_norm2.default(all_reduce, res, w, EPS)
add_node = torch.ops.aten.add.Tensor(residual, plain)
return torch.ops.tensor_cast.rms_norm.default(add_node, w2, EPS)
graph_module = _trace_program(Program(), INPUT_SHAPE, INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE, WEIGHT_SHAPE)
_mark_add_rms_norm2_sp_local(graph_module)
matched = Pattern3Rewriter().apply(graph_module.graph)
self.assertEqual(matched, 0)
def test_p3_add_residual_not_from_add_rms_norm2(self):
class Program(torch.nn.Module):
def forward(self, x, y, w):
parts = torch.ops.aten.split.Tensor(y, 64, 1)
residual = parts[1]
all_reduce = torch.ops.tensor_cast.all_reduce.default(x, 0, RANK_GROUP)
add_node = torch.ops.aten.add.Tensor(residual, all_reduce)
return torch.ops.tensor_cast.rms_norm.default(add_node, w, EPS)
graph_module = _trace_program(Program(), LOCAL_INPUT_SHAPE, INPUT_SHAPE, WEIGHT_SHAPE)
[residual] = _find_getitems(graph_module, 1)
self.assertIsNot(residual.args[0].target, torch.ops.tensor_cast.add_rms_norm2.default)
matched = Pattern3Rewriter().apply(graph_module.graph)
self.assertEqual(matched, 0)
if __name__ == "__main__":
unittest.main()