from collections import defaultdict
from types import SimpleNamespace
import torch
from torch.fx import Graph, GraphModule, Node
from torch.library import Library
from torch.utils._pytree import tree_map
from torch._inductor import config as inductor_config
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.virtualized import V
from torch._subclasses import FakeTensor
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.tools_common import legalize_graph
from torch.fx.passes.utils.fuser_utils import (
topo_sort,
fuse_as_graphmodule,
erase_nodes,
)
from .graph_build import DvmCodegenInterpreter
from .load_codegen import patch_gm_placeholder_strides_from_codegen_args
from .fx_test import generate_dvm_fx_case
from .op_emitter import DVM_OP_REGISTRY
from .fx_pass import (
decompose_k1_matmul_to_mul,
insert_promote_cast_by_pos_prims,
insert_sum_fp32_prepost_cast_prims,
expand_to_reshape,
need_fallback_gm,
)
dump_fx_test = False
uncont_policy = "fuse"
aten = torch.ops.aten
prims = torch.ops.prims
GRAPH_FUSION_SUPPORT_OP = [
aten.add.Tensor,
aten.add.Scalar,
aten.sub.Tensor,
aten.sub.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.div.Tensor,
aten.div.Scalar,
aten.pow.Tensor_Tensor,
aten.pow.Tensor_Scalar,
aten.pow.Scalar,
aten.lt.Tensor,
aten.lt.Scalar,
aten.le.Tensor,
aten.le.Scalar,
aten.gt.Tensor,
aten.gt.Scalar,
aten.ge.Tensor,
aten.ge.Scalar,
aten.eq.Tensor,
aten.eq.Scalar,
aten.ne.Tensor,
aten.ne.Scalar,
aten.maximum.default,
aten.minimum.default,
aten.sqrt.default,
aten.rsqrt.default,
aten.abs.default,
aten.log.default,
aten.exp.default,
aten.reciprocal.default,
aten.isfinite.default,
prims.convert_element_type.default,
torch.ops.npu.npu_dtype_cast.default,
torch.ops.npu.npu_dtype_cast_backward.default,
torch.ops.npu._npu_dtype_cast.default,
torch.ops.npu._npu_dtype_cast_backward.default,
aten.sum.dim_IntList,
aten.sum.default,
aten.neg.default,
aten.relu.default,
aten.mm.default,
aten.bmm.default,
aten.addmm.default,
aten.where.default,
aten.where.self,
]
class UnionFind:
def __init__(self) -> None:
self.parent: dict[Node, Node] = {}
self.rank: dict[Node, int] = {}
def find(self, x: Node) -> Node:
p = self.parent.get(x, x)
if p != x:
self.parent[x] = self.find(p)
else:
self.parent[x] = x
return self.parent[x]
def union(self, a: Node, b: Node) -> None:
ra, rb = self.find(a), self.find(b)
if ra == rb:
return
rka = self.rank.get(ra, 0)
rkb = self.rank.get(rb, 0)
if rka < rkb:
ra, rb = rb, ra
rka, rkb = rkb, rka
self.parent[rb] = ra
if rka == rkb:
self.rank[ra] = rka + 1
class DvmOpSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node):
if node.op == "call_function" and node.target in GRAPH_FUSION_SUPPORT_OP:
_, rule = DVM_OP_REGISTRY.get(node.target)
return rule(node)
return False
def split_partition_with_union_find(
partition_nodes: dict[Node, None],
) -> list[dict[Node, None]]:
"""
Split a partition into connected components based on data-dependency edges
within the partition: if u is an input of v and both are in partition,
they belong to the same component.
"""
nodes = list(partition_nodes.keys())
node_set = set(nodes)
uf = UnionFind()
for n in nodes:
uf.find(n)
for v in nodes:
for u in v.all_input_nodes:
if u in node_set:
uf.union(u, v)
groups: dict[Node, dict[Node, None]] = defaultdict(dict)
for n in nodes:
root = uf.find(n)
groups[root][n] = None
return list(groups.values())
class _FusedMeta:
def __init__(self, name: str, gm: GraphModule, input_nodes: list[Node]):
self.name = name
self.gm = gm
def codegen(self):
"""
Return (codegen_interpreter, python_source_string).
"""
if dump_fx_test:
generate_dvm_fx_case(self.gm, fusion_type="graph")
cg = DvmCodegenInterpreter(self.gm, ktype="split")
cg.run()
code = cg.code.getvalue().replace(cg.KERNEL_NAME_PLACEHOLDER, self.name)
return cg, code
_fused_metas: dict[int, _FusedMeta] = {}
_fused_libs: dict[str, Library] = {}
def _fused_run_stub(*args, **kwargs):
raise AssertionError("This op should never run at eager runtime.")
def _fused_run_fake(*args, **kwargs):
"""
Fake implementation for custom_op. The last arg is fused_id.
"""
fused_id = int(args[-1])
meta = _fused_metas[fused_id]
out = meta.gm.forward(*args[:-1])
return tree_map(lambda t: t if t.is_contiguous() else t.contiguous(), out)
class GraphFusionPartitioner(CapabilityBasedPartitioner):
fused_op_map_: dict[str, object] = {}
fused_id_: int = 0
def _input_schema_types(self, input_nodes: list[Node]) -> list[str]:
input_types: list[str] = []
for node in input_nodes:
val = node.meta.get("val", None)
if isinstance(val, torch.SymInt):
input_types.append("SymInt")
elif isinstance(val, torch.SymFloat):
input_types.append("SymFloat")
else:
input_types.append("Tensor")
return input_types
def _build_schema(self, input_types: list[str], output_len: int) -> str:
extra_inputs = "int fused_id"
input_schema = ", ".join(
f"{input_type} x{i}" for i, input_type in enumerate(input_types)
)
input_schema = (
f"{input_schema}, {extra_inputs}" if input_schema else extra_inputs
)
if output_len == 1:
output_schema = "Tensor"
else:
output_schema = f'({", ".join(["Tensor"] * output_len)})'
return f"({input_schema}) -> {output_schema}"
def _get_or_create_custom_op(self, input_nodes: list[Node], output_len: int):
input_types = self._input_schema_types(input_nodes)
type_suffix = "_".join(
"t" if t == "Tensor" else "si" if t == "SymInt" else "sf"
for t in input_types
)
input_len = len(input_types)
op_def = f"{input_len}_{output_len}"
if any(t != "Tensor" for t in input_types):
op_def = f"{op_def}_{type_suffix}"
custom = self.fused_op_map_.get(op_def, None)
if custom is not None:
return custom
schema = self._build_schema(input_types, output_len)
op_name = "fused_graph_" + op_def
qualname = "dvm::" + op_name
lib = Library("dvm", "FRAGMENT")
lib.define(op_name + schema, tags=[torch._C.Tag.flexible_layout])
lib._register_fake(op_name, _fused_run_fake, _stacklevel=1)
lib.impl(op_name, _fused_run_stub, "CompositeExplicitAutograd")
opoverload = getattr(getattr(torch.ops.dvm, op_name), "default")
custom = SimpleNamespace(_opoverload=opoverload, _lib=lib, _qualname=qualname)
_fused_libs[qualname] = lib
self.fused_op_map_[op_def] = custom
return custom
def _should_fuse(
self,
sub_gm: GraphModule,
orig_outputs: list[Node],
orig_inputs: list[Node],
) -> bool:
if len(orig_outputs) == 0:
return False
if need_fallback_gm(sub_gm):
return False
return any(
isinstance(node.meta.get("val", None), FakeTensor) for node in orig_inputs
)
def partition_and_fuse(self) -> GraphModule:
partitions = self.propose_partitions()
partition_nodes_list: list[dict[Node, None]] = [
sp_nodes
for partition in partitions
for sp_nodes in split_partition_with_union_find(partition.nodes)
]
for partition_nodes in partition_nodes_list:
fused_id = self.fused_id_
self.fused_id_ += 1
fused_name = f"dvm_graph_fused_{fused_id}"
sorted_nodes = topo_sort(list(partition_nodes))
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
self.graph_module,
sorted_nodes,
fused_name,
partition_nodes,
)
decompose_k1_matmul_to_mul(sub_gm)
insert_promote_cast_by_pos_prims(sub_gm)
insert_sum_fp32_prepost_cast_prims(sub_gm)
expand_to_reshape(sub_gm)
if not self._should_fuse(sub_gm, orig_outputs, orig_inputs):
continue
_fused_metas[fused_id] = _FusedMeta(fused_name, sub_gm, orig_inputs)
output_len = len(orig_outputs)
custom = self._get_or_create_custom_op(orig_inputs, output_len)
args = (*orig_inputs, fused_id)
new_node = self.graph_module.graph.call_function(
custom._opoverload, tuple(args), None
)
new_meta_vals = []
with V.fake_mode:
for orig_output in orig_outputs:
meta_val = orig_output.meta["val"]
new_meta_vals.append(
torch.empty(
meta_val.size(),
dtype=meta_val.dtype,
device=meta_val.device,
requires_grad=meta_val.requires_grad,
)
)
if output_len == 1:
orig_outputs[0].replace_all_uses_with(new_node)
new_node.meta["val"] = new_meta_vals[0]
else:
for i, (orig_output, meta_val) in enumerate(
zip(orig_outputs, new_meta_vals)
):
proxy_out = torch.fx.Proxy(new_node)[i].node
proxy_out.meta["val"] = meta_val
orig_output.replace_all_uses_with(proxy_out)
new_node.meta["val"] = tuple(new_meta_vals)
erase_nodes(self.graph_module, sorted_nodes)
legalize_graph(self.graph_module)
return self.graph_module
def dvm_graph_fusion(graph: Graph):
gm: GraphModule = graph.owning_module
dvm_support = DvmOpSupport()
fusion_part = GraphFusionPartitioner(
gm,
dvm_support,
allows_single_node_partition=True,
)
fusion_part.partition_and_fuse()
def _dvm_generate_fallback_kernel(self, fallback_kernel, args):
"""
Patch point: PythonWrapperCodegen.generate_fallback_kernel
If it's our custom op, pop meta and emit dvm kernel code.
"""
if not fallback_kernel.op_overload._name.startswith("dvm::fused_graph_"):
self.generate_extern_kernel_alloc(fallback_kernel, args)
return
fused_id = int(args[-1])
meta = _fused_metas.pop(fused_id)
args_list = list(args[:-1])
patch_gm_placeholder_strides_from_codegen_args(meta.gm, args_list)
cg, code = meta.codegen()
self.header.splice(code)
buf_name = fallback_kernel.get_name()
for i, no_trans in enumerate(cg.cont_flag_input):
if not no_trans:
args_list[i] += ".contiguous()"
for i, trans in enumerate(cg.need_trans_input):
if trans:
args_list[i] += ".mT"
self.writeline(f"{buf_name} = {meta.name}({', '.join(args_list)})")
self.add_import_once("from torch_npu._inductor import dvm")
class DvmGraphFusionPatch:
_enabled = False
_orig_generate_fallback_kernel = None
_orig_post_grad_custom_post_pass = None
@staticmethod
def enable() -> None:
if not DvmGraphFusionPatch._enabled:
DvmGraphFusionPatch._orig_generate_fallback_kernel = (
PythonWrapperCodegen.generate_fallback_kernel
)
DvmGraphFusionPatch._orig_post_grad_custom_post_pass = (
inductor_config.post_grad_custom_post_pass
)
PythonWrapperCodegen.generate_fallback_kernel = (
_dvm_generate_fallback_kernel
)
inductor_config.post_grad_custom_post_pass = dvm_graph_fusion
DvmGraphFusionPatch._enabled = True
@staticmethod
def disable() -> None:
if not DvmGraphFusionPatch._enabled:
return
PythonWrapperCodegen.generate_fallback_kernel = (
DvmGraphFusionPatch._orig_generate_fallback_kernel
)
inductor_config.post_grad_custom_post_pass = (
DvmGraphFusionPatch._orig_post_grad_custom_post_pass
)
DvmGraphFusionPatch._enabled = False
def __enter__(self) -> "DvmGraphFusionPatch":
DvmGraphFusionPatch.enable()
return self
def __exit__(self, exc_type, exc, tb) -> bool:
DvmGraphFusionPatch.disable()
return False