# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import numpy as np

import cv2

import torch





def ToTensor(sample):

    sample = sample.transpose(2, 0, 1)

    C, H, W = sample.shape

    sample = sample.reshape(1, C, H, W)

    return torch.from_numpy(sample.astype(np.float32))





def get_center(x):

    return (x - 1.) / 2.





# top-left bottom-right --> cx,cy,w,h

def xyxy2cxcywh(bbox):

    return get_center(bbox[0] + bbox[2]), \

           get_center(bbox[1] + bbox[3]), \

           (bbox[2] - bbox[0]), \

           (bbox[3] - bbox[1])





# model_sz=127, a picture is resized from original_sz to model_sz

def crop_and_pad(img, cx, cy, model_sz, original_sz, img_mean=None):

    xmin = cx - original_sz // 2

    xmax = cx + original_sz // 2

    ymin = cy - original_sz // 2

    ymax = cy + original_sz // 2

    im_h, im_w, _ = img.shape



    left = right = top = bottom = 0

    if xmin < 0:

        left = int(abs(xmin))

    if xmax > im_w:

        right = int(xmax - im_w)

    if ymin < 0:

        top = int(abs(ymin))

    if ymax > im_h:

        bottom = int(ymax - im_h)



    xmin = int(max(0, xmin))

    xmax = int(min(im_w, xmax))

    ymin = int(max(0, ymin))

    ymax = int(min(im_h, ymax))

    im_patch = img[ymin:ymax, xmin:xmax]

    if left != 0 or right != 0 or top != 0 or bottom != 0:

        if img_mean is None:

            img_mean = tuple(map(int, img.mean(axis=(0, 1))))

        im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,

                cv2.BORDER_CONSTANT, value=img_mean)

    if model_sz != original_sz:

        im_patch = cv2.resize(im_patch, (model_sz, model_sz))

    return im_patch





# size_z=127

def get_exemplar_image(img, bbox, size_z, context_amount, img_mean=None):

    cx, cy, w, h = xyxy2cxcywh(bbox)

    wc_z = w + context_amount * (w + h)

    hc_z = h + context_amount * (w + h)

    s_z = np.sqrt(wc_z * hc_z) 

    scale_z = size_z / s_z   # 0.75

    exemplar_img = crop_and_pad(img, cx, cy, size_z, s_z, img_mean)  # 127*127

    return exemplar_img, scale_z, s_z





def get_pyramid_instance_image(img, center, size_x, size_x_scales, img_mean=None):

    if img_mean is None:

        img_mean = tuple(map(int, img.mean(axis=(0, 1))))

    pyramid = [crop_and_pad(img, center[0], center[1], size_x, size_x_scale, img_mean)

            for size_x_scale in size_x_scales]

    return pyramid





def center_error(rects1, rects2):

    r"""Center error.



    Args:

        rects1 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

        rects2 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

    """

    centers1 = rects1[..., :2] + (rects1[..., 2:] - 1) / 2

    centers2 = rects2[..., :2] + (rects2[..., 2:] - 1) / 2

    errors = np.sqrt(np.sum(np.power(centers1 - centers2, 2), axis=-1))



    return errors





def rect_iou(rects1, rects2, bound=None):

    r"""Intersection over union.



    Args:

        rects1 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

        rects2 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

        bound (numpy.ndarray): A 4 dimensional array, denotes the bound

            (min_left, min_top, max_width, max_height) for ``rects1`` and ``rects2``.

    """

    assert rects1.shape == rects2.shape

    if bound is not None:

        # bounded rects1

        rects1[:, 0] = np.clip(rects1[:, 0], 0, bound[0])

        rects1[:, 1] = np.clip(rects1[:, 1], 0, bound[1])

        rects1[:, 2] = np.clip(rects1[:, 2], 0, bound[0] - rects1[:, 0])

        rects1[:, 3] = np.clip(rects1[:, 3], 0, bound[1] - rects1[:, 1])

        # bounded rects2

        rects2[:, 0] = np.clip(rects2[:, 0], 0, bound[0])

        rects2[:, 1] = np.clip(rects2[:, 1], 0, bound[1])

        rects2[:, 2] = np.clip(rects2[:, 2], 0, bound[0] - rects2[:, 0])

        rects2[:, 3] = np.clip(rects2[:, 3], 0, bound[1] - rects2[:, 1])



    rects_inter = _intersection(rects1, rects2)

    areas_inter = np.prod(rects_inter[..., 2:], axis=-1)



    areas1 = np.prod(rects1[..., 2:], axis=-1)

    areas2 = np.prod(rects2[..., 2:], axis=-1)

    areas_union = areas1 + areas2 - areas_inter



    eps = np.finfo(float).eps

    ious = areas_inter / (areas_union + eps)

    ious = np.clip(ious, 0.0, 1.0)



    return ious





def _intersection(rects1, rects2):

    r"""Rectangle intersection.



    Args:

        rects1 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

        rects2 (numpy.ndarray): An N x 4 numpy array, each line represent a rectangle

            (left, top, width, height).

    """

    assert rects1.shape == rects2.shape

    x1 = np.maximum(rects1[..., 0], rects2[..., 0])

    y1 = np.maximum(rects1[..., 1], rects2[..., 1])

    x2 = np.minimum(rects1[..., 0] + rects1[..., 2],

                    rects2[..., 0] + rects2[..., 2])

    y2 = np.minimum(rects1[..., 1] + rects1[..., 3],

                    rects2[..., 1] + rects2[..., 3])



    w = np.maximum(x2 - x1, 0)

    h = np.maximum(y2 - y1, 0)



    return np.stack([x1, y1, w, h]).T