from functools import wraps
from typing import Any
from typing_extensions import Self
def version_wrapper(func):
@wraps(func)
def wrapper(name, *args, **kwargs):
if name == 'transformer-engine':
return '2.2.0'
elif name == 'flash_attn':
return '1.0'
return func(name, *args, **kwargs)
return wrapper
def _load_from_state_dict_wrapper(fn):
def wrapper(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
fn(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
from megatron.training import get_args
args = get_args()
if not getattr(args, 'save_lora_to_hf', False) and hasattr(args, 'lora_target_modules') and args.lora_target_modules:
if not any(('lora_a' in key.lower() or 'lora_b' in key.lower()) and key.endswith('weight') for key in state_dict):
import warnings
warnings.warn("The lora weights is missing from the checkpoint and will be randomly initialized.", RuntimeWarning)
missing_keys[:] = [key for key in missing_keys if 'lora_a' not in key.lower() and 'lora_b' not in key.lower()]
filter_keys = ('absmax', 'quant_map', 'bitsandbytes__nf4', '_extra_state')
unexpected_keys[:] = [key for key in unexpected_keys if not key.endswith(filter_keys)]
return wrapper