# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for MindExplain"""

__all__ = [
    'ForwardProbe',
    'abs_max',
    'calc_auc',
    'calc_correlation',
    'deprecated_error',
    'format_tensor_to_ndarray',
    'generate_one_hot',
    'rank_pixels',
    'resize',
    'retrieve_layer_by_name',
    'retrieve_layer',
    'unify_inputs',
    'unify_targets'
]

from typing import Tuple, Union

import numpy as np
from PIL import Image

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops.operations as op

_Array = np.ndarray
_Module = nn.Cell
_Tensor = ms.Tensor


class DeprecatedError(RuntimeError):
    def __init__(self):
        super().__init__("'mindspore.explainer' is deprecated from version 1.5 and "
                         "will be removed in a future version, use MindSpore XAI "
                         "https://gitee.com/mindspore/xai instead.")


def deprecated_error(func_or_cls):
    del func_or_cls
    raise DeprecatedError()


def abs_max(gradients):
    """
    Transform gradients to saliency through abs then take max along channels.

    Args:
        gradients (_Tensor): Gradients which will be transformed to saliency map.

    Returns:
        _Tensor, saliency map integrated from gradients.
    """
    gradients = op.Abs()(gradients)
    saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
    return saliency


def generate_one_hot(indices, depth):
    r"""
    Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
    and 0.0.
    """
    on_value = ms.Tensor(1.0, ms.float32)
    off_value = ms.Tensor(0.0, ms.float32)
    weights = op.OneHot()(indices, depth, on_value, off_value)
    return weights


def unify_inputs(inputs) -> tuple:
    """Unify inputs of explainer."""
    if isinstance(inputs, tuple):
        return inputs
    if isinstance(inputs, ms.Tensor):
        inputs = (inputs,)
    elif isinstance(inputs, np.ndarray):
        inputs = (ms.Tensor(inputs),)
    else:
        raise TypeError(
            'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
            'but get {}'.format(type(inputs)))
    return inputs


def unify_targets(targets) -> ms.Tensor:
    """Unify targets labels of explainer."""
    if isinstance(targets, ms.Tensor):
        return targets
    if isinstance(targets, list):
        targets = ms.Tensor(targets, dtype=ms.int32)
    if isinstance(targets, int):
        targets = ms.Tensor([targets], dtype=ms.int32)
    else:
        raise TypeError(
            'targets must be one of [int, list or ms.Tensor], '
            'but get {}'.format(type(targets)))
    return targets


def retrieve_layer_by_name(model: _Module, layer_name: str):
    """
    Retrieve the layer in the model by the given layer_name.

    Args:
        model (Cell): Model which contains the target layer.
        layer_name (str): Name of target layer.

    Returns:
        Cell, the target layer.

    Raises:
        ValueError: If module with given layer_name is not found in the model.
    """
    if not isinstance(layer_name, str):
        raise TypeError('layer_name should be type of str, but receive {}.'
                        .format(type(layer_name)))

    if not layer_name:
        return model

    target_layer = None
    for name, cell in model.cells_and_names():
        if name == layer_name:
            target_layer = cell
            return target_layer

    if target_layer is None:
        raise ValueError(
            'Cannot match {}, please provide target layer'
            'in the given model.'.format(layer_name))
    return None


def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
    """
    Retrieve the layer in the model.

    'target' can be either a layer name or a Cell object. Given the layer name,
    the method will search thourgh the model and return the matched layer. If a
    Cell object is provided, it will check whether the given layer exists
    in the model. If target layer is not found in the model, ValueError will
    be raised.

    Args:
        model (Cell): Model which contains the target layer.
        target_layer (str, Cell): Name of target layer or the target layer instance.

    Returns:
        Cell, the target layer.

    Raises:
        ValueError: If module with given layer_name is not found in the model.
    """
    if isinstance(target_layer, str):
        target_layer = retrieve_layer_by_name(model, target_layer)
        return target_layer

    if isinstance(target_layer, _Module):
        for _, cell in model.cells_and_names():
            if target_layer is cell:
                return target_layer
        raise ValueError(
            'Model not contain cell {}, fail to probe.'.format(target_layer)
        )
    raise TypeError('layer_name must have type of str or ms.nn.Cell,'
                    'but receive {}'.format(type(target_layer)))


class ForwardProbe:
    """
    Probe to capture output of specific layer in a given model.

    Args:
        target_layer (str, Cell): Name of target layer or the target layer instance.
    """

    def __init__(self, target_layer: _Module):
        self._target_layer = target_layer
        self._original_construct = self._target_layer.construct
        self._intermediate_tensor = None

    @property
    def value(self):
        """Obtain the intermediate tensor."""
        return self._intermediate_tensor

    def __enter__(self):
        self._target_layer.construct = self._new_construct
        return self

    def __exit__(self, *_):
        self._target_layer.construct = self._original_construct
        self._intermediate_tensor = None
        return False

    def _new_construct(self, *inputs):
        outputs = self._original_construct(*inputs)
        self._intermediate_tensor = outputs
        return outputs


def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
    """Unify Tensor and numpy.array to numpy.array."""
    if isinstance(x, ms.Tensor):
        x = x.asnumpy()

    if not isinstance(x, np.ndarray):
        raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
                        ' but receive {}'.format(type(x)))
    return x


def calc_correlation(x: Union[ms.Tensor, np.ndarray],
                     y: Union[ms.Tensor, np.ndarray]) -> float:
    """Calculate Pearson correlation coefficient between two vectors."""
    x = format_tensor_to_ndarray(x)
    y = format_tensor_to_ndarray(y)

    if len(x.shape) > 1 or len(y.shape) > 1:
        raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.'
                         .format(len(x.shape), len(y.shape)))

    if np.all(x == 0) or np.all(y == 0):
        return np.float(0)
    faithfulness = np.corrcoef(x, y)[0, 1]
    return faithfulness


def calc_auc(x: _Array) -> _Array:
    """Calculate the Area under Curve."""
    # take mean for multiple patches if the model is fully convolutional model
    if len(x.shape) == 4:
        x = np.mean(np.mean(x, axis=2), axis=3)
    auc = (x.sum() - x[0] - x[-1]) / len(x)
    return auc


def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
    """
    Generate rank order for every pixel in an 2D array.

    The rank order start from 0 to (num_pixel-1). If descending is True, the
    rank order will generate in a descending order, otherwise in ascending
    order.
    """
    if len(inputs.shape) < 2 or len(inputs.shape) > 3:
        raise ValueError('Only support 2D or 3D inputs currently.')

    batch_size = inputs.shape[0]
    flatten_saliency = inputs.reshape(batch_size, -1)
    factor = -1 if descending else 1
    sorted_arg = np.argsort(factor * flatten_saliency, axis=1)
    flatten_rank = np.zeros_like(sorted_arg)
    arange = np.arange(flatten_saliency.shape[1])
    for i in range(batch_size):
        flatten_rank[i][sorted_arg[i]] = arange
    rank_map = flatten_rank.reshape(inputs.shape)
    return rank_map


def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
    """
    Resize the intermediate layer _attribution to the same size as inputs.

    Args:
        inputs (Tensor): The input tensor to be resized.
        size (tuple[int]): The targeted size resize to.
        mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'.

    Returns:
        Tensor, the resized tensor.

    Raises:
        ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear'].
    """
    h, w = size
    if mode == 'nearest_neighbor':
        resize_nn = op.ResizeNearestNeighbor((h, w))
        outputs = resize_nn(inputs)

    elif mode == 'bilinear':
        inputs_np = inputs.asnumpy()
        inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
        array_lst = []
        for inp in inputs_np:
            array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
            image = Image.fromarray(array)
            image = image.resize(size, resample=Image.BILINEAR)
            array = np.asarray(image).astype(np.float32) / 255
            array_lst.append(array[:, :, 0:1])

        resized_np = np.transpose(array_lst, [0, 3, 1, 2])
        outputs = ms.Tensor(resized_np, inputs.dtype)
    else:
        raise ValueError('Unsupported resize mode {}.'.format(mode))

    return outputs