05360171创建于 2022年3月18日历史提交
# BSD 3-Clause License

#

# Copyright (c) 2017 xxxx

# All rights reserved.

# Copyright 2021 Huawei Technologies Co., Ltd

#

# Redistribution and use in source and binary forms, with or without

# modification, are permitted provided that the following conditions are met:

#

# * Redistributions of source code must retain the above copyright notice, this

#   list of conditions and the following disclaimer.

#

# * Redistributions in binary form must reproduce the above copyright notice,

#   this list of conditions and the following disclaimer in the documentation

#   and/or other materials provided with the distribution.

#

# * Neither the name of the copyright holder nor the names of its

#   contributors may be used to endorse or promote products derived from

#   this software without specific prior written permission.

#

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"

# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE

# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# ============================================================================

import importlib

import torch.utils.data

#from data.base_dataset import BaseDataset





def find_dataset_using_name(dataset_name):

    """Import the module "data/[dataset_name]_dataset.py".



    In the file, the class called DatasetNameDataset() will

    be instantiated. It has to be a subclass of BaseDataset,

    and it is case-insensitive.

    """

    dataset_filename = "data." + dataset_name + "_dataset"

    datasetlib = importlib.import_module(dataset_filename)



    dataset = None

    target_dataset_name = dataset_name.replace('_', '') + 'dataset'

    for name, cls in datasetlib.__dict__.items():

        if name.lower() == target_dataset_name.lower():

            dataset = cls



    if dataset is None:

        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))



    return dataset





def get_option_setter(dataset_name):

    """Return the static method <modify_commandline_options> of the dataset class."""

    dataset_class = find_dataset_using_name(dataset_name)

    return dataset_class.modify_commandline_options



 

def create_dataset(opt):

    """Create a dataset given the option.



    This function wraps the class CustomDatasetDataLoader.

        This is the main interface between this package and 'train.py'/'test.py'

    """

    dataset_class = find_dataset_using_name(opt.dataset_mode)

    datasets = dataset_class(opt)

    train_sampler = torch.utils.data.distributed.DistributedSampler(datasets)

    data_loader = CustomDatasetDataLoader(opt,datasets,train_sampler)

    dataset = data_loader.load_data()

    return dataset,train_sampler





class CustomDatasetDataLoader():

    """Wrapper class of Dataset class that performs multi-threaded data loading"""



    def __init__(self, opt,dataset,train_sampler):

        """Initialize this class



        Step 1: create a dataset instance given the name [dataset_mode]

        Step 2: create a multi-threaded data loader.

        """

        self.opt = opt

       

        self.dataset=dataset

        

        print("dataset [%s] was created" % type(self.dataset).__name__)

        if(opt.ngpus_per_node>1 and opt.multiprocessing_distributed>=1):

            self.dataloader = torch.utils.data.DataLoader(

                self.dataset,

                batch_size=opt.batch_size,

                shuffle=(train_sampler is None),

                pin_memory=False,

                num_workers=int(opt.num_threads),

                sampler=train_sampler,

                drop_last=True)

            #self.dataloader = torch.utils.data.DataLoader(

            #    self.dataset,

            #    batch_size=opt.batch_size,

            #    shuffle=not opt.serial_batches,

            #    num_workers=int(opt.num_threads),

            #    )

        else:

            self.dataloader = torch.utils.data.DataLoader(

                self.dataset,

                batch_size=opt.batch_size,

                shuffle=not opt.serial_batches,

                num_workers=int(opt.num_threads),

                )



    def load_data(self):

        return self



    def __len__(self):

        """Return the number of data in the dataset"""

        return min(len(self.dataset), self.opt.max_dataset_size)



    def __iter__(self):

        """Return a batch of data"""

        for i, data in enumerate(self.dataloader):

            if i * self.opt.batch_size >= self.opt.max_dataset_size:

                break

            yield data