from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.optim import AdamW
class SwapOptimizerOperate():
swap_to_device_stream = None
swap_to_host_stream = None
swap_to_device_events_map = {}
swap_to_host_events_map = {}
param_to_cpu_states_map = {}
param_to_device_states_map = {}
state_keys = ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']
swap_state_keys = set()
def __init__(self, mem_fraction_static=0.8, state_keys=None):
if state_keys is not None:
self.state_keys = state_keys
if SwapOptimizerOperate.swap_to_device_stream is None:
SwapOptimizerOperate.swap_to_device_stream = torch.accelerator.Stream()
SwapOptimizerOperate.swap_to_host_stream = torch.accelerator.Stream()
self.param_to_group_map = {}
for group in self.param_groups:
for p in group['params']:
self.param_to_group_map[p] = group
self.opt_states_initialization()
self.mem_fraction_static = mem_fraction_static
self.memory_data_initialization()
def opt_states_initialization(self):
for group in self.param_groups:
for param in group["params"]:
device_state_dtensor = self.state[param]
device_state_tensor = {}
cpu_state = {}
amsgrad = self.param_to_group_map[param]['amsgrad']
for key in self.state_keys:
if key == 'max_exp_avg_sq' and not amsgrad:
device_state_dtensor[key] = None
device_state_tensor[key] = None
cpu_state[key] = None
else:
self.swap_state_keys.add(key)
device_state_dtensor[key] = torch.zeros_like(param, memory_format=torch.preserve_format)
device_state_tensor[key] = device_state_dtensor[key].to_local()
cpu_state[key] = torch.empty_like(device_state_tensor[key], pin_memory=True, device='cpu')
cpu_state[key].copy_(device_state_tensor[key], non_blocking=True)
device_state_tensor[key].storage().resize_(0)
self.param_to_device_states_map[param] = device_state_tensor
self.param_to_cpu_states_map[param] = cpu_state
def memory_data_initialization(self):
params = list(self.param_to_device_states_map.keys())
self.byte_param = 4 if params[0].dtype == torch.float32 else 2
self.total_memory = torch.accelerator.get_device_properties(torch.accelerator.current_device()).total_memory
def get_swap_numel_from_unused_memory(self):
used_memory = torch.accelerator.memory_allocated()
unused_memory = self.total_memory - used_memory
swap_numel = unused_memory * self.mem_fraction_static // (self.byte_param * len(self.swap_state_keys))
return swap_numel
def swap_all_to_host(self):
for param in self.param_to_cpu_states_map.keys():
self.swap_tensors_to_host(param)
for param in self.param_to_cpu_states_map.keys():
event = self.swap_to_host_events_map.get(param, None)
if event is not None:
torch.accelerator.current_stream().wait_event(event)
self.swap_to_host_events_map[param] = None
def swap_all_to_device(self):
for param in self.param_to_cpu_states_map.keys():
self.swap_tensors_to_device(param)
for param in self.param_to_cpu_states_map.keys():
event = self.swap_to_device_events_map.get(param, None)
if event is not None:
torch.accelerator.current_stream().wait_event(event)
self.swap_to_device_events_map[param] = None
def swap_tensors_to_device(self, param):
cpu_state = self.param_to_cpu_states_map[param]
if param in self.param_to_device_states_map:
device_state = self.param_to_device_states_map[param]
for key in self.state_keys:
if device_state[key] is not None and device_state[key].storage().size() == 0:
device_state[key].storage().resize_(cpu_state[key].storage().size())
device_state[key].copy_(cpu_state[key], non_blocking=True)
self.swap_to_device_events_map[param] = torch.accelerator.current_stream().record_event()
def wait_swap_to_device_event(self, param):
event = self.swap_to_device_events_map.get(param, None)
if event is not None:
torch.accelerator.current_stream().wait_event(event)
self.swap_to_device_events_map[param] = None
def swap_tensors_to_host(self, param):
cpu_state = self.param_to_cpu_states_map[param]
if param in self.param_to_device_states_map:
device_state = self.param_to_device_states_map[param]
for key in self.state_keys:
if key in device_state and device_state[key] is not None and device_state[key].storage().size() != 0:
cpu_state[key].copy_(device_state[key], non_blocking=True)
device_state[key].storage().resize_(0)
self.swap_to_host_events_map[param] = torch.accelerator.current_stream().record_event()
def swap_batch_tensor_to_device(self, params_list, index):
torch.accelerator.current_stream().wait_stream(self.swap_to_host_stream)
swap_count = 0
with torch.accelerator.stream(self.swap_to_device_stream):
torch.accelerator.current_stream().wait_stream(self.swap_to_host_stream)
self.swap_numel = self.get_swap_numel_from_unused_memory()
while index < len(params_list) and (swap_count + params_list[index].to_local().numel() <= self.swap_numel):
self.swap_tensors_to_device(params_list[index])
swap_count += params_list[index].to_local().numel()
index += 1
if swap_count == 0:
raise AssertionError(
"OOM, the amount of data transferred for optimizer states "
"from host to device is 0. You can try increasing "
"mem_fraction_static."
)
return swap_count
class AdamWSwap(AdamW, SwapOptimizerOperate):
def __init__(
self,
params,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
mem_fraction_static: Optional[float] = 0.8,
state_keys: Optional[List[str]] = None,
):
"""
This is a class that supports swapping optimizer states from the device to the host side
to reduce peak GPU memory usage.
During non-step phases, to avoid the additional GPU memory overhead from optimizer states,
they are swapped from the device side to the host side. This operation is executed in the
__init__ method of SwapOptimizerOperate.
During the step execution, to further minimize peak memory usage, optimizer states are
swapped from the host side back to the device side in batches, based on the available
GPU memory. This operation is performed within the step method.
Args:
mem_fraction_static(float): Allocate available GPU memory * mem_fraction_static for swapping parameters
from host to device in batches.
state_keys(list): Optimizer States That Need to be Swapped
Other parameters: Refer to the documentation of the native torch.optim.AdamW parameters.
Examples for AdamW:
>>> AdamWSwap(params, mem_fraction_static=0.9, state_keys=['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])
"""
super().__init__(params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
foreach=False,
maximize=maximize,
capturable=False,
differentiable=False,
fused=True,)
SwapOptimizerOperate.__init__(self, mem_fraction_static=mem_fraction_static)
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if 'step' in group:
group['step'] += 1
if group['step'].is_cpu:
group['step'] = group['step'].cuda()
else:
group['step'] = torch.tensor(1, dtype=torch.int64, device=torch.accelerator.current_device_index())
swap_count = 0
params_list = list(self.param_to_group_map.keys())
for i, param in enumerate(params_list):
if param.grad is None:
continue
if param.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
group = self.param_to_group_map[param]
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
state = self.state[param]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
if 'max_exp_avg_sq' not in state:
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) if amsgrad else None
if swap_count == 0:
swap_count = self.swap_batch_tensor_to_device(params_list, i)
self.wait_swap_to_device_event(param)
torch._fused_adamw_([param.to_local()], [param.grad.to_local()], [state['exp_avg'].to_local()], [state['exp_avg_sq'].to_local()], [state['exp_avg_sq'].to_local()] if amsgrad else [],
[group['step']], lr=group['lr'], beta1=beta1, beta2=beta2, weight_decay=group['weight_decay'],
eps=group['eps'], amsgrad=amsgrad, maximize=group['maximize'])
with torch.accelerator.stream(self.swap_to_host_stream):
swap_count -= param.to_local().numel()
self.swap_tensors_to_host(param)
return loss