import os
import numpy as np
import time
import cfg
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class RawDataset(Dataset):
def __init__(self, is_val=False):
self.img_h, self.img_w = cfg.max_train_img_size, cfg.max_train_img_size
if is_val:
with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:
f_list = f_val.readlines()
else:
with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
f_list = f_train.readlines()
self.image_path_list = []
self.labels_path_dic = {}
self.gt_xy_list_path_dic = {}
for f_line in f_list:
img_filename = str(f_line).strip().split(',')[0]
img_path = os.path.join(cfg.data_dir, cfg.train_image_dir_name, img_filename)
self.image_path_list.append(img_path)
gt_file = os.path.join(cfg.data_dir, cfg.train_label_dir_name, img_filename[:-4] + '_gt.npy')
gt_xy_list = os.path.join(cfg.data_dir, cfg.train_label_dir_name, img_filename[:-4] + '.npy')
self.labels_path_dic[img_path] = gt_file
self.gt_xy_list_path_dic[img_path] = gt_xy_list
self.image_path_list.sort()
self.nSamples = len(self.image_path_list)
def __len__(self):
return self.nSamples
def __getitem__(self, index):
img_path = self.image_path_list[index]
label = np.load(self.labels_path_dic[img_path])
gt_xy_list = np.load(self.gt_xy_list_path_dic[img_path])
try:
img = Image.open(img_path).convert('RGB')
except IOError:
print(f'Corrupted image for {index}')
img = Image.new('RGB', (self.img_w, self.img_h))
img_tensor = transforms.ToTensor()(img)
label = np.transpose(label, (2, 0, 1))
return (img_tensor, label, gt_xy_list)
def data_collate(batch):
imgs = []
labels = []
gt_xy_list = []
for info in batch:
imgs.append(info[0])
labels.append(info[1])
gt_xy_list.append(info[2])
return torch.stack(imgs, 0), torch.tensor(np.array(labels)), gt_xy_list
if __name__ == '__main__':
tick = time.time()
train_dataset = RawDataset(is_val=False)
data_loader_A = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg.batch_size,
collate_fn=data_collate,
shuffle=True,
num_workers=int(cfg.workers),
pin_memory=True)
for i, (image_tensors, labels, gt_xy_list) in enumerate(data_loader_A):
print(image_tensors.shape, labels.shape)
tock = time.time()
print(tock-tick)