import logging
import operator
from typing import Any, Dict, Tuple
import torch
import torch.fx as fx
from ... import ops
from ..pass_base import TensorCastGraphModulePass
logger = logging.getLogger(__name__)
class LiftCombineQuantPass(TensorCastGraphModulePass):
_SWAPPABLE_OPS = {
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
}
_QUANTIZE_OPS = [
torch.ops.tensor_cast.quantize.default,
torch.ops.tensor_cast.dynamic_quantize_asymmetric.default,
torch.ops.tensor_cast.dynamic_quantize_symmetric.default,
torch.ops.tensor_cast.dynamic_quantize_mxfp4.default,
]
"""
An FX graph pass that lifts `tensor_cast.quantize` operations as early
as possible and combines identical quantization operations into one.
This makes later fusions possible such as rms_norm+quant fusion.
The pass works as follows:
1. Iterates through all nodes in the graph to find `quantize` calls.
2. For each `quantize` call, it traces its input backward through a chain
of "swappable" ops (like reshape, view, transpose).
3. It determines the true, pre-view-change input tensor.
4. It uses a cache to see if this tensor has already been quantized with the
same scale/offset.
- If yes, it reuses the existing quantized tensor.
- If no, it inserts a new `quantize` op right after the true input tensor
and adds it to the cache.
5. It rebuilds the chain of swappable ops on top of the (new or cached)
quantized tensor.
6. It replaces the original `quantize` node's uses with the final node of
the rebuilt chain.
7. Finally, it removes all the old, now-unused nodes.
"""
def __call__(self, gm: fx.GraphModule) -> fx.GraphModule:
logger.debug("Running LiftCombineQuantPass.........")
def is_swappable(node: fx.Node) -> bool:
"""Checks if a node represents a swappable operation."""
if node.op == "call_function":
return node.target in self._SWAPPABLE_OPS
return False
def is_multi_output_node(node: fx.Node) -> bool:
"""Checks if a node produces multiple outputs, i.e. used by getitem."""
return any(user.op == "call_function" and user.target == operator.getitem for user in node.users)
graph = gm.graph
for quantize_op in self._QUANTIZE_OPS:
node_cache: Dict[Tuple[Any, Any, Any], fx.Node] = {}
for node in graph.find_nodes(op="call_function", target=quantize_op):
is_multi_output = is_multi_output_node(node)
original_quant_node = node
current_input = original_quant_node.args[0]
swappable_ops_chain = []
while is_swappable(current_input):
if len(current_input.users) > 1:
break
swappable_ops_chain.append(current_input)
current_input = current_input.args[0]
args = original_quant_node.args
kwargs = original_quant_node.kwargs
cache_key = (
original_quant_node.target,
(current_input, *args[1:]),
kwargs,
)
if cache_key in node_cache:
lifted_quant_node = node_cache[cache_key]
if is_multi_output:
new_getitem_nodes = {}
for user in list(lifted_quant_node.users):
if user.op == "call_function" and user.target == operator.getitem:
idx = user.args[1]
new_getitem_nodes.setdefault(idx, user)
quantized_tensor_node = new_getitem_nodes[0]
else:
quantized_tensor_node = lifted_quant_node
else:
with graph.inserting_after(current_input):
lifted_quant_node = graph.call_function(
quantize_op,
args=(current_input, *args[1:]),
kwargs=kwargs,
)
node_cache[cache_key] = lifted_quant_node
if is_multi_output:
new_getitem_nodes = {}
for user in list(original_quant_node.users):
if user.op == "call_function" and user.target == operator.getitem:
idx = user.args[1]
if idx not in new_getitem_nodes:
with graph.inserting_after(lifted_quant_node):
new_getitem_nodes[idx] = graph.call_function(
operator.getitem,
args=(lifted_quant_node, idx),
)
assert 0 in new_getitem_nodes, (
f"Expected accessing to first output of {original_quant_node} but got {new_getitem_nodes}"
)
quantized_tensor_node = new_getitem_nodes[0]
else:
quantized_tensor_node = lifted_quant_node
final_node = quantized_tensor_node
for swappable_node in reversed(swappable_ops_chain):
key = (
swappable_node.target,
(final_node, *swappable_node.args[1:]),
swappable_node.kwargs,
)
if key in node_cache:
final_node = node_cache[key]
continue
with graph.inserting_after(final_node):
call = getattr(graph, swappable_node.op)
final_node = call(
swappable_node.target,
args=(final_node, *swappable_node.args[1:]),
kwargs=swappable_node.kwargs,
)
node_cache[key] = final_node
if is_multi_output:
for user in list(original_quant_node.users):
if user.op == "call_function" and user.target == operator.getitem:
idx = user.args[1]
if idx == 0:
user.replace_all_uses_with(final_node)
else:
assert idx in new_getitem_nodes
user.replace_all_uses_with(new_getitem_nodes[idx])
else:
original_quant_node.replace_all_uses_with(final_node)
graph.eliminate_dead_code()
gm.recompile()
return gm