import heapq
import operator
import torch
import torch.fx
def stable_topo_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Similar to torch.fx.passes.tools_common.legalize_graph but preserve
the original order in the graph as much as possible.
"""
PRIORITIZED_OPS = [
operator.add,
operator.mul,
operator.sub,
operator.floordiv,
operator.truediv,
operator.mod,
operator.le,
operator.lt,
operator.ge,
operator.gt,
operator.eq,
operator.ne,
torch.ops.aten.sym_constrain_range.default,
torch.ops.aten.sym_constrain_range_for_size.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten.scalar_tensor.default,
torch.ops.aten._assert_scalar.default,
]
PRIORITIZED_OPS_SET = set(PRIORITIZED_OPS)
original_order = {node: i for i, node in enumerate(gm.graph.nodes)}
indeg = dict.fromkeys(gm.graph.nodes, 0)
new_graph = torch.fx.Graph()
for node in gm.graph.nodes:
for user in node.users:
indeg[user] += 1
queue: list[tuple[int, int, torch.fx.Node]] = []
for node in gm.graph.nodes:
if indeg[node] == 0:
op_priority = 0 if node.op == "call_function" and node.target in PRIORITIZED_OPS_SET else 1
original_index = original_order[node]
heapq.heappush(queue, (op_priority, original_index, node))
env: dict[torch.fx.Node, torch.fx.Node] = {}
while len(queue) > 0:
op_prio, orig_idx, cur = heapq.heappop(queue)
if cur in env:
continue
env[cur] = new_graph.node_copy(cur, lambda x: env[x])
for user in cur.users:
indeg[user] -= 1
if indeg[user] == 0:
op_priority = 0 if user.op == "call_function" and user.target in PRIORITIZED_OPS_SET else 1
original_index = original_order[user]
heapq.heappush(queue, (op_priority, original_index, user))
if len(new_graph.nodes) < len(gm.graph.nodes):
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm