import re
from typing import List, Optional
from bridge.models.conversion.param_mapping import MegatronParamMapping
class MegatronMappingRegistry:
def _convert_pattern_to_regex(self, pattern: str) -> str:
"""Convert a pattern with wildcards to regex pattern.
Args:
pattern: Pattern string with * and ** wildcards
Returns:
Regex pattern string
Note:
** must be processed before * to avoid conflicts.
** becomes (.*) - matches any characters including dots
* becomes (\\d+) - matches digits only for layer indices
"""
regex_pattern = re.escape(pattern)
regex_pattern = regex_pattern.replace(r"\*\*", r"(.*)")
regex_pattern = regex_pattern.replace(r"\*", r"(\d+)")
return regex_pattern
def __init__(self, *mappings: MegatronParamMapping):
"""
Initialize MegatronMappingRegistry with weight mappings.
Args:
*mappings: MegatronParamMapping objects
"""
self.mappings = list(mappings)
self._compiled_patterns = []
self._reverse_patterns = []
for mapping in mappings:
if "*" in mapping.megatron_param:
pattern = self._convert_pattern_to_regex(mapping.megatron_param)
self._compiled_patterns.append((re.compile(f"^{pattern}$"), mapping))
else:
self._compiled_patterns.append((None, mapping))
if isinstance(mapping.hf_param, str):
if "*" in mapping.hf_param:
pattern = self._convert_pattern_to_regex(mapping.hf_param)
self._reverse_patterns.append((re.compile(f"^{pattern}$"), mapping))
else:
self._reverse_patterns.append((None, mapping))
else:
reverse_dict_patterns = {}
for key, hf_pattern in mapping.hf_param.items():
if "*" in hf_pattern:
pattern = self._convert_pattern_to_regex(hf_pattern)
reverse_dict_patterns[key] = re.compile(f"^{pattern}$")
else:
reverse_dict_patterns[key] = None
self._reverse_patterns.append((reverse_dict_patterns, mapping))
def megatron_to_hf_lookup(self, megatron_param_name: str) -> Optional[MegatronParamMapping]:
"""
Get mapping for a Megatron parameter name.
This method performs efficient lookups by first checking for exact matches,
then falling back to pattern matching using pre-compiled regex patterns.
When a pattern match is found, wildcards are automatically resolved.
Args:
megatron_param_name: Megatron parameter name to look up
Example: "decoder.layers.0.self_attention.linear_qkv.weight"
Returns:
MegatronParamMapping: Bridge instance with resolved wildcards, or None
if no matching mapping is found. The returned bridge will have
all wildcards replaced with actual values.
"""
for pattern, mapping in self._compiled_patterns:
if pattern is None:
if mapping.megatron_param == megatron_param_name:
return mapping
else:
match = pattern.match(megatron_param_name)
if match:
return mapping.resolve(match.groups())
return None
def __len__(self) -> int:
"""Return number of mappings."""
return len(self.mappings)
def __iter__(self):
"""Iterate over mappings."""
return iter(self.mappings)
def __repr__(self) -> str:
"""String representation of MegatronMappingRegistry."""
return f"MegatronMappingRegistry({len(self.mappings)} mappings)"