import torch
import torch_npu
from mindspeed.optimizer.adamw import adamw
def virtual_optimizer_replace(optimizer, virtual_allocator):
if not optimizer.state:
return
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
state["exp_avg"] = virtual_allocator.copy2swap(state["exp_avg"])
state["exp_avg_sq"] = virtual_allocator.copy2swap(state["exp_avg_sq"])
def virtual_optimizer_step_impl(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
if 'step' in group:
group['step'] += 1
if group['step'].is_cpu:
group['step'] = group['step'].cuda()
else:
group['step'] = torch.tensor(1, dtype=torch.int64, device=torch.cuda.current_device())
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
state['exp_avg'], state['exp_avg_sq'] = self.virtual_allocator.init_exp(p)
if amsgrad:
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
group['step'],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=group['maximize'])
self.virtual_allocator.print_swap_size(self.print_swap_flag)
return loss
class VirtualAllocator:
def __init__(self, pp_rank, pp_stages, virtual_optimizer_size=None):
"""
Initializes an VirtualAllocator instance.
Args:
pp_rank (int): pipeline model rank number.
pp_stages (int):
virtual_optimizer_size (float, List[float], Optional):
"""
self.pp_stages = pp_stages
self.pp_rank = pp_rank
self.virtual_optimizer_size = virtual_optimizer_size
self.swap_size_this_pp_rank = self.get_swap_memory_size()[self.pp_rank] * (1024 ** 3)
self.actually_swap_size = 0
self.print_flag = False
def get_swap_memory_size(self):
"""
Obtain the swap size of each pipeline rank based on the given parameters.
"""
swap_sizes = self.virtual_optimizer_size
if swap_sizes is not None:
if len(swap_sizes) == 1:
return [swap_sizes[0] for _ in range(self.pp_stages)]
elif len(swap_sizes) == self.pp_stages:
return swap_sizes
else:
raise ValueError("Virtual_optimizer swap size do not match pp_stages or swap ALL.")
else:
return [0 for _ in range(self.pp_stages)]
def init_exp(self, p: torch.Tensor):
"""
Create exp_avg and exp_avg_sq based on input param.
Args:
p (tensor): param tensor.
"""
exp_avg = self.create(p)
exp_avg_sq = self.create(p)
return exp_avg, exp_avg_sq
def create(self, p: torch.Tensor):
"""
Select the tensor construction mode based on the size of each PP swap.
Args:
p (tensor): param tensor.
"""
if self.swap_size_this_pp_rank > 0:
return self.get_swap_memory(p)
else:
return self.get_npu_memory(p)
def copy2swap(self, input_: torch.Tensor):
"""
Copy input_ to swap tensor.
Args:
input_ (torch.Tensor): Switch input_ to swap type.
"""
swap = self.create(input_)
swap.copy_(input_)
return swap
def get_swap_memory(self, p: torch.Tensor):
"""
Construct tensors in virtual memory.
Args:
p (tensor): param tensor.
"""
if not hasattr(torch_npu, "empty_with_swapped_memory"):
return self.get_npu_memory(p)
try:
swap_tensor = torch_npu.empty_with_swapped_memory(p.size(), device=p.device)
setattr(swap_tensor, "swap_tensor", True)
setattr(swap_tensor.data, "swap_tensor", True)
swap_tensor.zero_()
tensor_size = p.numel() * p.element_size()
self.actually_swap_size += tensor_size / 1024 / 1024
self.swap_size_this_pp_rank -= tensor_size
return swap_tensor
except Exception as e:
print(f"[Warning] Swap memory alloc failed: {e}")
return self.get_npu_memory(p)
def get_npu_memory(self, p: torch.Tensor):
"""
Construct tensors in hbm memory.
Args:
p (tensor): param tensor.
"""
return torch.zeros_like(p, memory_format=torch.preserve_format)
def print_swap_size(self, print_swap_flag):
"""
Print swap size when `print_swap_flag` is true.
Args:
print_swap_flag (bool): Whether to print the swap size.
"""
if print_swap_flag:
print(f"[Swap virtual-optimizer Summary: Rank {torch.distributed.get_rank()}] Swap {self.actually_swap_size:.5f} MB")