import copy
import functools
import re
import types
from typing import Iterable, List, Optional, Tuple
import torch
from megatron.core.transformer.module import MegatronModule
from rich.table import Table
from transformers.configuration_utils import PretrainedConfig
def unwrap_model(model, module_instances=None):
"""Unwrap_model to return the final model instance"""
if module_instances is None:
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
from megatron.core.transformer.module import Float16Module
module_instances = (DDP, torch_FSDP, Float16Module)
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def weights_verification_table(bridge, megatron_model) -> Table:
"""
Returns a table comparing weights between a Hugging Face model and a Megatron-LM model.
Args:
bridge (AutoBridge): The bridge object containing model information.
megatron_model: The Megatron-LM model instance.
Returns:
Table: A rich Table object with the comparison.
"""
table = Table(title="Hugging Face Weights Verification")
table.add_column("Weight Name", style="cyan")
table.add_column("Shape")
table.add_column("DType")
table.add_column("Device")
table.add_column("Matches Original", justify="center")
for name, param in bridge.export_hf_weights(megatron_model, show_progress=True):
original_param = bridge.hf_pretrained.state[name]
table.add_row(
name,
str(tuple(param.shape)),
str(param.dtype).replace("torch.", ""),
str(param.device),
"✅" if torch.allclose(param, original_param.to(param.device), atol=1e-6) else "❌",
)
return table
def get_module_and_param_from_name(
models: MegatronModule | List[MegatronModule],
param_name: str,
vp_stage: Optional[int] = None,
) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]:
"""
Get parameter from specific VP stage, ensuring that parameter
attributes are preserved. Supports both absolute and relative parameter names.
Args:
models: List of Megatron model instances or a submodule
param_name: Dot-separated parameter name (can be absolute or relative to models)
vp_stage: Virtual pipeline stage index (None for single stage)
Returns:
Tuple of (module, parameter) where module owns the parameter
Raises:
ValueError: If vp_stage is out of range or parameter doesn't exist
"""
if isinstance(models, list):
if vp_stage is None:
model = models[0]
else:
if vp_stage >= len(models):
raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})")
model = models[vp_stage]
else:
model = models
module = unwrap_model(model)
splitted_name = param_name.split(".")
def try_get_param(parts):
param = module
temp_module = module
for i, part in enumerate(parts):
if not hasattr(param, part):
return None
param = getattr(param, part)
if i < len(parts) - 1:
temp_module = getattr(temp_module, part)
return temp_module, param
result = try_get_param(splitted_name)
if result is not None:
return result
if not param_name.startswith("predictor."):
predictor_name = f"predictor.{param_name}"
predictor_splitted = predictor_name.split(".")
result = try_get_param(predictor_splitted)
if result is not None:
return result
for start_idx in range(1, len(splitted_name)):
suffix_parts = splitted_name[start_idx:]
result = try_get_param(suffix_parts)
if result is not None:
return result
raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}")
def remove_non_pickleables(obj, max_depth: int = 3, current_depth: int = 0):
"""Remove non-pickleable objects from a configuration object recursively.
This utility function identifies and removes objects that cannot be pickled for
inter-process communication, including functions, bound methods, partial
functions, and other problematic callables.
Args:
obj: The object to clean
max_depth: Maximum recursion depth (default: 3)
current_depth: Current recursion depth (internal use)
Returns:
The cleaned object with non-pickleables removed
"""
if current_depth >= max_depth:
return obj
if obj is None:
return obj
if callable(obj):
if isinstance(obj, type):
return obj
elif isinstance(obj, (types.FunctionType, types.MethodType, functools.partial)) or hasattr(obj, "__self__"):
return None
if hasattr(obj, "__dict__"):
cleaned_obj = copy.copy(obj)
for attr_name in list(vars(cleaned_obj).keys()):
attr_value = getattr(cleaned_obj, attr_name)
cleaned_value = remove_non_pickleables(attr_value, max_depth, current_depth + 1)
setattr(cleaned_obj, attr_name, cleaned_value)
return cleaned_obj
elif isinstance(obj, list):
return [remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj]
elif isinstance(obj, tuple):
return tuple(remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj)
elif isinstance(obj, dict):
return {key: remove_non_pickleables(value, max_depth, current_depth + 1) for key, value in obj.items()}
return obj
def extract_sort_key(param_name: str):
"""Extract sorting key based on layer and expert numbers."""
numbers = []
layer_match = re.search(r"layers\.(\d+)", param_name)
if layer_match:
numbers.append(int(layer_match.group(1)))
expert_match = re.search(r"(?:bias|weight)(\d+)", param_name)
if expert_match:
numbers.append(int(expert_match.group(1)))
while len(numbers) < 2:
numbers.append(-1)
numbers = numbers[:2]
return numbers, param_name
def persistent_buffers(model: torch.nn.Module) -> Iterable[Tuple[str, torch.Tensor]]:
"""Return an iterator over persistent module buffers, yielding both the name of the buffer as well as the buffer itself."""
for mod_prefix, mod in model.named_modules():
for local_name, buffer in mod.named_buffers(recurse=False):
if local_name not in getattr(mod, "_non_persistent_buffers_set", set()):
full_name = f"{mod_prefix + '.' if mod_prefix else ''}{local_name}"
yield full_name, buffer