# Copyright 2022 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



#coding=utf-8



from __future__ import absolute_import

from __future__ import division

from __future__ import print_function



import cv2

import torch

from PIL import Image, ImageDraw

import torch.utils.data as data

import numpy as np

import random

from utils.augmentations import preprocess





class WIDERDetection(data.Dataset):

    """docstring for WIDERDetection"""



    def __init__(self, list_file, mode='train'):

        super(WIDERDetection, self).__init__()

        self.mode = mode

        self.fnames = []

        self.boxes = []

        self.labels = []



        with open(list_file) as f:

            lines = f.readlines()



        for line in lines:

            line = line.strip().split()

            num_faces = int(line[1])

            box = []

            label = []

            for i in range(num_faces):

                x = float(line[2 + 5 * i])

                y = float(line[3 + 5 * i])

                w = float(line[4 + 5 * i])

                h = float(line[5 + 5 * i])

                c = int(line[6 + 5 * i])

                if w <= 0 or h <= 0:

                    continue

                box.append([x, y, x + w, y + h])

                label.append(c)

            if len(box) > 0:

                self.fnames.append(line[0])

                self.boxes.append(box)

                self.labels.append(label)



        self.num_samples = len(self.boxes)



    def __len__(self):

        return self.num_samples



    def __getitem__(self, index):

        img, target, h, w = self.pull_item(index)

        return img, target



    def pull_item(self, index):

        while True:

            image_path = self.fnames[index]

            img = Image.open(image_path)



            img = img.convert('RGB')

            im_width, im_height = img.size

            boxes = self.annotransform(

                np.array(self.boxes[index]), im_width, im_height)

            label = np.array(self.labels[index])

            bbox_labels = np.hstack((label[:, np.newaxis], boxes)).tolist()

            img, sample_labels = preprocess(

                img, bbox_labels, self.mode, image_path)

            sample_labels = np.array(sample_labels)

            if len(sample_labels) > 0:

                target = np.hstack(

                    (sample_labels[:, 1:], sample_labels[:, 0][:, np.newaxis]))



                assert (target[:, 2] > target[:, 0]).any()

                assert (target[:, 3] > target[:, 1]).any()

                break 

            else:

                index = random.randrange(0, self.num_samples)

        return torch.from_numpy(img), target, im_height, im_width

        



    def annotransform(self, boxes, im_width, im_height):

        boxes[:, 0] /= im_width

        boxes[:, 1] /= im_height

        boxes[:, 2] /= im_width

        boxes[:, 3] /= im_height

        return boxes





def detection_collate(batch):

    """Custom collate fn for dealing with batches of images that have a different

    number of associated object annotations (bounding boxes).



    Arguments:

        batch: (tuple) A tuple of tensor images and lists of annotations



    Return:

        A tuple containing:

            1) (tensor) batch of images stacked on their 0 dim

            2) (list of tensors) annotations for a given image are stacked on

                                 0 dim

    """

    targets = []

    imgs = []

    for sample in batch:

        imgs.append(sample[0])

        targets.append(torch.FloatTensor(sample[1]))

    return torch.stack(imgs, 0), targets





if __name__ == '__main__':

    from config import cfg

    dataset = WIDERDetection(cfg.FACE_TRAIN_FILE)

    dataset.pull_item(14)