import functools
import logging
from typing import Any, Callable, Optional, Sequence
import torch
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import fake_tensor_prop
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.freezing import freeze
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
from .. import config, ops
from . import patterns
from .constant_folding import fold_meta_constants
from .freezing_passes import patterns as freezing_patterns
from .freezing_passes.dispatch_ffn_combine_pass import DispatchFFNCombinePass
from .freezing_passes.grouped_matmul_swiglu_pass import GroupedMatmulSwigluPass
from .freezing_passes.sink_split_pass import SinkSplitPass
from .passes.lift_quant_pass import LiftCombineQuantPass
from .passes.merge_linear_pass import MergeLinearPass
from .passes.multistream_pass import MultiStreamSchedulePass
from .passes.peep_hole_pass import PeepHolePass
from .passes.redundant_node_elimination_pass import ReduandantNodeEliminationPass
from .passes.sequence_parallel_pass import SequenceParallelPass
logger = logging.getLogger(__name__)
class CompilerBackend:
"""
The compilation backend for 'torch.compile'.
It is used to process the FX graph and perform custom operation fusing etc.
"""
def __init__(self, device_name: Optional[str] = None):
self._multistream_device_name = device_name
def __call__(self, gm: fx.GraphModule, example_inputs) -> Callable:
"""
Process the FX graph and perform custom operation fusing.
Args:
graph (fx.Graph): The FX graph to be processed.
example_inputs (optional): Example inputs for the graph.
Returns:
fx.Graph: The processed FX graph with custom operation fusing applied.
"""
gm = self.compile(gm, example_inputs)
return gm
def compile(
self,
gm: fx.GraphModule,
example_inputs,
**kwargs,
) -> tuple[Callable, Optional[Any]]:
def freezing_compile(compile_inner, aot_autograd_gm, example_inputs):
frozen_gm, preserved_arg_indices = freeze(gm, aot_autograd_gm, example_inputs)
example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
optimized_function = compile_inner(frozen_gm, example_inputs)
def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
args_new = [args[i] for i in preserved_arg_indices]
args.clear()
return optimized_function(*args_new)
wrapper._boxed_call = True
return wrapper
def graph_rewrite_before_freezing(fx_graph, inputs):
logger.debug("Graph before compiling:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(fx_graph.print_readable(print_output=False))
self.apply_peep_hole_pass(fx_graph, inputs)
self.apply_redundant_node_elimination_pass(fx_graph, inputs)
self.apply_quantization_passes(fx_graph, inputs)
self.apply_pattern_match_passes(fx_graph, inputs)
self.apply_sequence_parallel_pass(fx_graph, inputs)
return fx_graph
def graph_rewrite_after_freezing(fx_graph, inputs):
self.apply_merge_linear_pass(fx_graph, inputs)
fold_meta_constants(fx_graph)
self.apply_redundant_node_elimination_pass(fx_graph, inputs)
self.apply_freezing_passes(fx_graph, inputs)
self.apply_multistream_pass(fx_graph, inputs)
self.apply_decompose_auto_functionalized_pass(fx_graph)
logger.debug("Graph after compiling:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(fx_graph.print_readable(print_output=False))
return fx_graph
def compile_inner(fx_graph, inputs):
graph_rewrite_before_freezing(fx_graph, inputs)
if config.compilation.enable_freezing:
return freezing_compile(graph_rewrite_after_freezing, fx_graph, inputs)
else:
return graph_rewrite_after_freezing(fx_graph, inputs)
decompositions = select_decomp_table()
return aot_autograd(
fw_compiler=compile_inner,
decompositions=decompositions,
)(gm, example_inputs)
def apply_redundant_node_elimination_pass(self, gm: fx.GraphModule, inputs):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="redundant_node_elimination_pass",
log_url=config.compilation.debug.graph_log_url,
)
GraphTransformObserver(gm, "redundant_node_elimination_pass").apply_gm_pass(ReduandantNodeEliminationPass())
logger.debug("Graph after redundant node elimination pass:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_quantization_passes(self, gm: fx.GraphModule, inputs):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="quantization_passes",
log_url=config.compilation.debug.graph_log_url,
)
if config.compilation.passes.enable_life_combine_quant:
GraphTransformObserver(gm, "life_combine_quant_pass").apply_gm_pass(LiftCombineQuantPass())
logger.debug("Graph after quantization passes:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_pattern_match_passes(self, gm: fx.GraphModule, inputs):
patterns.lazy_init()
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="pattern_match_passes",
log_url=config.compilation.debug.graph_log_url,
)
for i, pattern_match_pass in enumerate(patterns.all_passes):
GraphTransformObserver(gm, f"pattern_match_pass_{i}").apply_gm_pass(pattern_match_pass)
logger.debug("Graph after pattern matching:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_sequence_parallel_pass(self, gm: fx.GraphModule, inputs):
"""Rewrite TP `all_reduce + norm` patterns into SP
`reduce_scatter + norm + all_gather`.
Runs the norm on a seq-dim-sharded tensor so each rank processes only
1/TP of the tokens, cutting norm compute/memory; the trade-off is an
all_gather after the norm to restore the full sequence. Must run before
freezing, otherwise CSE will break the all_reduce→norm match chain.
"""
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="sequence_parallel_pass",
log_url=config.compilation.debug.graph_log_url,
)
if config.compilation.passes.enable_sequence_parallel:
GraphTransformObserver(gm, "sequence_parallel_pass").apply_gm_pass(SequenceParallelPass())
logger.debug("Graph after sequence parallel pass:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_decompose_auto_functionalized_pass(self, gm: fx.GraphModule):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="decompose_auto_functionalized_pass",
log_url=config.compilation.debug.graph_log_url,
)
GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass(decompose_auto_functionalized)
def apply_merge_linear_pass(self, gm: fx.GraphModule, inputs):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="merge_linear_pass",
log_url=config.compilation.debug.graph_log_url,
)
if config.compilation.passes.enable_merge_linear:
GraphTransformObserver(gm, "merge_linear_pass").apply_gm_pass(MergeLinearPass())
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
logger.debug("Graph after the merge linear pass:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_freezing_passes(self, gm: fx.GraphModule, inputs):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="freezing_passes",
log_url=config.compilation.debug.graph_log_url,
)
if config.compilation.passes.enable_sink_split:
GraphTransformObserver(gm, "sink_split_pass").apply_gm_pass(SinkSplitPass())
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
if config.compilation.fusion_patterns.enable_matmul_allreduce:
self.apply_freezing_pattern_passes(gm, inputs)
if config.compilation.fusion_patterns.enable_grouped_matmul_swiglu:
GraphTransformObserver(gm, "grouped_matmul_swiglu_fusion_pass").apply_gm_pass(GroupedMatmulSwigluPass())
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
if config.compilation.fusion_patterns.enable_dispatch_ffn_combine:
GraphTransformObserver(gm, "dispatch_ffn_combine_fusion_pass").apply_gm_pass(DispatchFFNCombinePass())
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
logger.debug("Graph after freezing passes:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_freezing_pattern_passes(self, gm: fx.GraphModule, inputs):
freezing_patterns.lazy_init()
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="freezing_pattern_passes",
log_url=config.compilation.debug.graph_log_url,
)
for i, freezing_pattern_pass in enumerate(freezing_patterns.all_passes):
GraphTransformObserver(gm, f"freezing_pattern_pass_{i}").apply_gm_pass(freezing_pattern_pass)
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
def apply_peep_hole_pass(self, gm: fx.GraphModule, inputs):
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="peep_hole_pass",
log_url=config.compilation.debug.graph_log_url,
)
GraphTransformObserver(gm, "peep_hole_pass").apply_gm_pass(PeepHolePass())
logger.debug("Graph after peep hole pass:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))
def apply_multistream_pass(self, gm: fx.GraphModule, inputs):
if not config.compilation.multistream.enable:
return
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="multistream_pass",
log_url=config.compilation.debug.graph_log_url,
)
fake_tensor_prop(gm, inputs, force_allow_non_fake_inputs=True)
GraphTransformObserver(gm, "multistream_pass").apply_gm_pass(
MultiStreamSchedulePass(device_name=self._multistream_device_name)
)
logger.debug("Graph after multistream pass:")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(gm.print_readable(print_output=False))