__all__ = ["compile_fx", "register_replacement"]
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union, Match, List
try:
from torch._inductor.pattern_matcher import fwd_only, SearchFn, ReplaceFn, TraceFn, PatternExpr
except ImportError:
from torch._inductor.pattern_matcher import inference_graph as fwd_only
from . import inference
from . import scope
def compile_fx(gm, example_inputs=None, options=None):
import npugraph_ex
return npugraph_ex.compile_fx(gm, example_inputs, options)
def _return_true(match: Match):
return True
def register_replacement(search_fn: SearchFn, replace_fn: ReplaceFn, example_inputs: Iterable[Any],
trace_fn: TraceFn = fwd_only, extra_check: Callable[[Match], bool] = _return_true,
search_fn_pattern: Union[PatternExpr, None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
skip_duplicates: bool = False):
import npugraph_ex
return npugraph_ex.patterns.pattern_pass_manager.register_replacement(search_fn, replace_fn, example_inputs,
trace_fn=trace_fn, extra_check=extra_check,
search_fn_pattern=search_fn_pattern,
scalar_workaround=scalar_workaround,
skip_duplicates=skip_duplicates)