from types import MethodType
import torch
from accelerate.utils.other import is_compiled_module
from accelerate.utils.imports import is_deepspeed_available
from accelerate.utils.transformer_engine import convert_model
def extract_model_from_parallel(
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
):
"""
Extract a model from its distributed containers.
Args:
model (`torch.nn.Module`):
The model to extract.
keep_fp32_wrapper (`bool`, *optional*):
Whether to remove mixed precision hooks from the model.
keep_torch_compile (`bool`, *optional*):
Whether to unwrap compiled model.
recursive (`bool`, *optional*, defaults to `False`):
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
recursively, not just the top-level distributed containers.
Returns:
`torch.nn.Module`: The extracted model.
"""
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
is_compiled = is_compiled_module(model)
if is_compiled:
compiled_model = model
model = model._orig_mod
if is_deepspeed_available():
from accelerate.utils.deepspeed import DeepSpeedEngine
options += (DeepSpeedEngine,)
while isinstance(model, options):
model = model.module
if recursive:
def _recursive_unwrap(module):
if hasattr(module, "module"):
unwrapped_module = _recursive_unwrap(module.module)
else:
unwrapped_module = module
for name, child in unwrapped_module.named_children():
setattr(unwrapped_module, name, _recursive_unwrap(child))
return unwrapped_module
model = _recursive_unwrap(model)
if not keep_fp32_wrapper:
forward = model.forward
original_forward = model.__dict__.pop("_original_forward", None)
if original_forward is not None:
while hasattr(forward, "__wrapped__"):
forward = forward.__wrapped__
if forward == original_forward:
break
model.forward = MethodType(forward, model)
if getattr(model, "_converted_to_transformer_engine", False):
convert_model(model, to_transformer_engine=False)
if keep_torch_compile and is_compiled:
compiled_model._orig_mod = model
model = compiled_model
return model