from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._inductor.freezing_utils import maybe_set_is_frozen_param
from torch.fx import GraphModule, Node
def replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: Any,
name: Optional[str] = None,
) -> None:
"""
Replaces a node in the graph with a 'get_attr' node pointing to
a new constant value registered on the GraphModule.
"""
g = gm.graph
if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0
i = gm._frozen_param_count
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1
gm._frozen_param_count = i + 1
with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)
new_input_node.name = node.name
if isinstance(constant, torch.Tensor):
gm.register_buffer(qualname, constant)
maybe_set_is_frozen_param(constant)
else:
setattr(gm, qualname, constant)
class MetaConstantFolder(torch.fx.Interpreter):
"""
A simplified FX Interpreter that executes operations on 'meta' tensors
and records the results.
It identifies nodes that can be pre-computed (folded) because
all of their inputs are constants (either 'get_attr' nodes
pointing to meta tensors or other primitive constants like ints).
"""
def __init__(self, gm: GraphModule):
super().__init__(gm)
self.node_replacements: dict[Node, Any] = {}
self.unknown_value = object()
def run(self) -> None:
"""
Runs the interpreter over the graph.
"""
env: dict[Node, Any] = {}
for n in self.module.graph.find_nodes(op="placeholder"):
env[n] = self.unknown_value
super().run(initial_env=env)
def run_node(self, node: Node) -> Any:
"""
Executes a single node.
"""
if node.op == "placeholder":
return self.unknown_value
if node.op == "output":
return super().run_node(node)
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
inputs_are_constants = True
for inp in flattened_inputs:
if inp is self.unknown_value:
inputs_are_constants = False
break
if isinstance(inp, torch.Tensor) and inp.device.type != "meta":
inputs_are_constants = False
break
if node.op == "get_attr":
val = super().run_node(node)
if isinstance(val, torch.Tensor) and val.device.type != "meta":
return self.unknown_value
return val
if not inputs_are_constants:
return self.unknown_value
try:
out = super().run_node(node)
except Exception:
return self.unknown_value
flattened_outputs = pytree.arg_tree_leaves(out)
for o in flattened_outputs:
if isinstance(o, torch.Tensor) and o.device.type != "meta":
return self.unknown_value
if node.op == "call_function":
self.node_replacements[node] = out
return out
def fold_meta_constants(gm: GraphModule) -> GraphModule:
"""
Performs constant folding on a GraphModule with 'meta' device tensors.
This function will:
1. Run the MetaConstantFolder to find all foldable nodes.
2. Replace foldable nodes with 'get_attr' nodes pointing to new buffers.
3. Clean up the graph by removing dead nodes.
Args:
gm: The GraphModule to process.
"""
with torch.utils._python_dispatch._disable_current_modes():
folder = MetaConstantFolder(gm)
folder.run()
for node, constant in folder.node_replacements.items():
replace_node_with_constant(gm, node, constant)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm