from torch.distributed.pipelining import Schedule1F1B
from typing import Optional, Union, Any, Mapping
import torch.distributed as dist
import torch
def _batch_p2p(
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
) -> list[dist.Work]:
"""
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
"""
if len(p2p_ops) == 0:
return []
desc_str = f"{desc}, " if desc else ""
return dist.batch_isend_irecv(p2p_ops)
def _wait_batch_p2p(work: list[dist.Work]):
"""
Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
"""
for w in work:
w.wait()
class IterSchedule1F1B(Schedule1F1B):
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self._stage.device}
return data.to(**kwargs)
return data
def step(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None
):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Args:
microbatches: list of microbatch args.
"""
self._stage.clear_runtime_states()
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
target_mbs = self._prepare_input(target_mbs)
arg_mbs[0], kwarg_mbs[0] = self._prepare_input((arg_mbs[0], kwarg_mbs[0]))
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
warmup_chunks = min(
self._n_microbatches,
self._num_stages - self._stage.stage_index,
)
fwd_mb_index = 0
bwd_mb_index = 0
send_work: list[dist.Work] = []
fwd_sends = []
for _ in range(warmup_chunks):
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
_wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv"))
arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] = self._prepare_input(
(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])
)
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
)
arg_mbs[fwd_mb_index] = None
kwarg_mbs[fwd_mb_index] = None
_wait_batch_p2p(send_work)
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
if fwd_mb_index != warmup_chunks - 1:
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
fwd_mb_index += 1
while True:
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
_wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"))
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
bwd_mb_index += 1
if fwd_mb_index == self._n_microbatches:
break
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
_wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"))
arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] = self._prepare_input(
(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])
)
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
)
arg_mbs[fwd_mb_index] = None
kwarg_mbs[fwd_mb_index] = None
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
fwd_mb_index += 1
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
while bwd_mb_index < self._n_microbatches:
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
_wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv"))
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
_wait_batch_p2p(send_work)
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
bwd_mb_index += 1
self._stage.scale_grads(
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
)
_wait_batch_p2p(send_work)
self._update_losses(self._stage, losses)
if self._stage.is_last:
return self._merge_outputs(
self._stage.output_chunks
)
else:
return None