from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data.distributed import DistributedSampler
def dataloader(dataset, input_size, batch_size, args, split='train'):
transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3,1,1)), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
if dataset == 'mnist':
if args.is_distributed:
train_sampler = DistributedSampler(datasets.MNIST('data/mnist', train=True, download=True, transform=transform))
data_loader = DataLoader(dataset=datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
sampler=train_sampler, num_workers=1,batch_size=batch_size, pin_memory=False, drop_last=True)
else:
data_loader = DataLoader(
datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True, num_workers=1)
elif dataset == 'fashion-mnist':
data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'cifar10':
data_loader = DataLoader(
datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'svhn':
data_loader = DataLoader(
datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'stl10':
data_loader = DataLoader(
datasets.STL10('data/stl10', split=split, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
elif dataset == 'lsun-bed':
data_loader = DataLoader(
datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
batch_size=batch_size, shuffle=True)
return data_loader