"""Communication op for all-reduce."""
from __future__ import annotations
from typing import Optional
import torch
from ...distributed import allreduce
from ...tensor import Quantizer
from ..op import BasicOperation, OperationContext
class AllReduce(BasicOperation):
"""All-reduce tensor over a process group."""
def __init__(
self,
process_group: Optional[torch.distributed.ProcessGroup] = None,
reduce_in_backward: bool = True,
) -> None:
super().__init__()
self.process_group = process_group
self._reduce_in_backward = reduce_in_backward
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
out, handle = allreduce(input_.contiguous(), group=self.process_group)
if handle is not None:
handle.wait()
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output, ()