from functools import wraps
import torch
from megatron.core.parallel_state import get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size
from megatron.training import get_args
from mindspeed.core.optimizer.virtual_optimizer.virtual_adam import (
virtual_optimizer_step_impl,
virtual_optimizer_replace,
VirtualAllocator
)
def virtual_optimizer_step(self, closure=None):
if not hasattr(self, "virtual_allocator"):
self.virtual_allocator = get_global_virtual_allocator()
self.print_swap_flag = not hasattr(self, "print_swap_flag")
with torch.no_grad():
loss = virtual_optimizer_step_impl(self, closure)
return loss
def get_global_virtual_allocator():
"""
Get global virtual allocator.
"""
args = get_args()
if not hasattr(args, "virtual_allocator"):
args.virtual_allocator = VirtualAllocator(
get_pipeline_model_parallel_rank(),
get_pipeline_model_parallel_world_size(),
get_args().virtual_optimizer)
return args.virtual_allocator
def replace_swap_tensor_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
res = fn(self, *args, **kwargs)
virtual_allocator = get_global_virtual_allocator()
virtual_optimizer_replace(self.optimizer, virtual_allocator)
return res
return wrapper