from typing import Any, Callable, Dict, List, Tuple, Optional, Sequence
import re
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from .gm_pass_base import GraphModulePass
from .._custom_decomposition import select_custom_decomp_table
from ...utils.logs.logging import logger
DEBUG_LOG_LEVEL = 10
torch_version = re.match(r"(\d+\.\d+)", torch.__version__).group(1)
IS_TORCH_21 = torch_version == "2.1"
if IS_TORCH_21:
def mindie_inference_graph(fn, args):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._subclasses.fake_tensor import FakeTensor
decomp_table = select_custom_decomp_table()
def safe_to_copy(x, dtype=None, layout=None, device=None, pin_memory=False, non_blocking=False):
if isinstance(x, FakeTensor):
return x
return torch.ops.aten._to_copy.default(x, dtype, layout, device, pin_memory, non_blocking)
decomp_table[torch.ops.aten._to_copy.default] = safe_to_copy
gm = make_fx(fn, decomposition_table=decomp_table)(*args)
gm.graph.eliminate_dead_code()
gm.recompile()
return gm
class PatternMatchPass(GraphModulePass):
def __init__(self):
self.pattern_replacements: Dict[str, Tuple[Callable[..., Any], Callable[..., Any]]] = {}
try:
self.pattern_pass: PatternMatcherPass = PatternMatcherPass(pass_name="pattern_match_pass")
except TypeError:
self.pattern_pass: PatternMatcherPass = PatternMatcherPass()
def __call__(self, graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
matched_cnt = 0
while True:
cnt = self.pattern_pass.apply(graph)
if cnt == 0:
break
matched_cnt += cnt
if logger.isEnabledFor(DEBUG_LOG_LEVEL):
logger.debug("PatternMatchPass replace %d patterns.", matched_cnt)
pattern_idx = 0
logger.debug("Patterns registered for replacement:")
try:
from torch._inductor.pattern_matcher import PatternPrettyPrinter
for pattern_entry in self.pattern_pass.patterns.values():
for p in pattern_entry:
p_str = PatternPrettyPrinter.run(p.pattern)
logger.debug("Pattern %d: %s", pattern_idx, p_str)
pattern_idx += 1
except ImportError:
logger.debug("PatternPrettyPrinter not available, skipping pattern printing")
return graph
def register_pattern(
self,
name: str,
pattern: Callable[..., Any],
replacement: Callable[..., Any],
example_inputs: List[torch.Tensor],
):
if name in self.pattern_replacements:
logger.error(
"[MindIE-SD/compilation] Pattern registration failed. "
"issue=pattern name already registered, pattern_name=%s, expected=unique pattern name. "
"possible_cause=activate_pattern_once or custom registration was called repeatedly with the same name. "
"Troubleshooting: check pattern registration order and avoid duplicate names.",
name,
)
raise ValueError(f"Pattern '{name}' is already registered.")
self.pattern_replacements[name] = (pattern, replacement)
logger.debug("Registering pattern: %s", name)
if not hasattr(pm, "fwd_only"):
if IS_TORCH_21:
pm.fwd_only = mindie_inference_graph
else:
logger.warning(
"[MindIE-SD/compilation] Pattern replacement preparation failed. "
"issue=torch._inductor.pattern_matcher.fwd_only is unavailable, torch_version=%s, "
"expected=fwd_only API exists or torch version is 2.1 for compatibility patch. "
"possible_cause=current torch version does not provide the expected inductor API. "
"Troubleshooting: verify torch version compatibility and pattern registration stack.",
torch.__version__,
)
def fwd_only_with_custom_decomp(
fn: Callable[..., Any],
args: Sequence[Any],
*,
run_functional_passes: bool = True,
get_decomp_fn: Optional[Callable[..., Any]] = select_custom_decomp_table,
) -> torch.fx.GraphModule:
if IS_TORCH_21:
return pm.fwd_only(fn=fn, args=args)
else:
return pm.fwd_only(
fn=fn, args=args, run_functional_passes=run_functional_passes, get_decomp_fn=get_decomp_fn
)
try:
pm.register_replacement(
pattern,
replacement,
example_inputs,
fwd_only_with_custom_decomp,
self.pattern_pass.patterns,
)
logger.debug("Successfully register pattern: %s", name)
except RuntimeError as e:
if "Duplicate pattern" in str(e):
logger.debug(
"[MindIE-SD/compilation] Duplicate pattern registration skipped. "
"pattern_name=%s, possible_cause=the same pattern was activated more than once.",
name,
)
else:
logger.error(
"[MindIE-SD/compilation] Pattern registration failed. "
"issue=torch inductor register_replacement raised RuntimeError, pattern_name=%s, "
"actual_error=%s. possible_cause=pattern, replacement, or example_inputs are incompatible. "
"Troubleshooting: inspect the pattern definition, replacement function schema, example input "
"shape/dtype, and torch._inductor stack.",
name,
e,
)
raise e