05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

# ============================================================================

""" Datasets

    This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets

"""

import os

import os.path

import sys

import numpy as np

import torch.utils.data as data

import torchvision.datasets as dset

from PIL import Image

from tqdm import tqdm



IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']





def is_image_file(filename):

    """Checks if a file is an image.



    Args:

        filename (string): path to a file



    Returns:

        bool: True if the filename ends with a known image extension

    """

    filename_lower = filename.lower()

    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)





def find_classes(dir):

    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]

    classes.sort()

    class_to_idx = {classes[i]: i for i in range(len(classes))}

    return classes, class_to_idx





def make_dataset(dir, class_to_idx):

    images = []

    dir = os.path.expanduser(dir)

    for target in tqdm(sorted(os.listdir(dir))):

        d = os.path.join(dir, target)

        if not os.path.isdir(d):

            continue



        for root, _, fnames in sorted(os.walk(d)):

            for fname in sorted(fnames):

                if is_image_file(fname):

                    path = os.path.join(root, fname)

                    item = (path, class_to_idx[target])

                    images.append(item)



    return images





def pil_loader(path):

    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)

    with open(path, 'rb') as f:

        img = Image.open(f)

        return img.convert('RGB')





def accimage_loader(path):

    import accimage

    try:

        return accimage.Image(path)

    except IOError:

        # Potentially a decoding problem, fall back to PIL.Image

        return pil_loader(path)





def default_loader(path):

    from torchvision import get_image_backend

    if get_image_backend() == 'accimage':

        return accimage_loader(path)

    else:

        return pil_loader(path)





class ImageFolder(data.Dataset):

    """A generic data loader where the images are arranged in this way: ::



        root/dogball/xxx.png

        root/dogball/xxy.png

        root/dogball/xxz.png



        root/cat/123.png

        root/cat/nsdf3.png

        root/cat/asd932_.png



    Args:

        root (string): Root directory path.

        transform (callable, optional): A function/transform that  takes in an PIL image

            and returns a transformed version. E.g, ``transforms.RandomCrop``

        target_transform (callable, optional): A function/transform that takes in the

            target and transforms it.

        loader (callable, optional): A function to load an image given its path.



     Attributes:

        classes (list): List of the class names.

        class_to_idx (dict): Dict with items (class_name, class_index).

        imgs (list): List of (image path, class_index) tuples

    """



    def __init__(self, root, transform=None, target_transform=None,

                 loader=default_loader, load_in_mem=False,

                 index_filename='imagenet_imgs.npz', **kwargs):

        classes, class_to_idx = find_classes(root)

        # Load pre-computed image directory walk

        if os.path.exists(index_filename):

            print('Loading pre-saved Index file %s...' % index_filename)

            imgs = np.load(index_filename)['imgs']

        # If first time, walk the folder directory and save the

        # results to a pre-computed file.

        else:

            print('Generating  Index file %s...' % index_filename)

            imgs = make_dataset(root, class_to_idx)

            np.savez_compressed(index_filename, **{'imgs': imgs})

        if len(imgs) == 0:

            raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n"

                                                                             "Supported image extensions are: " + ",".join(

                IMG_EXTENSIONS)))



        self.root = root

        self.imgs = imgs

        self.classes = classes

        self.class_to_idx = class_to_idx

        self.transform = transform

        self.target_transform = target_transform

        self.loader = loader

        self.load_in_mem = load_in_mem



        if self.load_in_mem:

            print('Loading all images into memory...')

            self.data, self.labels = [], []

            for index in tqdm(range(len(self.imgs))):

                path, target = imgs[index][0], imgs[index][1]

                self.data.append(self.transform(self.loader(path)))

                self.labels.append(target)



    def __getitem__(self, index):

        """

        Args:

            index (int): Index



        Returns:

            tuple: (image, target) where target is class_index of the target class.

        """

        if self.load_in_mem:

            img = self.data[index]

            target = self.labels[index]

        else:

            path, target = self.imgs[index]

            img = self.loader(str(path))

            if self.transform is not None:

                img = self.transform(img)



        if self.target_transform is not None:

            target = self.target_transform(target)



        # print(img.size(), target)

        return img, int(target)



    def __len__(self):

        return len(self.imgs)



    def __repr__(self):

        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'

        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())

        fmt_str += '    Root Location: {}\n'.format(self.root)

        tmp = '    Transforms (if any): '

        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))

        tmp = '    Target Transforms (if any): '

        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))

        return fmt_str





''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid

    having to load individual images all the time. '''

import h5py as h5

import torch





class ILSVRC_HDF5(data.Dataset):

    def __init__(self, root, transform=None, target_transform=None,

                 load_in_mem=False, train=True, download=False, validate_seed=0,

                 val_split=0, **kwargs):  # last four are dummies



        self.root = root

        self.num_imgs = len(h5.File(root, 'r')['labels'])



        # self.transform = transform

        self.target_transform = target_transform



        # Set the transform here

        self.transform = transform



        # load the entire dataset into memory?

        self.load_in_mem = load_in_mem



        # If loading into memory, do so now

        if self.load_in_mem:

            print('Loading %s into memory...' % root)

            with h5.File(root, 'r') as f:

                self.data = f['imgs'][:]

                self.labels = f['labels'][:]



    def __getitem__(self, index):

        """

        Args:

            index (int): Index



        Returns:

            tuple: (image, target) where target is class_index of the target class.

        """

        # If loaded the entire dataset in RAM, get image from memory

        if self.load_in_mem:

            img = self.data[index]

            target = self.labels[index]



        # Else load it from disk

        else:

            with h5.File(self.root, 'r') as f:

                img = f['imgs'][index]

                target = f['labels'][index]



        # if self.transform is not None:

        # img = self.transform(img)

        # Apply my own transform

        img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2



        if self.target_transform is not None:

            target = self.target_transform(target)



        return img, int(target)



    def __len__(self):

        return self.num_imgs

        # return len(self.f['imgs'])





import pickle





class CIFAR10(dset.CIFAR10):



    def __init__(self, root, train=True,

                 transform=None, target_transform=None,

                 download=True, validate_seed=0,

                 val_split=0, load_in_mem=True, **kwargs):

        self.root = os.path.expanduser(root)

        self.transform = transform

        self.target_transform = target_transform

        self.train = train  # training set or test set

        self.val_split = val_split



        if download:

            self.download()



        if not self._check_integrity():

            raise RuntimeError('Dataset not found or corrupted.' +

                               ' You can use download=True to download it')



        # now load the picked numpy arrays

        self.data = []

        self.labels = []

        for fentry in self.train_list:

            f = fentry[0]

            file = os.path.join(self.root, self.base_folder, f)

            fo = open(file, 'rb')

            if sys.version_info[0] == 2:

                entry = pickle.load(fo)

            else:

                entry = pickle.load(fo, encoding='latin1')

            self.data.append(entry['data'])

            if 'labels' in entry:

                self.labels += entry['labels']

            else:

                self.labels += entry['fine_labels']

            fo.close()



        self.data = np.concatenate(self.data)

        # Randomly select indices for validation

        if self.val_split > 0:

            label_indices = [[] for _ in range(max(self.labels) + 1)]

            for i, l in enumerate(self.labels):

                label_indices[l] += [i]

            label_indices = np.asarray(label_indices)



            # randomly grab 500 elements of each class

            np.random.seed(validate_seed)

            self.val_indices = []

            for l_i in label_indices:

                self.val_indices += list(l_i[np.random.choice(len(l_i),

                                                              int(len(self.data) * val_split) // (max(self.labels) + 1),

                                                              replace=False)])



        if self.train == 'validate':

            self.data = self.data[self.val_indices]

            self.labels = list(np.asarray(self.labels)[self.val_indices])



            self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32))

            self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC



        elif self.train:

            print(np.shape(self.data))

            if self.val_split > 0:

                self.data = np.delete(self.data, self.val_indices, axis=0)

                self.labels = list(np.delete(np.asarray(self.labels), self.val_indices, axis=0))



            self.data = self.data.reshape((int(50e3 * (1. - self.val_split)), 3, 32, 32))

            self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        else:

            f = self.test_list[0][0]

            file = os.path.join(self.root, self.base_folder, f)

            fo = open(file, 'rb')

            if sys.version_info[0] == 2:

                entry = pickle.load(fo)

            else:

                entry = pickle.load(fo, encoding='latin1')

            self.data = entry['data']

            if 'labels' in entry:

                self.labels = entry['labels']

            else:

                self.labels = entry['fine_labels']

            fo.close()

            self.data = self.data.reshape((10000, 3, 32, 32))

            self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC



    def __getitem__(self, index):

        """

        Args:

            index (int): Index

        Returns:

            tuple: (image, target) where target is index of the target class.

        """

        img, target = self.data[index], self.labels[index]



        # doing this so that it is consistent with all other datasets

        # to return a PIL Image

        img = Image.fromarray(img)



        if self.transform is not None:

            img = self.transform(img)



        if self.target_transform is not None:

            target = self.target_transform(target)



        return img, target



    def __len__(self):

        return len(self.data)





class CIFAR100(CIFAR10):

    base_folder = 'cifar-100-python'

    url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"

    filename = "cifar-100-python.tar.gz"

    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'

    train_list = [

        ['train', '16019d7e3df5f24257cddd939b257f8d'],

    ]



    test_list = [

        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],

    ]