import os
import threading
import itertools
import queue
from typing import Generic
import torch
import torch.distributed as dist
from torch.utils.data import _utils
from torch.utils.data.dataloader import _SingleProcessDataLoaderIter as SrcSingleProcessDataLoaderIter
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter as SrcMultiProcessingDataLoaderIter
from torch.utils.data.dataloader import DataLoader as SrcDataLoader
from torch.utils.data.dataloader import _DatasetKind, _share_dist_seed, _get_distributed_settings, _T_co
from torch.utils.data._utils.pin_memory import _pin_memory_loop
from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
import torch.multiprocessing as multiprocessing
import torchvision
import torch_npu
MP_STATUS_CHECK_INTERVAL = 5.0
def npu_worker_loop(*args, device_id=None):
if device_id is None:
_utils.worker._worker_loop(*args)
return
os.environ["MIN_COMPILE_RESOURCE_USAGE_CTRL"] = "ub_fusion,coretype_check,op_compile"
torch_npu.npu.set_device(device_id)
torchvision.set_image_backend('npu')
torchvision.set_video_backend('npu')
torch_npu.npu.current_stream().set_data_preprocess_stream(True)
_utils.worker._worker_loop(*args)
class DataLoader(SrcDataLoader[_T_co], Generic[_T_co]):
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
def __setattr__(self, attr, val):
if self.__initialized and attr in (
'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
super(SrcDataLoader, self).__setattr__(attr, val)
class _SingleProcessDataLoaderIter(SrcSingleProcessDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
self._pin_memory = loader.pin_memory and torch_npu.npu.is_available()
if self._timeout != 0:
raise ValueError("self._timeout != 0")
if self._num_workers != 0:
raise ValueError("self._num_workers != 0")
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
class _MultiProcessingDataLoaderIter(SrcMultiProcessingDataLoaderIter):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
def __init__(self, loader):
try:
torch_npu.npu.synchronize()
except Exception as e:
print(e)
self._prefetch_factor = loader.prefetch_factor
self._dataset = loader.dataset
self._shared_seed = None
self._pg = None
if isinstance(self._dataset, IterDataPipe):
if dist.is_available() and dist.is_initialized():
self._pg = dist.new_group(backend="gloo")
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
ws, rank = _get_distributed_settings()
self._world_size = ws
self._rank = rank
self._pin_memory = loader.pin_memory and torch.npu.is_available()
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
self._device_id = None
if self._num_workers <= 0:
raise ValueError("self._num_workers <= 0")
if self._prefetch_factor <= 0:
raise ValueError("self._prefetch_factor <= 0")
worker_loop = npu_worker_loop
daemon = True
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
if torchvision.get_video_backend() == 'npu' or torchvision.get_image_backend() == 'npu':
self._device_id = torch_npu.npu.current_device()
multiprocessing_context = multiprocessing.get_context('spawn')
daemon = False
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue()
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers, self._shared_seed),
kwargs={"device_id": self._device_id})
w.daemon = daemon
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
train_device_id = 0
if torch_npu.npu.is_available():
train_device_id = torch_npu.npu.current_device()
else:
train_device_id = torch.cuda.current_device()
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
train_device_id,
self._pin_memory_thread_done_event, self._pin_memory_device))
pin_memory_thread.daemon = True
pin_memory_thread.start()
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
if self._persistent_workers and self._pin_memory:
import atexit
for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True)
@staticmethod
def _clean_up_worker(w):
try:
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
finally:
if w.is_alive():
w.terminate()
def __del__(self):
self._shutdown_workers()
def add_dataloader_method():
torch.utils.data.DataLoader = DataLoader
torch.utils.data.dataloader.DataLoader = DataLoader
SrcDataLoader.__setattr__ = DataLoader.__setattr__