import threading
once_flag = threading.Event()
_once_lock = threading.Lock()
def activate_pattern_once():
def activate_pattern():
import importlib
from ..compiliation_config import CompilationConfig
from .register_pattern_to_pass import register_pattern_to_pass
pattern_registry = {
"enable_rms_norm": ("RMSNormPatternGroup", "..patterns"),
"enable_rope": ("RopePatternGroup", "..patterns"),
"enable_adalayernorm": ("AdaLayerNormPatternGroup", "..patterns"),
"enable_fast_gelu": ("GELUPatternGroup", "..patterns"),
"enable_mul_add": ("MulAddPatternGroup", "..patterns"),
}
fusion_config = CompilationConfig.fusion_patterns
for config_key, (pattern_group_name, pattern_module) in pattern_registry.items():
if getattr(fusion_config, config_key, False):
patterns_module = importlib.import_module(pattern_module, package=__package__)
pattern_group = getattr(patterns_module, pattern_group_name)
register_pattern_to_pass(pattern_group)
if not once_flag.is_set():
with _once_lock:
if not once_flag.is_set():
activate_pattern()
once_flag.set()