import torch
from .swap_utils import print_with_rank, PrintLevel
def get_module_name(module: torch.nn.Module):
return module.__module__ + "." + module.__class__.__name__
class SwapHookRegister:
id = 0
def __init__(self):
self.id = SwapHookRegister.id
SwapHookRegister.id += 1
self.fwd_pre_hook_handle = None
self.fwd_post_hook_handle = None
self.bwd_pre_hook_handle = None
self.bwd_post_hook_handle = None
self.fwd_begin_module: torch.nn.Module = None
self.fwd_end_module: torch.nn.Module = None
self.bwd_begin_module: torch.nn.Module = None
self.bwd_end_module: torch.nn.Module = None
self.fwd_idx = 0
self.bwd_idx = 0
self.prehook_handles = []
self.posthook_handls = []
self.fwd_pre_hook_custom_func = None
self.fwd_post_hook_custom_func = None
self.bwd_pre_hook_custom_func = None
self.bwd_post_hook_custom_func = None
def __del__(self):
r"""if not need swap hook to module, del it."""
self.reset()
if self.fwd_pre_hook_handle:
self.fwd_pre_hook_handle.remove()
if self.fwd_post_hook_handle:
self.fwd_post_hook_handle.remove()
if self.bwd_pre_hook_handle:
self.bwd_pre_hook_handle.remove()
if self.bwd_post_hook_handle:
self.bwd_post_hook_handle.remove()
def reset(self):
self.fwd_begin_module = None
self.fwd_end_module = None
self.bwd_begin_module = None
self.bwd_end_module = None
self.fwd_idx = 0
self.bwd_idx = 0
for hdl in self.prehook_handles:
hdl.remove()
for hdl in self.posthook_handls:
hdl.remove()
self.prehook_handles.clear()
self.posthook_handls.clear()
def register_custom_func(
self, fwd_pre_hook_custom_func, fwd_post_hook_custom_func, bwd_pre_hook_custom_func, bwd_post_hook_custom_func
):
r"""
custom_func(instance_id, fwd_or_bwd_idx)
"""
self.fwd_pre_hook_custom_func = fwd_pre_hook_custom_func
self.fwd_post_hook_custom_func = fwd_post_hook_custom_func
self.bwd_pre_hook_custom_func = bwd_pre_hook_custom_func
self.bwd_post_hook_custom_func = bwd_post_hook_custom_func
def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
print_with_rank(message, prefix="SwapHook", print_level=print_level)
def register_hook_to_grad_fn(self, input_tensor, position, is_bwd_pre):
def grad_fn_bwd_pre_hook(grad_outputs):
self.bwd_idx += 1
self.print_with_rank(f"grad_fn_bwd_pre_hook: bwd begin, id[{self.id}], bwd_idx[{self.bwd_idx}]")
if self.bwd_pre_hook_custom_func:
self.bwd_pre_hook_custom_func(self.id, self.bwd_idx)
return grad_outputs
def grad_fn_bwd_post_hook(grad_inputs, _):
self.print_with_rank(f"grad_fn_bwd_post_hook: bwd end, id[{self.id}], bwd_idx[{self.bwd_idx}]")
if self.bwd_post_hook_custom_func:
self.bwd_post_hook_custom_func(self.id, self.bwd_idx)
return grad_inputs
if is_bwd_pre:
self.print_with_rank(f"{position}, register grad_fn_bwd_pre_hook to grad_fn: {input_tensor.grad_fn}")
self.prehook_handles.append(input_tensor.grad_fn.register_prehook(grad_fn_bwd_pre_hook))
else:
self.print_with_rank(f"{position}, register grad_fn_bwd_post_hook to grad_fn: {input_tensor.grad_fn}")
self.posthook_handls.append(input_tensor.grad_fn.register_hook(grad_fn_bwd_post_hook))
def register_hook_to_bwd_end_module(self, module, inputs, position):
if not self.bwd_end_module or (self.bwd_end_module and module is self.bwd_end_module):
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
if isinstance(inputs, tuple):
for input_item in inputs:
if not isinstance(input_item, torch.Tensor):
continue
if (input_item.requires_grad and not input_item.is_leaf) and input_item.grad_fn:
if not self.bwd_end_module:
self.bwd_end_module = module
self.print_with_rank(f"{position}, set bwd_end_module: {get_module_name(module)}")
self.register_hook_to_grad_fn(input_item, position, is_bwd_pre=False)
break
def register_hook_to_bwd_begin_module(self, module, inputs, position):
if self.bwd_begin_module and module is self.bwd_begin_module:
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
if isinstance(inputs, tuple):
for input_item in inputs:
if not isinstance(input_item, torch.Tensor):
continue
if (input_item.requires_grad and not input_item.is_leaf) and input_item.grad_fn:
self.register_hook_to_grad_fn(input_item, position, is_bwd_pre=True)
break
def fwd_pre_hook(self, module, args):
self.print_with_rank(f"fwd_pre_hook, {get_module_name(module)}")
if not self.fwd_begin_module:
self.fwd_begin_module = module
self.fwd_end_module = module
self.bwd_begin_module = module
self.print_with_rank(
f"fwd_pre_hook: set fwd_begin_module, fwd_end_module and bwd_begin_module: {get_module_name(module)}"
)
if self.fwd_begin_module and module is self.fwd_begin_module:
self.fwd_idx += 1
self.print_with_rank(
f"fwd_pre_hook: fwd begin, id[{self.id}], fwd_idx[{self.fwd_idx}], {get_module_name(module)}"
)
if self.fwd_pre_hook_custom_func:
self.fwd_pre_hook_custom_func(self.id, self.fwd_idx)
self.register_hook_to_bwd_end_module(module, args, "fwd_pre_hook")
return None
def fwd_post_hook(self, module, _, outputs):
self.print_with_rank(f"fwd_post_hook, {get_module_name(module)}")
if self.fwd_end_module and module is self.fwd_end_module:
self.print_with_rank(
f"fwd_post_hook: fwd end, id[{self.id}], fwd_idx[{self.fwd_idx}], {get_module_name(module)}"
)
if self.fwd_post_hook_custom_func:
self.fwd_post_hook_custom_func(self.id, self.fwd_idx)
self.register_hook_to_bwd_begin_module(module, outputs, "fwd_post_hook")
self.register_hook_to_bwd_end_module(module, outputs, "fwd_post_hook")
return None
def register_hooks_to_modules_recursively(self, module, name=""):
self.print_with_rank(f"register_hooks_to_modules_recursively, {get_module_name(module)}")
for child_name, child in module.named_children():
self.register_hooks_to_modules_recursively(child, name + child_name)
def module_fwd_pre_hook(module, args):
return self.fwd_pre_hook(module, args)
def module_fwd_post_hook(module, args, outputs):
return self.fwd_post_hook(module, args, outputs)
self.fwd_pre_hook_handle = module.register_forward_pre_hook(module_fwd_pre_hook)
self.fwd_post_hook_handle = module.register_forward_hook(module_fwd_post_hook)
def register_swap_hooks_to_modules(
module,
fwd_pre_hook_custom_func=None,
fwd_post_hook_custom_func=None,
bwd_pre_hook_custom_func=None,
bwd_post_hook_custom_func=None,
):
r"""
usage:
# before training
models = [model_1, model_2, ...]
swap_hook_registers = []
def fwd_pre_hook_custom_func(swap_hook_register_id, fwd_idx):
...
def fwd_post_hook_custom_func(swap_hook_register_id, fwd_idx):
...
def bwd_pre_hook_custom_func(swap_hook_register_id, bwd_idx):
...
def bwd_post_hook_custom_func(swap_hook_register_id, bwd_idx):
...
for model in models:
import smart_swap
swap_hook_register = smart_swap.xxx.register_swap_hooks_to_modules(.
model,
fwd_pre_hook_custom_func, fwd_post_hook_custom_func
bwd_pre_hook_custom_func, bwd_post_hook_custom_func)
swap_hook_registers.append(swap_hook_register)
# when training
for step in range(train_steps):
for swap_hook_register in swap_hook_registers:
swap_hook_register.reset()
train_step(xxx)
# after training
for swap_hook_register in swap_hook_registers:
del swap_hook_register
"""
swap_hook_register = SwapHookRegister()
swap_hook_register.register_hooks_to_modules_recursively(module)
swap_hook_register.register_custom_func(
fwd_pre_hook_custom_func, fwd_post_hook_custom_func, bwd_pre_hook_custom_func, bwd_post_hook_custom_func
)
return swap_hook_register
class SwapOptimizerHookRegister:
def __init__(self):
self.optimizer_pre_hook_handle = None
self.optimizer_pre_hook_custom_func = None
def __del__(self):
r"""if not need swap hook to module, del it."""
if self.optimizer_pre_hook_handle:
self.optimizer_pre_hook_handle.remove()
def register_optimizer_custom_func(self, optimizer_pre_hook_custom_func):
self.optimizer_pre_hook_custom_func = optimizer_pre_hook_custom_func
def optim_pre_hook(self, optimizer, args, kwargs):
if self.optimizer_pre_hook_custom_func:
self.optimizer_pre_hook_custom_func()
return None
def register_hooks_to_optimizer(self, optimizer):
def optimizer_pre_hook(optimizer, args, kwargs):
return self.optim_pre_hook(optimizer, args, kwargs)
self.optimizer_pre_hook_handle = optimizer.register_step_pre_hook(optimizer_pre_hook)
def register_swap_hooks_to_optimizers(
optimizers,
optimizer_pre_hook_custom_func=None,
):
swap_optim_hook_register = SwapOptimizerHookRegister()
for optimizer in optimizers:
swap_optim_hook_register.register_hooks_to_optimizer(optimizer)
swap_optim_hook_register.register_optimizer_custom_func(optimizer_pre_hook_custom_func)
return swap_optim_hook_register