import torch
import torch.fx as fx
from torch._inductor.pattern_matcher import Arg as _BaseArg, CallFunction, Match
from .... import config, ops
def _arg(name: str):
try:
return _BaseArg(name)
except TypeError:
del name
return _BaseArg()
matmul_allreduce_pattern = CallFunction(
torch.ops.tensor_cast.all_reduce.default,
CallFunction(torch.ops.aten.mm.default, _arg("mat1"), _arg("mat2")),
_arg("rank"),
_arg("rank_group"),
)
static_quant_linear_allreduce_pattern = CallFunction(
torch.ops.tensor_cast.all_reduce.default,
CallFunction(
torch.ops.tensor_cast.static_quant_linear.default,
_arg("x"),
_arg("w"),
_arg("w_scale"),
_arg("w_offset"),
_arg("x_scale"),
_arg("x_offset"),
_arg("bias"),
_arg("out_dtype"),
),
_arg("rank"),
_arg("rank_group"),
)
static_quant_linear_int4_allreduce_pattern = CallFunction(
torch.ops.tensor_cast.all_reduce.default,
CallFunction(
torch.ops.tensor_cast.static_quant_linear_int4.default,
_arg("x"),
_arg("w"),
_arg("w_scale"),
_arg("w_offset"),
_arg("x_scale"),
_arg("x_offset"),
_arg("bias"),
_arg("out_dtype"),
),
_arg("rank"),
_arg("rank_group"),
)
fp8_linear_allreduce_pattern = CallFunction(
torch.ops.tensor_cast.all_reduce.default,
CallFunction(
torch.ops.tensor_cast.fp8_linear.default,
_arg("x"),
_arg("w"),
_arg("x_scale"),
_arg("w_scale"),
_arg("bias"),
_arg("out_dtype"),
),
_arg("rank"),
_arg("rank_group"),
)
mxfp4_linear_allreduce_pattern = CallFunction(
torch.ops.tensor_cast.all_reduce.default,
CallFunction(
torch.ops.tensor_cast.mxfp4_linear.default,
_arg("x"),
_arg("w"),
_arg("x_scale"),
_arg("w_scale"),
_arg("bias"),
_arg("out_dtype"),
),
_arg("rank"),
_arg("rank_group"),
)
def _is_allreduce_after_target(match: Match, linear_target) -> bool:
allreduce_node = match.output_node()
if allreduce_node.target != torch.ops.tensor_cast.all_reduce.default:
return False
if len(allreduce_node.args) != 3:
return False
linear_node = allreduce_node.args[0]
if not isinstance(linear_node, fx.Node):
return False
if linear_node.target != linear_target:
return False
if len(linear_node.users) != 1:
return False
return True
def _is_matmul_allreduce(match: Match) -> bool:
return _is_allreduce_after_target(match, torch.ops.aten.mm.default)
def _is_static_quant_linear_allreduce(match: Match) -> bool:
return _is_allreduce_after_target(match, torch.ops.tensor_cast.static_quant_linear.default)
def _is_static_quant_linear_int4_allreduce(match: Match) -> bool:
return _is_allreduce_after_target(match, torch.ops.tensor_cast.static_quant_linear_int4.default)
def _is_fp8_linear_allreduce(match: Match) -> bool:
return _is_allreduce_after_target(match, torch.ops.tensor_cast.fp8_linear.default)
def _is_mxfp4_linear_allreduce(match: Match) -> bool:
return _is_allreduce_after_target(match, torch.ops.tensor_cast.mxfp4_linear.default)
def _replace_with_fused_op(match: Match, fused_target, fused_args) -> None:
graph = match.graph
allreduce_node = match.output_node()
with graph.inserting_before(allreduce_node):
fused_node = graph.call_function(fused_target, args=fused_args)
fused_node.meta.update(allreduce_node.meta)
allreduce_node.replace_all_uses_with(fused_node)
match.erase_nodes()
def _fuse_matmul_allreduce(match: Match, mat1, mat2, rank, rank_group):
_replace_with_fused_op(
match,
torch.ops.tensor_cast.matmul_all_reduce,
(mat1, mat2, None, rank, rank_group),
)
def _fuse_static_quant_linear_allreduce(
match: Match,
x,
w,
w_scale,
w_offset,
x_scale,
x_offset,
bias,
out_dtype,
rank,
rank_group,
):
_replace_with_fused_op(
match,
torch.ops.tensor_cast.static_quant_linear_all_reduce,
(x, w, w_scale, w_offset, x_scale, x_offset, bias, out_dtype, rank, rank_group),
)
def _fuse_static_quant_linear_int4_allreduce(
match: Match,
x,
w,
w_scale,
w_offset,
x_scale,
x_offset,
bias,
out_dtype,
rank,
rank_group,
):
_replace_with_fused_op(
match,
torch.ops.tensor_cast.static_quant_linear_int4_all_reduce,
(x, w, w_scale, w_offset, x_scale, x_offset, bias, out_dtype, rank, rank_group),
)
def _fuse_fp8_linear_allreduce(
match: Match,
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
rank,
rank_group,
):
_replace_with_fused_op(
match,
torch.ops.tensor_cast.fp8_linear_all_reduce,
(x, w, x_scale, w_scale, bias, out_dtype, rank, rank_group),
)
def _fuse_mxfp4_linear_allreduce(
match: Match,
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
rank,
rank_group,
):
_replace_with_fused_op(
match,
torch.ops.tensor_cast.mxfp4_linear_all_reduce,
(x, w, x_scale, w_scale, bias, out_dtype, rank, rank_group),
)
def register_all_patterns():
from . import register_pattern
if not config.compilation.fusion_patterns.enable_matmul_allreduce:
return
register_pattern(
name="matmul_allreduce_pattern",
pattern=matmul_allreduce_pattern,
handler=_fuse_matmul_allreduce,
extra_check=_is_matmul_allreduce,
)
register_pattern(
name="static_quant_linear_allreduce_pattern",
pattern=static_quant_linear_allreduce_pattern,
handler=_fuse_static_quant_linear_allreduce,
extra_check=_is_static_quant_linear_allreduce,
)
register_pattern(
name="static_quant_linear_int4_allreduce_pattern",
pattern=static_quant_linear_int4_allreduce_pattern,
handler=_fuse_static_quant_linear_int4_allreduce,
extra_check=_is_static_quant_linear_int4_allreduce,
)
register_pattern(
name="fp8_linear_allreduce_pattern",
pattern=fp8_linear_allreduce_pattern,
handler=_fuse_fp8_linear_allreduce,
extra_check=_is_fp8_linear_allreduce,
)
register_pattern(
name="mxfp4_linear_allreduce_pattern",
pattern=mxfp4_linear_allreduce_pattern,
handler=_fuse_mxfp4_linear_allreduce,
extra_check=_is_mxfp4_linear_allreduce,
)