from dataclasses import dataclass
import torch
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
from torch._subclasses import FakeTensor
from torch.fx import GraphModule, Node
from .op_emitter import _is_last2_transpose_tensor
aten = torch.ops.aten
prims = torch.ops.prims
def need_fallback_gm(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "call_function":
continue
if node.target not in (
aten.reshape.default,
aten.expand.default,
):
return False
return True
def expand_dvm_mm_to_explicit_transpose_for_inductor(gm: torch.fx.GraphModule) -> bool:
"""Rewrite ``aten.mm`` + ``dvm_trans_*`` into explicit 2D ``permute(1,0)`` + ``mm``.
MFusion IR roundtrip can emit a single ``torch.aten.mm`` with ``dvm_trans_a`` / ``dvm_trans_b``
(fused ``aclnn.mm`` transpose flags). FakeTensor propagation uses a stub so ``meta`` is
consistent, but Inductor lowering calls ``mm_args`` on **storage** shapes and fails
``guard_equals`` on the inner dimension. Inserting ``aten.permute`` (last-two dims swap for 2D)
restores layouts Inductor expects. We use ``permute`` rather than ``transpose.int`` because
some NPU Inductor builds register both a fallback and a decomposition for ``transpose.int``,
which triggers ``AssertionError: both a fallback and a decomp for same op``.
"""
g = gm.graph
changed = False
swap_2d = (1, 0)
for n in list(g.nodes):
if n.op != "call_function" or n.target != aten.mm.default:
continue
ta = bool(n.meta.get("dvm_trans_a", False))
tb = bool(n.meta.get("dvm_trans_b", False))
if not (ta or tb):
continue
if len(n.args) < 2:
continue
lhs, rhs = n.args[0], n.args[1]
if not isinstance(lhs, Node) or not isinstance(rhs, Node):
continue
lv = lhs.meta.get("val")
rv = rhs.meta.get("val")
if not isinstance(lv, torch.Tensor) or not isinstance(rv, torch.Tensor):
continue
if lv.dim() != 2 or rv.dim() != 2:
continue
new_lhs, new_rhs = lhs, rhs
with g.inserting_before(n):
if ta:
new_lhs = g.call_function(aten.permute.default, (lhs, list(swap_2d)))
if tb:
new_rhs = g.call_function(aten.permute.default, (rhs, list(swap_2d)))
n.args = (new_lhs, new_rhs)
n.meta.pop("dvm_trans_a", None)
n.meta.pop("dvm_trans_b", None)
changed = True
if changed:
g.lint()
gm.recompile()
return changed
def annotate_mm_transpose_flags(gm: torch.fx.GraphModule):
flag = False
for node in gm.graph.nodes:
if node.op != "call_function":
continue
if node.target in [aten.mm.default, aten.bmm.default]:
lhs = node.args[0]
rhs = node.args[1]
flag = True
elif node.target is aten.addmm.default:
add = node.args[0]
lhs = node.args[1]
rhs = node.args[2]
if (
add.meta["val"].dim() == 1
and node.kwargs.get("beta", 1) == 1
and node.kwargs.get("alpha", 1) == 1
):
node.meta["use_bias"] = True
flag = True
else:
continue
node.meta["trans_a"] = False
node.meta["trans_b"] = False
if lhs.op == "placeholder" and _is_last2_transpose_tensor(lhs.meta["val"]):
node.meta["trans_a"] = True
lhs.meta["trans"] = True
elif getattr(lhs, "meta", None) and lhs.meta.get("trans"):
node.meta["trans_a"] = True
if rhs.op == "placeholder" and _is_last2_transpose_tensor(rhs.meta["val"]):
node.meta["trans_b"] = True
rhs.meta["trans"] = True
elif getattr(rhs, "meta", None) and rhs.meta.get("trans"):
node.meta["trans_b"] = True
if node.meta.get("dvm_trans_a"):
node.meta["trans_a"] = True
if node.meta.get("dvm_trans_b"):
node.meta["trans_b"] = True
return flag
def make_cast_node(g, src: Node, target_dtype: torch.dtype) -> Node:
cast = g.call_function(
prims.convert_element_type.default,
args=(src, target_dtype),
)
cast.meta["val"] = src.meta["val"].to(dtype=target_dtype)
return cast
def decompose_k1_matmul_to_mul(gm: GraphModule) -> GraphModule:
g = gm.graph
changed = False
for node in list(g.nodes):
if node.op != "call_function":
continue
if node.target not in (aten.mm.default, aten.bmm.default):
continue
lhs, rhs = node.args[:2]
if not isinstance(lhs, Node) or not isinstance(rhs, Node):
continue
lhs_val = lhs.meta.get("val", None)
rhs_val = rhs.meta.get("val", None)
if not isinstance(lhs_val, FakeTensor) or not isinstance(rhs_val, FakeTensor):
continue
if _is_last2_transpose_tensor(lhs_val):
continue
if _is_last2_transpose_tensor(rhs_val):
continue
lhs_k = lhs_val.shape[-1]
rhs_k = rhs_val.shape[-2]
if isinstance(lhs_k, torch.SymInt) or isinstance(rhs_k, torch.SymInt):
continue
if lhs_k != 1 or rhs_k != 1:
continue
with g.inserting_before(node):
mul_node = g.call_function(aten.mul.Tensor, args=(lhs, rhs))
mul_node.meta["val"] = node.meta["val"]
node.replace_all_uses_with(mul_node)
g.erase_node(node)
changed = True
if changed:
g.lint()
gm.recompile()
return gm
def insert_sum_fp32_prepost_cast_prims(gm: GraphModule):
g = gm.graph
for node in g.nodes:
if node.op != "call_function":
continue
if node.target not in [aten.sum.default, aten.sum.dim_IntList]:
continue
out_val = node.meta.get("val", None)
if not isinstance(out_val, FakeTensor):
continue
orig_out_dtype = out_val.dtype
if not node.args:
continue
x = node.args[0]
if not isinstance(x, Node):
continue
in_val = x.meta.get("val", None)
if not isinstance(in_val, FakeTensor):
continue
in_dtype = in_val.dtype
if in_dtype == torch.float32:
continue
with g.inserting_before(node):
x_fp32 = make_cast_node(g, x, torch.float32)
new_args = list(node.args)
new_args[0] = x_fp32
node.args = tuple(new_args)
if orig_out_dtype != torch.float32:
with g.inserting_after(node):
y = make_cast_node(g, node, orig_out_dtype)
node.replace_all_uses_with(y)
y.args = (node, orig_out_dtype)
g.lint()
gm.recompile()
return gm
@dataclass(frozen=True)
class PromoteRule:
pos: tuple[int, ...]
kind: object = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
PROMOTE_TYPE_OP = {
aten.add.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
aten.sub.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
aten.mul.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
aten.div.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
aten.pow.Tensor_Tensor: PromoteRule(
(0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
),
aten.lt.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.le.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.gt.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.ge.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.eq.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.ne.Tensor: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL),
aten.maximum.default: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
aten.minimum.default: PromoteRule((0, 1), ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
}
def insert_promote_cast_by_pos_prims(gm: GraphModule) -> GraphModule:
g = gm.graph
for node in g.nodes:
if node.op != "call_function":
continue
rule = PROMOTE_TYPE_OP.get(node.target, None)
if rule is None:
continue
arg_vals = []
arg_nodes = {}
for idx in rule.pos:
if idx >= len(node.args):
continue
arg = node.args[idx]
if not isinstance(arg, Node):
continue
val = arg.meta.get("val", None)
if not isinstance(val, FakeTensor):
continue
arg_vals.append(val)
arg_nodes[idx] = arg
if len(arg_vals) <= 1:
continue
dtypes = [v.dtype for v in arg_vals]
if all(dt == dtypes[0] for dt in dtypes[1:]):
continue
compute_dtype, _ = elementwise_dtypes(
*arg_vals,
type_promotion_kind=rule.kind,
)
new_args = list(node.args)
for idx, arg in arg_nodes.items():
in_val = arg.meta.get("val", None)
if in_val.dtype == compute_dtype:
continue
with g.inserting_before(node):
cast = make_cast_node(g, arg, compute_dtype)
new_args[idx] = cast
node.args = tuple(new_args)
g.lint()
gm.recompile()
return gm
def expand_to_reshape(gm: GraphModule) -> GraphModule:
for node in gm.graph.find_nodes(op="call_function", target=aten.expand.default):
x = node.args[0]
in_val = x.meta.get("val")
out_val = node.meta.get("val")
if tuple(in_val.shape) == tuple(out_val.shape):
node.target = aten.reshape.default