62fbbf9e创建于 2021年3月20日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base class for XAI metrics."""

import copy
import math
from typing import Callable

import numpy as np

import mindspore as ms
from mindspore import log as logger
from mindspore.train._utils import check_value_type
from ..._operators import Tensor
from ..._utils import format_tensor_to_ndarray
from ...explanation._attribution.attribution import Attribution

_Explainer = Attribution


def verify_argument(inputs, arg_name):
    """Verify the validity of the parsed arguments."""
    check_value_type(arg_name, inputs, Tensor)
    if len(inputs.shape) != 4:
        raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
    if len(inputs) > 1:
        raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))


def verify_targets(targets, num_labels):
    """Verify the validity of the parsed targets."""
    check_value_type('targets', targets, (int, Tensor))

    if isinstance(targets, Tensor):
        if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
            raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
                             'it should have the length = 1 as we only support single evaluation now.')
        targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
    if targets > num_labels - 1 or targets < 0:
        raise ValueError('Parsed targets exceed the label range.')


class AttributionMetric:
    """Super class of XAI metric class used in classification scenarios."""

    def __init__(self):
        self._explainer = None

    evaluate: Callable
    """
    This method evaluates the explainer on the given attribution and returns the evaluation results.
    Derived class should implement this method according to specific algorithms of the metric.
    """

    def _record_explainer(self, explainer: _Explainer):
        """Record the explainer in current evaluation."""
        if self._explainer is None:
            self._explainer = explainer
        elif self._explainer is not explainer:
            logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated '
                        'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer)
            self._explainer = explainer


class LabelAgnosticMetric(AttributionMetric):
    """Super class add functions for label-agnostic metric."""

    def __init__(self):
        super().__init__()
        self._global_results = []

    @property
    def performance(self) -> float:
        """
        Return the average evaluation result.

        Return:
            float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
        """
        result_sum, count = 0, 0
        for res in self._global_results:
            if math.isfinite(res):
                result_sum += res
                count += 1
        return 0. if count == 0 else result_sum / count

    def aggregate(self, result):
        """Aggregate single evaluation result to global results."""
        if isinstance(result, float):
            self._global_results.append(result)
        elif isinstance(result, (ms.Tensor, np.ndarray)):
            result = format_tensor_to_ndarray(result)
            self._global_results.extend([float(res) for res in result.reshape(-1)])
        else:
            raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))

    def get_results(self):
        """Return the global results."""
        return self._global_results.copy()

    def reset(self):
        """Reset global results."""
        self._global_results.clear()

    def _check_evaluate_param(self, explainer, inputs):
        """Check the evaluate parameters."""
        check_value_type('explainer', explainer, Attribution)
        self._record_explainer(explainer)
        verify_argument(inputs, 'inputs')


class LabelSensitiveMetric(AttributionMetric):
    """Super class add functions for label-sensitive metrics."""

    def __init__(self, num_labels: int):
        super().__init__()
        LabelSensitiveMetric._verify_params(num_labels)
        self._num_labels = num_labels
        self._global_results = {i: [] for i in range(num_labels)}

    @property
    def num_labels(self):
        """Number of labels used in evaluation."""
        return self._num_labels

    @staticmethod
    def _verify_params(num_labels):
        """Checks whether num_labels is valid."""
        check_value_type("num_labels", num_labels, int)
        if num_labels < 1:
            raise ValueError("Argument num_labels must be parsed with a integer > 0.")

    def aggregate(self, result, targets):
        """Aggregates single result to global_results."""
        if isinstance(result, float):
            if isinstance(targets, int):
                self._global_results[targets].append(result)
            else:
                target_np = format_tensor_to_ndarray(targets)
                if len(target_np) > 1:
                    raise ValueError("One result can not be aggreated to multiple targets.")
        elif isinstance(result, (ms.Tensor, np.ndarray)):
            result_np = format_tensor_to_ndarray(result).reshape(-1)
            if isinstance(targets, int):
                for res in result_np:
                    self._global_results[targets].append(float(res))
            else:
                target_np = format_tensor_to_ndarray(targets).reshape(-1)
                if len(target_np) != len(result_np):
                    raise ValueError("Length of result does not match with length of targets.")
                for tar, res in zip(target_np, result_np):
                    self._global_results[int(tar)].append(float(res))
        else:
            raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))

    def reset(self):
        """Resets global_result."""
        self._global_results = {i: [] for i in range(self._num_labels)}

    @property
    def class_performances(self):
        """
        Get the class performances by global result.

        Returns:
            (:class:`list`): a list of performances where each value is the average score of specific class.
        """
        results_on_labels = []
        for label_id in range(self._num_labels):
            sum_of_label, count_of_label = 0, 0
            for res in self._global_results[label_id]:
                if math.isfinite(res):
                    sum_of_label += res
                    count_of_label += 1
            results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label)
        return results_on_labels

    @property
    def performance(self):
        """
        Get the performance by global result.

        Returns:
            (:class:`float`): mean performance.
        """
        result_sum, count = 0, 0
        for label_id in range(self._num_labels):
            for res in self._global_results[label_id]:
                if math.isfinite(res):
                    result_sum += res
                    count += 1
        return 0. if count == 0 else result_sum / count

    def get_results(self):
        """Global result of the metric can be return"""
        return copy.deepcopy(self._global_results)

    def _check_evaluate_param(self, explainer, inputs, targets, saliency):
        """Check the evaluate parameters."""
        check_value_type('explainer', explainer, Attribution)
        self._record_explainer(explainer)
        verify_argument(inputs, 'inputs')
        output = explainer.network(inputs)
        check_value_type("output of explainer model", output, Tensor)
        output_dim = explainer.network(inputs).shape[1]
        if output_dim != self._num_labels:
            raise ValueError("The output dimension of of black-box model in explainer does not match the dimension "
                             "of num_labels set in the __init__, please check explainer and num_labels again.")
        verify_targets(targets, self._num_labels)
        check_value_type('saliency', saliency, (Tensor, type(None)))