import torch
import torch.distributed as dist
from contextlib import contextmanager
import logging.config
import random
def generate_seeds(rng, size):
"""
Generate list of random seeds
:param rng: random number generator
:param size: length of the returned list
"""
seeds = [rng.randint(0, 2**32 - 1) for _ in range(size)]
return seeds
def broadcast_seeds(seeds, device):
"""
Broadcasts random seeds to all distributed workers.
Returns list of random seeds (broadcasted from workers with rank 0).
:param seeds: list of seeds (integers)
:param device: torch.device
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
seeds_tensor = torch.LongTensor(seeds).to(device)
torch.distributed.broadcast(seeds_tensor, 0)
seeds = seeds_tensor.tolist()
return seeds
def setup_seeds(master_seed, epochs, device):
"""
Generates seeds from one master_seed.
Function returns (worker_seeds, shuffling_seeds), worker_seeds are later
used to initialize per-worker random number generators (mostly for
dropouts), shuffling_seeds are for RNGs resposible for reshuffling the
dataset before each epoch.
Seeds are generated on worker with rank 0 and broadcasted to all other
workers.
:param master_seed: master RNG seed used to initialize other generators
:param epochs: number of epochs
:param device: torch.device (used for distributed.broadcast)
"""
if master_seed is None:
master_seed = random.SystemRandom().randint(0, 2**32 - 1)
if get_rank() == 0:
logging.info(f'Using random master seed: {master_seed}')
else:
logging.info(f'Using master seed from command line: {master_seed}')
seeding_rng = random.Random(master_seed)
worker_seeds = generate_seeds(seeding_rng, get_world_size())
shuffling_seeds = generate_seeds(seeding_rng, epochs)
worker_seeds = broadcast_seeds(worker_seeds, device)
shuffling_seeds = broadcast_seeds(shuffling_seeds, device)
return worker_seeds, shuffling_seeds
def barrier():
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
torch.cuda.synchronize()
def get_rank():
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
return rank
def get_world_size():
"""
Gets total number of distributed workers or returns one if distributed is
not initialized.
"""
if torch.distributed.is_available():
print("Torch distributed is available.")
else:
print("Torch distributed is not available.")
if torch.distributed.is_initialized():
print("Torch distributed is initialized.")
else:
print("Torch distributed is not initialized.")
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def set_device(cuda, local_rank):
"""
Sets device based on local_rank and returns instance of torch.device.
:param cuda: if True: use cuda
:param local_rank: local rank of the worker
"""
if cuda:
torch.cuda.set_device(local_rank)
device = torch.device('cuda')
else:
device = torch.device('cpu')
return device
@contextmanager
def sync_workers():
"""
Yields distributed rank and synchronizes all workers on exit.
"""
rank = get_rank()
yield rank
barrier()
def is_main_process():
return get_rank() == 0
def format_step(step):
if isinstance(step, str):
return step
s = ""
if len(step) > 0:
s += "Training Epoch: {} ".format(step[0])
if len(step) > 1:
s += "Training Iteration: {} ".format(step[1])
if len(step) > 2:
s += "Validation Iteration: {} ".format(step[2])
return s