from distutils.version import LooseVersion
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION
from .scatter_gather import scatter_kwargs
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs, kwargs, device_id):
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log('Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
if TORCH_MAJOR == 1 and TORCH_MINOR <= 8:
self._sync_params()
else:
self._sync_params_and_buffers()
if self.device_ids and False:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
output = self.module.train_step(*inputs[0], **kwargs[0])
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log('Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
if TORCH_MAJOR == 1 and TORCH_MINOR <= 8:
self._sync_params()
else:
self._sync_params_and_buffers()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output