import contextlib
import logging
from typing import Union
import torch
from torch import _C
from torch.cuda import _lazy_call, _lazy_init
from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.core.parallel_state import (
get_expert_model_parallel_rank,
get_expert_tensor_parallel_rank,
get_tensor_model_parallel_rank,
)
from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_data
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
try:
import transformer_engine
HAVE_TE = True
except ModuleNotFoundError:
HAVE_TE = False
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", clone: bool = False, graph_safe: bool = False
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU.
Arguments:
device (int): The gpu to retrieve the rng state
clone (bool): Whether to also clone the retrieved RNG state
graph_safe (bool): Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()"""
if not graph_safe:
return torch.cuda.random.get_rng_state(device=device)
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if clone:
return default_generator.clone_state()
return default_generator.graphsafe_get_state()
def _set_cuda_rng_state(new_state: torch.Tensor, device: int = -1, graph_safe: bool = False):
"""Sets the random number generator state of the current GPU.
Arguments:
new_state (torch.ByteTensor): The desired state
device (int): The gpu to retrieve the rng state
graph_safe (bool): Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe:
default_generator.graphsafe_set_state(new_state)
else:
default_generator.set_state(new_state)
_lazy_call(cb)
def get_expert_parallel_rng_tracker_name():
"""Get the expert parallel rng tracker name"""
global _EXPERT_PARALLEL_RNG_TRACKER_NAME
return _EXPERT_PARALLEL_RNG_TRACKER_NAME
def get_data_parallel_rng_tracker_name():
"""Get the data parallel rng tracker name"""
global _DATA_PARALLEL_RNG_TRACKER_NAME
return _DATA_PARALLEL_RNG_TRACKER_NAME
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self, use_cudagraphable_rng=False, is_inference_rng_tracker=False):
self.reset()
self.use_cudagraphable_rng = use_cudagraphable_rng
self.is_inference_rng_tracker = is_inference_rng_tracker
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
def reset(self):
"""Set to the initial state (no tracker)."""
self._is_initialized = False
self.states_ = {}
self.seeds_ = set()
def get_states(self):
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
self._is_initialized = True
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
self._is_initialized = True
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
if self.use_cudagraphable_rng:
new_state = _get_cuda_rng_state(clone=True, graph_safe=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
else:
orig_rng_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
_set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
cpu_rng_state = torch.get_rng_state()
try:
yield
finally:
if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
_CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
"""
global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return
base_tracker = None
if HAVE_TE and use_te_rng_tracker:
if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
base_tracker = TECudaRNGStatesTracker
tracker_kwargs = {"is_inference_rng_tracker": inference_rng_tracker}
else:
base_tracker = CudaRNGStatesTracker
tracker_kwargs = {
"use_cudagraphable_rng": use_cudagraphable_rng,
"is_inference_rng_tracker": inference_rng_tracker,
}
if inference_rng_tracker:
class InferenceCudaRNGStatesTracker(base_tracker):
"""RNG tracker for inference."""
def add(self, name, seed):
"""Mirrors the interface from the training RNG tracker."""
pass
def set_states(self, states):
"""Mirrors the interface from the training RNG tracker."""
pass
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Mirrors the interface from the training RNG tracker."""
return contextlib.nullcontext()
tracker_class = InferenceCudaRNGStatesTracker
else:
tracker_class = base_tracker
_CUDA_RNG_STATE_TRACKER = tracker_class(**tracker_kwargs)
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng)
return _CUDA_RNG_STATE_TRACKER
def get_all_rng_states():
"""Returns all generator states used by the current `CudaRNGStatesTracker`."""
if isinstance(_CUDA_RNG_STATE_TRACKER, CudaRNGStatesTracker):
return _CUDA_RNG_STATE_TRACKER.states_
elif HAVE_TE and is_te_min_version("1.5.0"):
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
if isinstance(_CUDA_RNG_STATE_TRACKER, TECudaRNGStatesTracker):
from transformer_engine.pytorch.distributed import get_all_rng_states
return get_all_rng_states()
else:
return {}
def model_parallel_cuda_manual_seed(
seed: int,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Three set of RNG states are tracked:
default state: This is for data parallelism and is the same among a set of model parallel GPUs
but different across different model parallel groups. This is used for example for dropout
in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model parallel GPUs,
but the same across data parallel groups. This is used for example for dropout
in model parallel regions.
expert-parallel-seed: This state is only used for the expert layer of MoE models.
It is different among expert-tensor and expert-model parallel GPUs, and the same
across expert-data parallel groups.
"""
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
data_parallel_seed = seed
initialize_rng_tracker(te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng)
_CUDA_RNG_STATE_TRACKER.reset()
torch.cuda.manual_seed(data_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
expert_parallel_seed = (
seed + 1024 + 100 * get_expert_model_parallel_rank() + get_expert_tensor_parallel_rank()
)
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
def _get_all_rng_states():
"""Get all the rng states."""
cpu_rng_state = torch.get_rng_state()
cuda_rng_state = _get_cuda_rng_state()
cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker
def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker):
"""Set all the rng states."""
torch.set_rng_state(cpu_rng_state)
_set_cuda_rng_state(cuda_rng_state)
get_cuda_rng_tracker().set_states(cuda_rng_state_tracker)
@contextlib.contextmanager
def _fork_rng():
"""Fork the rng state."""
current_states = _get_all_rng_states()
try:
yield
finally:
_set_all_rng_states(*current_states)
class CheckpointFunction(torch.autograd.Function):
"""Checkpoint Function
This function is adapted from torch.utils.checkpoint with two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward pass."""
from mindspeed.utils import get_actual_seq_len
ctx.actual_seq_len = get_actual_seq_len()
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
ctx.rng_states = _get_all_rng_states()
with torch.no_grad():
outputs = run_function(*args)
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
)
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
"""Backward pass."""
from mindspeed.utils import set_actual_seq_len
set_actual_seq_len(ctx.actual_seq_len)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
)
with _fork_rng():
_set_all_rng_states(*ctx.rng_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args))
)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, distribute_saved_activations, *args):
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
class CheckpointWithoutOutputFunction(torch.autograd.Function):
"""
Checkpoint Function Helper for CheckpointWithouOutput.
Save context for recompute.
"""
@staticmethod
def forward(ctx, run_function, checkpoint_without_output_obj, *args):
"""Forward pass."""
with torch.no_grad():
outputs = run_function(*args)
ctx.save_for_backward(*detach_variable(args))
checkpoint_without_output_obj.ctx = ctx
return outputs
@staticmethod
def backward(ctx, *args):
"""Backward pass."""
inputs = ctx.saved_tensors
outputs = ctx.outputs
torch.autograd.backward(outputs, args)
ctx.outputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
return (None, None) + grads
class CheckpointWithoutOutput(object):
"""
Checkpoint a model or part of the model and release the output.
For the normal 'checkpoint` function, the outputs of it may be cached by the following
operations for its backward computation. However, the output of the checkpointed function is
re-generated at recomputation, so the output store is not technically needed. This method can
manually discard the output in the forward pass and restore it by recomputation in the
backward pass to reduce the memory usage.
"""
def __init__(self):
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
self.ctx = None
self.outputs = None
def checkpoint(self, run_function, *args):
"""Checkpoint function."""
self.run_function = run_function
self.rng_states = _get_all_rng_states()
outputs = CheckpointWithoutOutputFunction.apply(run_function, self, *args)
self.outputs = outputs
if isinstance(self.outputs, torch.Tensor):
self.outputs = (self.outputs,)
return outputs
def _recompute(self, _):
"""Used as a hook to recompute the output."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
with _fork_rng():
_set_all_rng_states(*self.rng_states)
with torch.enable_grad():
outputs = self.run_function(*self.ctx.saved_tensors)
self.run_function = None
self.rng_states = None
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
with torch.no_grad():
for output, recomputation_output in zip(self.outputs, outputs):
output_size = recomputation_output.untyped_storage().size()
output.untyped_storage().resize_(output_size)
output.untyped_storage().copy_(recomputation_output.untyped_storage())
self.ctx.outputs = outputs
self.outputs = None
self.ctx = None
def discard_output_and_register_recompute(self, hook_tensor):
"""
Release the output tensor storages and register the recompute function as a grad hook of
the hook_tensor.
Note: the caller should make sure that the output tensors are no longer used
in the forward pass and the gradient of the hook_tensor is computed before the recomputed
tensors are used.
"""
for output in self.outputs:
output.untyped_storage().resize_(0)
if hook_tensor.requires_grad:
hook_tensor.register_hook(self._recompute)