from datetime import timedelta
from typing import Any
import torch
import torch.distributed as dist
from packaging.version import parse as V
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
GLOO_GROUP = None
def init_gloo_group():
"""Initialize Gloo group for distributed communication."""
global GLOO_GROUP
if GLOO_GROUP is None:
GLOO_GROUP = dist.new_group(backend="gloo")
return GLOO_GROUP
def get_gloo_group():
"""Get the Gloo group for distributed communication."""
global GLOO_GROUP
if GLOO_GROUP is None:
raise RuntimeError("Gloo group has not been initialized. Call _init_gloo_group() first.")
return GLOO_GROUP
def init_process_group(
backend: str | Backend = None,
init_method: str | None = None,
timeout: timedelta | None = None,
world_size: int = -1,
rank: int = -1,
store: Store | None = None,
group_name: str = None,
pg_options: Any | None = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
store = PrefixStore(group_name, store)
pg_options_param_name = "backend_options" if V(torch.__version__) >= V("2.6") else "pg_options"
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def distributed_masked_whiten(
values: torch.Tensor,
mask: torch.Tensor,
process_group: dist.ProcessGroup | None = None,
shift_mean: bool = True,
epsilon: float = 1e-8,
):
"""
Performs whitening on a tensor using global statistics from all participating GPUs.
It calculates the global mean and variance across all ranks in the default
process group (the WORLD) and uses these global statistics to normalize the
local data on each rank.
Args:
values (torch.Tensor): The local tensor of values to whiten.
mask (torch.Tensor): The local mask corresponding to the values.
process_group: The process group for all_reduce.
If None, uses the default world group.
shift_mean (bool): If True, the output is zero-mean. Defaults to True.
epsilon (float): A small value for numerical stability.
Returns:
torch.Tensor: The locally whitened tensor using global statistics.
"""
local_sum = (values * mask).sum()
local_sum_sq = ((values**2) * mask).sum()
local_mask_sum = mask.sum()
stats_tensor = torch.tensor(
[local_sum, local_sum_sq, local_mask_sum],
device=values.device,
dtype=torch.float32,
)
dist.all_reduce(stats_tensor, group=process_group)
global_sum, global_sum_sq, global_mask_sum = stats_tensor
if global_mask_sum.item() == 0:
raise ValueError("The global mask sum across all participating GPUs is zero.")
global_mean = global_sum / global_mask_sum
global_mean_sq = global_sum_sq / global_mask_sum
global_var = global_mean_sq - global_mean**2
if global_mask_sum.item() >= 2:
bessel_correction = global_mask_sum / (global_mask_sum - 1)
global_var = global_var * bessel_correction
whitened_values = (values - global_mean) * torch.rsqrt(global_var + epsilon)
if not shift_mean:
whitened_values += global_mean
return whitened_values