import torch
import torch.utils._pytree as pytree
from torch.fx.node import Argument, Target
from torch._inductor.utils import IndentedBuffer
from .op_emitter import DVM_OP_REGISTRY, load, store, view_load
from .fx_pass import annotate_mm_transpose_flags
aten = torch.ops.aten
def is_fx_dynamic(graph):
for node in graph.graph.nodes:
if node.op == "placeholder" or node.op == "call_function":
if isinstance(node.meta["val"], torch.Tensor):
if any(isinstance(dim, torch.SymInt) for dim in node.meta["val"].shape):
return True
elif isinstance(node.meta["val"], (torch.SymInt, torch.SymFloat)):
return True
return False
class DvmCodegenInterpreter(torch.fx.Interpreter):
KERNEL_NAME_PLACEHOLDER = "__DVM_KERNEL_NAME__"
def __init__(
self,
gm: torch.fx.GraphModule,
ktype: str,
uncont_policy="fuse",
):
super().__init__(gm)
self.gm = gm
self.ktype = ktype
self.is_mix_kernel = annotate_mm_transpose_flags(gm)
self.is_dynamic = is_fx_dynamic(gm)
self.current_node = None
self.cont_flag_input = []
self.need_trans_input = []
self.use_view = uncont_policy == "fuse"
self.code = IndentedBuffer()
self.spec_nodes = set()
if self.ktype == "vector" and self.need_spec():
self.ktype = "spec"
self.code.splice(
f'\n"""\n{self.gm.print_readable(print_output=False)}\n"""')
decorator = (
f"{chr(64)}dvm.kernel(ktype={self.ktype!r}, dyn_shape={self.is_dynamic})"
)
self.code.splice(decorator)
self.code.splice(f"def {self.KERNEL_NAME_PLACEHOLDER}(k):")
self.code.do_indent()
def need_spec(self) -> bool:
self.spec_nodes.clear()
for node in self.gm.graph.nodes:
if node.op != "call_function":
continue
for input_node in node.all_input_nodes:
if input_node.op == "call_function" and input_node.target in [
aten.sum.default,
aten.sum.dim_IntList,
aten.amax.default,
aten.amin.default,
]:
self.spec_nodes.add(input_node)
return len(self.spec_nodes) > 0
def run_node(self, n: torch.fx.Node) -> Argument:
self.current_node = n
expr = super().run_node(n)
if n.op == "output":
for _expr in pytree.tree_leaves(expr):
self.code.splice(f"{_expr}")
else:
self.code.splice(f"{n} = {expr}")
if n in self.spec_nodes:
self.code.splice("k.spec_next()")
return f"{n}"
def placeholder(
self, target: "Target", args: tuple[Argument], kwargs: dict[str, Argument]
) -> Argument:
meta = self.current_node.meta
val = meta["val"]
if isinstance(val, torch.SymInt):
self.cont_flag_input.append(True)
return "k.scalar(dvm.int64)"
if isinstance(val, torch.SymFloat):
self.cont_flag_input.append(True)
return "k.scalar(dvm.float32)"
is_contiguous = val.is_contiguous()
shape, stride, dtype = val.shape, val.stride(), val.dtype
is_symbolic = any(
isinstance(s, torch.SymInt) and s.node.is_symbolic() for s in shape
)
shape = [-1 if isinstance(s, torch.SymInt) else s for s in shape]
stride = [-1 if isinstance(s, torch.SymInt) else s for s in stride]
self.need_trans_input.append(meta.get("trans", False))
if self.is_mix_kernel:
if meta.get("trans", False):
self.cont_flag_input.append(True)
shape = val.mT.shape
return load(shape, dtype)
else:
self.cont_flag_input.append(is_contiguous)
return load(shape, dtype)
else:
if is_contiguous:
self.cont_flag_input.append(True)
return load(shape, dtype)
else:
if is_symbolic or not self.use_view:
self.cont_flag_input.append(False)
return load(shape, dtype)
else:
if stride[-1] == 1 and shape[-1] != 1:
self.cont_flag_input.append(True)
return view_load(shape, stride, dtype)
else:
self.cont_flag_input.append(False)
return load(shape, dtype)
def call_function(
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> Argument:
if target not in DVM_OP_REGISTRY:
raise NotImplementedError(f"{target} not implemented in DVM")
func, _ = DVM_OP_REGISTRY.get(target)
meta = self.current_node.meta
if target in (aten.mm.default, aten.bmm.default):
args = (*args, meta.get("trans_a", False),
meta.get("trans_b", False))
elif target is aten.addmm.default:
args = (
*args,
meta.get("trans_a", False),
meta.get("trans_b", False),
meta.get("use_bias", False),
)
return func(*args, **kwargs)
def output(
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> Argument:
outs = super().output(target, args, kwargs)
def codegen(out, node):
if isinstance(node, torch.fx.Node):
return store(out, node.meta["val"].dtype)
return ""
return pytree.tree_map(codegen, outs, self.current_node.args[0])