from typing import Optional, Dict, Any
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
class BaseRandomBatchSampler(StatefulDistributedSampler):
"""
Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. Default: ``True``. (It is not implemented that the drop_last is false.)
"""
def __init__(
self,
dataset,
batch_size: int = 1,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = True,
data_sharding: bool = False,
):
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.total_samples = len(dataset)
self.micro_batch_size = batch_size
self.consumed_samples = 0
self.next_consumed_samples = None
self.data_sharding = data_sharding
self.epoch = 0
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * self.num_replicas
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
if not drop_last:
raise ValueError("It is not implemented that the drop_last is false.")
def __len__(self):
return self.total_samples
def __iter__(self):
if self.next_consumed_samples is not None:
self.consumed_samples = self.next_consumed_samples
self.next_consumed_samples = None
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.num_replicas
start_idx = self.rank * bucket_size
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_bucket = torch.randperm(bucket_size, generator=g).tolist()
else:
idx_range_bucket = list(range(bucket_size))
idx_range = [start_idx + x for x in idx_range_bucket[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
else:
idx_range_total = list(range(full_bucket_size))
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.rank::self.num_replicas]
batch = []
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
def state_dict(self) -> Dict[str, Any]:
return {self._YIELDED: self.consumed_samples}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self._YIELDED not in state_dict:
raise ValueError("Invalid state_dict")
if state_dict[self._YIELDED] < 0:
raise ValueError("Cannot load state_dict with negative yielded value")
self.next_consumed_samples = state_dict[self._YIELDED]