from functools import lru_cache
from typing import Any, Callable, List
from ..passes.pattern_match_pass import PatternMatchPass
from . import rms_norm, rotary_embedding, swiglu
all_passes = [
PatternMatchPass(),
PatternMatchPass(),
PatternMatchPass(),
]
def register_pattern(
name: str,
pattern: Callable[..., Any],
replacement: Callable[..., Any],
example_inputs: List[Any],
level=0,
scalar_workaround: dict[str, Any] | None = None,
):
if level >= len(all_passes):
raise ValueError(f"Invalid level {level}, must be less than {len(all_passes)}")
all_passes[level].register_pattern(
name,
pattern,
replacement,
example_inputs,
scalar_workaround=scalar_workaround,
)
@lru_cache(None)
def lazy_init():
rms_norm.register_all_patterns()
rotary_embedding.register_all_patterns()
swiglu.register_all_patterns()