'''
A new dataloader using NVIDIA DALI in order to speed up the dataloader in pytorch
Ref: https://github.com/d-li14/mobilenetv2.pytorch/blob/master/utils/dataloaders.py
https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py
'''
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from math import ceil
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
except ImportError:
print("Please install DALI from https://www.github.com/NVIDIA/DALI to run DataLoader.")
class TinyImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(TinyImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class TinyImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed):
super(TinyImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
output = self.cmnp(images)
return [output, self.labels]
class ImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(ImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
'''
self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.75, 1.25],
random_area=[0.08, 1.0],
num_attempts=100)
self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR)
'''
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class ImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed):
super(ImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.res = ops.Resize(device='gpu', resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
class DALIWrapper(object):
def gen_wrapper(dali_pipeline):
for data in dali_pipeline:
input = data[0]['data']
target = data[0]['label'].squeeze().cuda().long()
yield input, target
def __init__(self, dali_pipeline):
self.dali_pipeline = dali_pipeline
def __iter__(self):
return DALIWrapper.gen_wrapper(self.dali_pipeline)
def get_dali_tinyImageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'train')
pipe = TinyImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=56, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_tinyImageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'val')
pipe = TinyImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=56, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))
def get_dali_imageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'ILSVRC2012_img_train')
pipe = ImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=224, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_imageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'ILSVRC2012_img_val')
pipe = ImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=224, size=256, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))