import torch
import torch.distributed as dist
from typing import Union, Tuple, List, Literal
def all_reduce(
inputs: Union[float, torch.Tensor, Tuple, List],
op: Literal["mean", "sum", "max", "min"] = "mean",
group: dist.ProcessGroup = None
) -> Union[float, Tuple[float, ...]]:
"""
Performs an All-Reduce operation on input scalars or tensors (averaging by default).
Args:
inputs: A single scalar, Tensor, or a tuple/list of them (e.g., (loss, grad_norm)).
group: The distributed process group (ProcessGroup).
average: Whether to calculate the average after reduction (Sum / GroupSize).
Returns:
The aggregated Python scalar or tuple.
"""
is_sequence = isinstance(inputs, (tuple, list))
if not is_sequence:
inputs = [inputs]
packed_tensors = []
device = torch.device(torch.accelerator.current_accelerator().type, torch.accelerator.current_device())
for item in inputs:
if isinstance(item, torch.Tensor):
t = item.detach().clone().to(device)
else:
t = torch.tensor(item, device=device, dtype=torch.float32)
packed_tensors.append(t)
reduce_ops = {
"mean": dist.ReduceOp.SUM,
"sum": dist.ReduceOp.SUM,
"max": dist.ReduceOp.MAX,
"min": dist.ReduceOp.MIN,
}
if group is None:
group_size = dist.get_world_size()
else:
group_size = dist.get_world_size(group)
if op not in reduce_ops:
raise ValueError(f"Invalid op value: '{op}'. Must be one of: {list(reduce_ops.keys())}")
for t in packed_tensors:
dist.all_reduce(t, op=reduce_ops[op], group=group)
if op == "mean":
t /= group_size
results = [t.item() for t in packed_tensors]
if is_sequence:
return tuple(results)
else:
return results[0]