import logging
import functools
from torch.utils.checkpoint import checkpoint
from mindspeed.fsdp.utils.log import print_rank
from mindspeed.fsdp.utils.str_match import module_name_match
logger = logging.getLogger(__name__)
def recompute_modules(model, plan):
modules = get_recompute_modules(model, plan.apply_modules)
for name, module in modules:
print_rank(logger.info, f'Applying recompute to module: {name}')
module.forward = recompute_wrapper(module.forward, plan.use_reentrant)
return model
def get_recompute_modules(modules, plan):
matched_modules = []
for plan_name in plan:
for name, module in modules.named_modules():
if module_name_match(plan_name, name):
matched_modules.append((name, module))
if len(matched_modules) == 0:
raise RuntimeError(f'[Recompute] No module named {plan}.')
return matched_modules
def recompute_wrapper(function, use_reentrant):
def wrapper(*args, **kwargs):
kwargs['past_key_values'] = None
if not use_reentrant:
return checkpoint(function, *args, use_reentrant=use_reentrant, **kwargs)
else:
bound_function = functools.partial(function, **kwargs)
return checkpoint(bound_function, *args, use_reentrant=use_reentrant)
return wrapper