# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import os

import argparse

import numpy as np

import cv2

import mmcv

import torch



from tqdm import tqdm





dataset_config = {

    'mean': (123.675, 116.28, 103.53),

    'std': (58.395, 57.12, 57.375)

}





tensor_height = 500

tensor_width = 500





def resize(img, size):

    old_h = img.shape[0]

    old_w = img.shape[1]

    scale_ratio = min(size[0] / old_w, size[1] / old_h)

    new_w = int(np.floor(old_w * scale_ratio))

    new_h = int(np.floor(old_h * scale_ratio))

    resized_img = mmcv.imresize(img, (new_w, new_h), backend='cv2')

    return resized_img





def voc2012_preprocess(input_image, output_bin_path):

    img_name = input_image.split('/')[-1]

    bin_name = img_name.split('.')[0] + ".bin"

    bin_fl = os.path.join(output_bin_path, bin_name)



    one_img = mmcv.imread(os.path.join(input_image), backend='cv2')

    one_img = resize(one_img, (tensor_width, tensor_height))

    mean = np.array(dataset_config['mean'], dtype=np.float16)

    std = np.array(dataset_config['std'], dtype=np.float16)

    one_img = mmcv.imnormalize(one_img, mean, std)



    h = one_img.shape[0]

    w = one_img.shape[1]

    pad_left = (tensor_width - w) // 2

    pad_top = (tensor_height - h) // 2

    pad_right = tensor_width - pad_left - w

    pad_bottom = tensor_height - pad_top - h

    one_img = mmcv.impad(one_img, padding=(pad_left, pad_top, pad_right, pad_bottom), pad_val=0)

    one_img=one_img.astype(np.float16)

    one_img = one_img.transpose(2, 0, 1)

    one_img.tofile(bin_fl)





if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='preprocess of FCN-8s pytorch model')

    parser.add_argument("--image_folder_path", default="/opt/npu/VOCdevkit/VOC2012/JPEGImages/", 

    help='image of dataset')

    parser.add_argument("--split", default="/opt/npu/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt")

    parser.add_argument("--bin_folder_path", default="./voc12_bin/", help='Preprocessed image buffer')

    flags = parser.parse_args()



    if not os.path.exists(flags.bin_folder_path):

        os.makedirs(flags.bin_folder_path)



    split = flags.split

    img_suffix = '.jpg'

    img_infos = []

    if split is not None:

        with open(split) as f:

            for line in f:

                img_name = line.strip()

                img_info = img_name + img_suffix

                img_infos.append(img_info)



    images = os.listdir(flags.image_folder_path)

    

    for idx, image_name in enumerate(tqdm(images)):        

        if not (image_name.endswith(".jpeg") or image_name.endswith(".JPEG") or image_name.endswith(

                ".jpg") and image_name in img_infos):

            continue



        path_image = os.path.join(flags.image_folder_path, image_name)

        voc2012_preprocess(path_image, flags.bin_folder_path)