# Copyright 2020 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import os

import sys

import numpy as np

from PIL import Image

from tqdm import tqdm





def resize(img, size, interpolation=Image.BILINEAR):

    r"""Resize the input PIL Image to the given size.



    Args:

        img (PIL Image): Image to be resized.

        size (sequence or int): Desired output size. If size is a sequence like

            (h, w), the output size will be matched to this. If size is an int,

            the smaller edge of the image will be matched to this number maintaining

            the aspect ratio. i.e, if height > width, then image will be rescaled to

            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`

        interpolation (int, optional): Desired interpolation. Default is

            ``PIL.Image.BILINEAR``



    Returns:

        PIL Image: Resized image.

    """



    if isinstance(size, int):

        w, h = img.size

        if (w <= h and w == size) or (h <= w and h == size):

            return img

        if w < h:

            ow = size

            oh = int(size * h / w)

            return img.resize((ow, oh), interpolation)

        else:

            oh = size

            ow = int(size * w / h)

            return img.resize((ow, oh), interpolation)

    else:

        return img.resize(size[::-1], interpolation)





def center_crop(img, out_height, out_width):

    height, width, _ = img.shape

    left = int((width - out_width) / 2)

    right = int((width + out_width) / 2)

    top = int((height - out_height) / 2)

    bottom = int((height + out_height) / 2)

    img = img[top:bottom, left:right]

    return img





def preprocess(file_path, bin_path):

    in_files = os.listdir(file_path)

    if not os.path.exists(bin_path):

        os.makedirs(bin_path)

    input_size = (256, 256)

    mean = [0.485, 0.456, 0.406]

    std = [0.229, 0.224, 0.225]

    for file in tqdm(in_files):

        img = Image.open(os.path.join(file_path, file)).convert('RGB')

        img = resize(img, input_size)  # transforms.Resize(256)

        img = np.array(img, dtype=np.float32)

        img = center_crop(img, 224, 224)   # transforms.CenterCrop(224)



        img = img / 255.



        # 均值方差

        img[..., 0] -= mean[0]

        img[..., 1] -= mean[1]

        img[..., 2] -= mean[2]

        img[..., 0] /= std[0]

        img[..., 1] /= std[1]

        img[..., 2] /= std[2]



        img = img.transpose(2, 0, 1)  # HWC -> CHW

        img.tofile(os.path.join(bin_path, file.split('.')[0] + '.bin'))





if __name__ == "__main__":

    file_path = os.path.abspath(sys.argv[1])

    bin_path = os.path.abspath(sys.argv[2])

    preprocess(file_path, bin_path)