""" Distributed training/validation utils
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import distributed as dist
from .model import unwrap_model
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
def distribute_bn(model, world_size, reduce=False):
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size)
else:
torch.distributed.broadcast(bn_buf, 0)