"""
Tensor swapping system for efficient memory management in neural network training.
This module provides functionality to swap tensors between device and host memory
to optimize memory usage during forward and backward passes.
"""
from collections import defaultdict
from enum import Enum
from typing import Optional, Callable, List, Dict, Tuple
import torch
class TensorLocation(Enum):
"""Enumeration of possible tensor storage locations."""
DEVICE = "device"
HOST = "host"
def is_valid_for_swap(tensor: torch.Tensor, custom_check_fn: Optional[Callable] = None) -> bool:
"""
Checks if a tensor is valid for swapping.
Args:
tensor: The tensor to validate.
custom_check_fn: Optional custom validation function.
Returns:
bool: True if tensor can be swapped, False otherwise.
"""
if (isinstance(tensor, torch.nn.parameter.Parameter) or
isinstance(getattr(tensor, '_base', None), torch.nn.parameter.Parameter)):
return False
if tensor.storage().size() <= 0:
return False
if custom_check_fn is not None and not custom_check_fn(tensor):
return False
return True
class TensorKeyManager:
"""Generates unique keys for tensor swap operations."""
def __init__(self):
self._current_layer_idx = -1
self._layer_tensor_counts: Dict[int, int] = {}
def get_key(self, layer_idx: int) -> Tuple[Tuple, bool]:
"""Generates a key for tensor swapping operations based on layer index and a flag
indicating whether the previous layer has completed.
This method computes a unique identifier (key) used to manage tensor swapping,
typically in memory or state management systems. The key is composed of the
current layer's index and an associated tensor index. It also returns a flag
indicating whether the previous layer has completed its processing.
Args:
layer_idx (int): The index of the current layer for which the swap key
is being generated.
Returns:
A tuple containing:
- tensor_key (Tuple[int, int]): A tuple of (layer_idx, tensor_index),
serving as a unique key for tensor swap operations.
- prev_layer_completed (bool): A boolean flag indicating whether
the processing of the previous layer (layer_idx - 1) is complete.
"""
prev_layer_completed = False
if layer_idx > self._current_layer_idx:
self._layer_tensor_counts[layer_idx] = 1
if layer_idx != 0:
prev_layer_completed = True
elif layer_idx == self._current_layer_idx:
self._layer_tensor_counts[layer_idx] += 1
else:
self._layer_tensor_counts = {layer_idx: 1}
self._current_layer_idx = layer_idx
tensor_index = self._layer_tensor_counts[self._current_layer_idx] - 1
tensor_key = (self._current_layer_idx, tensor_index)
return tensor_key, prev_layer_completed
def get_prefetch_keys(self, layer_idx: int, tensor_idx: int) -> List[tuple]:
"""
Get keys for tensors that should be prefetched.
Args:
layer_idx: Current layer index.
tensor_idx: Current tensor index.
Returns:
List of prefetch keys.
"""
prefetch_layer_idx = layer_idx - 1 if layer_idx >= 1 else None
if prefetch_layer_idx is None:
return []
prefetch_layer_tensor_nums = self._layer_tensor_counts[prefetch_layer_idx]
layer_tensor_nums = self._layer_tensor_counts[layer_idx]
start_idx = tensor_idx * prefetch_layer_tensor_nums // layer_tensor_nums
end_idx = (tensor_idx + 1) * prefetch_layer_tensor_nums // layer_tensor_nums
prefetch_idx = range(start_idx, end_idx)
return [(prefetch_layer_idx, prefetch_tensor_idx) for prefetch_tensor_idx in prefetch_idx]
class SwapTensor:
"""Represents a tensor that can be swapped between device and host memory."""
def __init__(self, tensor: torch.Tensor, key: tuple):
"""
Initialize swap tensor.
Args:
tensor: The original tensor to manage.
key: Unique identifier for this swap tensor.
"""
self.device_tensor = tensor
self.size = tensor.size()
self.storage_size = tensor.storage().size()
self.host_tensor = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device='cpu')
self.is_slice_tensor = tensor.storage().size() != tensor.numel()
self.current_location = TensorLocation.DEVICE
self.key = key
self.h2d_event = torch.accelerator.Event()
def async_d2h(self, stream: torch.Stream) -> None:
"""
Asynchronously copy tensor from device to host.
Args:
stream: Stream to perform the operation on.
"""
if self.current_location != TensorLocation.DEVICE:
return
forward_event = torch.accelerator.Event()
forward_event.record()
with torch.no_grad():
with torch.accelerator.stream(stream):
stream.wait_event(forward_event)
if self.is_slice_tensor:
self.host_tensor.copy_(self.device_tensor, non_blocking=True)
else:
self.host_tensor.storage().copy_(self.device_tensor.storage(), non_blocking=True)
self.current_location = TensorLocation.HOST
def wait_d2h_finished(self, stream: torch.Stream, should_wait_streams: bool) -> None:
"""
Wait for device-to-host copy to complete.
Args:
stream: The stream used for copying.
should_wait_streams: Whether to wait for streams to complete.
"""
if self.current_location != TensorLocation.HOST:
return
if should_wait_streams:
torch.accelerator.current_stream().wait_stream(stream)
torch.accelerator.default_stream().wait_stream(stream)
self.device_tensor.storage().resize_(0)
self.current_location = TensorLocation.HOST
def async_h2d(self, h2d_stream: torch.Stream,
should_resize_storage: bool, working_stream: Optional[torch.Stream] = None) -> None:
"""
Asynchronously copy tensor from host to device.
Args:
h2d_stream: Stream for host-to-device transfer.
should_resize_storage: Whether to resize device storage.
working_stream: Optional working stream to synchronize with.
"""
if self.current_location != TensorLocation.HOST:
return
backward_event = torch.accelerator.Event()
backward_event.record()
if should_resize_storage:
self.device_tensor.storage().resize_(self.storage_size)
with torch.no_grad():
with torch.accelerator.stream(h2d_stream):
h2d_stream.wait_event(backward_event)
if self.is_slice_tensor:
self.device_tensor.copy_(self.host_tensor, non_blocking=True)
else:
self.device_tensor.storage().copy_(self.host_tensor.storage(), non_blocking=True)
self.h2d_event.record()
self.current_location = TensorLocation.DEVICE
if working_stream is not None:
working_stream.wait_stream(h2d_stream)
else:
self.device_tensor.record_stream(h2d_stream)
class SingletonMeta(type):
"""
single class.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class AsyncSwapHandler(metaclass=SingletonMeta):
"""Central manager for coordinating tensor swap operations across the system."""
def __init__(self, custom_check_fn: Optional[Callable] = None,):
self.key_to_swap_tensor = {}
self.tensor_key_to_ = {}
self._key_manager = TensorKeyManager()
self._d2h_stream: Optional[torch.Stream] = None
self._h2d_stream: Optional[torch.Stream] = None
self.custom_check_fn = custom_check_fn
def set_custom_check_fn(self, fn):
"""Set a custom function to check tensor eligibility for swapping.
Args:
fn: The custom check function.
"""
self.custom_check_fn = fn
def swap_out_cur_and_release_prev(self, tensor: torch.Tensor, layer_idx: int, layer_nums: int):
"""async swap out current layer and release previous layer"""
if not is_valid_for_swap(tensor, self.custom_check_fn):
return tensor
tensor_key, prev_layer_completed = self.get_swap_key(layer_idx)
if prev_layer_completed:
self.release_device_tensor(layer_idx - 1)
swap_tensor = SwapTensor(tensor, tensor_key)
if layer_idx < layer_nums - 1:
working_stream = torch.accelerator.current_stream()
self.d2h_stream.wait_stream(working_stream)
swap_tensor.async_d2h(self.d2h_stream)
self.key_to_swap_tensor[tensor_key] = swap_tensor
return swap_tensor
def wait_cur_and_prefetch_next(self, tensor, prefetch=True):
"""wait current layer and prefetch next layer"""
if not isinstance(tensor, SwapTensor):
return tensor
h2d_stream = self.h2d_stream
d2h_stream = self.d2h_stream
working_stream = torch.accelerator.current_stream()
working_stream.wait_stream(h2d_stream)
h2d_stream.wait_stream(working_stream)
tensor.async_h2d(h2d_stream, True, working_stream)
if prefetch:
layer_idx, tensor_idx = tensor.key
self.prefetch_tensors(layer_idx, tensor_idx, h2d_stream, d2h_stream)
return tensor.device_tensor
def get_swap_key(self, layer_idx):
return self._key_manager.get_key(layer_idx)
def exist(self, key):
return key in self.key_to_swap_tensor
def release_device_tensor(self, layer_idx):
for tensor_key, swap_tensor in self.key_to_swap_tensor.items():
if tensor_key[0] == layer_idx:
swap_tensor.wait_d2h_finished(self.d2h_stream, True)
def prefetch_tensors(self, layer_idx, tensor_idx, h2d_stream, d2h_stream):
"""Prefetch tensors to device memory.
Args:
layer_idx: Current layer index.
tensor_idx: Current tensor index.
h2d_stream: Stream for host-to-device transfers.
d2h_stream: Stream for device-to-host transfers.
"""
prefetch_keys = self._key_manager.get_prefetch_keys(layer_idx, tensor_idx)
for prefetch_key in prefetch_keys:
if self.exist(prefetch_key):
swap_tensor = self.key_to_swap_tensor.pop(prefetch_key)
d2h_stream.wait_stream(h2d_stream)
swap_tensor.async_h2d(h2d_stream, True)
swap_tensor.device_tensor.record_stream(h2d_stream)
@property
def d2h_stream(self) -> torch.Stream:
"""Get or create the device-to-host stream.
Returns:
The device-to-host stream.
"""
if self._d2h_stream is None:
self._d2h_stream = torch.accelerator.Stream()
return self._d2h_stream
@property
def h2d_stream(self) -> torch.Stream:
"""Get or create the host-to-device stream.
Returns:
The host-to-device stream.
"""
if self._h2d_stream is None:
self._h2d_stream = torch.accelerator.Stream()
return self._h2d_stream
class TensorSwapContext:
"""Context manager for tensor swap operations during model execution."""
context_map = defaultdict(lambda: -1)
def __init__(
self,
module_tag: str = 'default',
custom_check_fn: Optional[Callable] = None,
prefetch: bool = True) -> None:
"""Initialize the tensor swap context.
Args:
module_tag: A pattern string for identify structurally equivalent components across layers, enabling
consistent identification of the same sub-module role within different layers
The asterisk (`*`) acts as a wildcard, matching any layer index.
Examples:
1) For layers 'model.layer.0', 'model.layer.1', ... 'model.layer.N',
the module_tag would be 'model.layer.*'.
2) For the component 'up_proj' of expert 0 in different layers, such as
'model.layer.2.mlp.experts.0.up_proj' and 'model.layer.5.mlp.experts.0.up_proj',
the module_tag would be 'model.layer.*.mlp.experts.0.up_proj'.
3) For the component 'up_proj' of expert 1 in different layers, such as
'model.layer.3.mlp.experts.1.up_proj' and 'model.layer.7.mlp.experts.1.up_proj',
the module_tag would be 'model.layer.*.mlp.experts.1.up_proj'.
custom_check_fn: Custom function to check tensor eligibility.
prefetch: Whether to enable tensor prefetching.
"""
self.module_tag = module_tag
TensorSwapContext.context_map[self.module_tag] += 1
self.layer_idx = TensorSwapContext.context_map[self.module_tag]
self.prefetch = prefetch
self.swap_handler = AsyncSwapHandler(custom_check_fn)
def __enter__(self):
"""Enter the context and set up swap hooks."""
def _on_save_for_backward(tensor):
swap_tensor = self.swap_handler.swap_out_cur_and_release_prev(tensor, self.layer_idx, self.layer_nums)
return swap_tensor
def _on_get_saved_tensor(tensor) -> torch.Tensor:
device_tensor = self.swap_handler.wait_cur_and_prefetch_next(tensor, self.prefetch)
return device_tensor
self.pack_hook = _on_save_for_backward
self.unpack_hook = _on_get_saved_tensor
torch._C._autograd._push_saved_tensors_default_hooks(
self.pack_hook, self.unpack_hook
)
def __exit__(self, *args):
"""Exit the context and remove hooks."""
torch._C._autograd._pop_saved_tensors_default_hooks()
def set_custom_check_fn(self, fn):
"""Set a custom function to check tensor eligibility for swapping.
Args:
fn: The custom check function.
"""
self.swap_handler.set_custom_check_fn(fn)
@property
def layer_nums(self) -> int:
"""
Returns:
layer nums.
"""
return TensorSwapContext.context_map[self.module_tag] + 1