import importlib
import torch.utils.data
from torch.utils.data import Dataset
import numpy as np
import glob
def find_dataset_using_name(dataset_name):
dataset_filename = dataset_name
module, target = dataset_name.split('::')
datasetlib = importlib.import_module(module)
dataset = None
for name, cls in datasetlib.__dict__.items():
if name == target:
dataset = cls
if dataset is None:
raise ValueError("In %s.py, there should be a class "
"with class name that matches %s in lowercase." %
(dataset_filename, target))
return dataset
def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataloader(opt, distributed, labels_required, is_inference):
dataset = find_dataset_using_name(opt.type)
instance = dataset(opt, is_inference, labels_required)
phase = 'val' if is_inference else 'training'
batch_size = opt.val.batch_size if is_inference else opt.train.batch_size
print("%s dataset [%s] of size %d was created" %
(phase, opt.type, len(instance)))
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=batch_size,
sampler=data_sampler(instance, shuffle=not is_inference, distributed=distributed),
drop_last=not is_inference,
num_workers=getattr(opt, 'num_workers', 0),
)
return dataloader
def data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return torch.utils.data.RandomSampler(dataset)
else:
return torch.utils.data.SequentialSampler(dataset)
def get_dataloader(opt, distributed, is_inference):
dataset = create_dataloader(opt, distributed, is_inference)
return dataset
def get_train_val_dataloader(opt, labels_required=False, distributed = False):
val_dataset = create_dataloader(opt, distributed, labels_required = labels_required, is_inference=True,)
train_dataset = create_dataloader(opt, distributed, labels_required = labels_required, is_inference=False)
return val_dataset, train_dataset
def pad_images(img_tensor, pad_value):
b,c,h,w = img_tensor.size()
pad = torch.ones((b,c,h,(h-w)//2)).cuda()*pad_value
return torch.cat([pad, img_tensor, pad], 3)