"""Base class for XAI metrics."""
import copy
import math
from typing import Callable
import numpy as np
import mindspore as ms
from mindspore import log as logger
from mindspore.train._utils import check_value_type
from ..._operators import Tensor
from ..._utils import format_tensor_to_ndarray
from ...explanation._attribution.attribution import Attribution
_Explainer = Attribution
def verify_argument(inputs, arg_name):
"""Verify the validity of the parsed arguments."""
check_value_type(arg_name, inputs, Tensor)
if len(inputs.shape) != 4:
raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
if len(inputs) > 1:
raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
def verify_targets(targets, num_labels):
"""Verify the validity of the parsed targets."""
check_value_type('targets', targets, (int, Tensor))
if isinstance(targets, Tensor):
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
'it should have the length = 1 as we only support single evaluation now.')
targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
if targets > num_labels - 1 or targets < 0:
raise ValueError('Parsed targets exceed the label range.')
class AttributionMetric:
"""Super class of XAI metric class used in classification scenarios."""
def __init__(self):
self._explainer = None
evaluate: Callable
"""
This method evaluates the explainer on the given attribution and returns the evaluation results.
Derived class should implement this method according to specific algorithms of the metric.
"""
def _record_explainer(self, explainer: _Explainer):
"""Record the explainer in current evaluation."""
if self._explainer is None:
self._explainer = explainer
elif self._explainer is not explainer:
logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated '
'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer)
self._explainer = explainer
class LabelAgnosticMetric(AttributionMetric):
"""Super class add functions for label-agnostic metric."""
def __init__(self):
super().__init__()
self._global_results = []
@property
def performance(self) -> float:
"""
Return the average evaluation result.
Return:
float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
"""
result_sum, count = 0, 0
for res in self._global_results:
if math.isfinite(res):
result_sum += res
count += 1
return 0. if count == 0 else result_sum / count
def aggregate(self, result):
"""Aggregate single evaluation result to global results."""
if isinstance(result, float):
self._global_results.append(result)
elif isinstance(result, (ms.Tensor, np.ndarray)):
result = format_tensor_to_ndarray(result)
self._global_results.extend([float(res) for res in result.reshape(-1)])
else:
raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
def get_results(self):
"""Return the global results."""
return self._global_results.copy()
def reset(self):
"""Reset global results."""
self._global_results.clear()
def _check_evaluate_param(self, explainer, inputs):
"""Check the evaluate parameters."""
check_value_type('explainer', explainer, Attribution)
self._record_explainer(explainer)
verify_argument(inputs, 'inputs')
class LabelSensitiveMetric(AttributionMetric):
"""Super class add functions for label-sensitive metrics."""
def __init__(self, num_labels: int):
super().__init__()
LabelSensitiveMetric._verify_params(num_labels)
self._num_labels = num_labels
self._global_results = {i: [] for i in range(num_labels)}
@property
def num_labels(self):
"""Number of labels used in evaluation."""
return self._num_labels
@staticmethod
def _verify_params(num_labels):
"""Checks whether num_labels is valid."""
check_value_type("num_labels", num_labels, int)
if num_labels < 1:
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
def aggregate(self, result, targets):
"""Aggregates single result to global_results."""
if isinstance(result, float):
if isinstance(targets, int):
self._global_results[targets].append(result)
else:
target_np = format_tensor_to_ndarray(targets)
if len(target_np) > 1:
raise ValueError("One result can not be aggreated to multiple targets.")
elif isinstance(result, (ms.Tensor, np.ndarray)):
result_np = format_tensor_to_ndarray(result).reshape(-1)
if isinstance(targets, int):
for res in result_np:
self._global_results[targets].append(float(res))
else:
target_np = format_tensor_to_ndarray(targets).reshape(-1)
if len(target_np) != len(result_np):
raise ValueError("Length of result does not match with length of targets.")
for tar, res in zip(target_np, result_np):
self._global_results[int(tar)].append(float(res))
else:
raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
def reset(self):
"""Resets global_result."""
self._global_results = {i: [] for i in range(self._num_labels)}
@property
def class_performances(self):
"""
Get the class performances by global result.
Returns:
(:class:`list`): a list of performances where each value is the average score of specific class.
"""
results_on_labels = []
for label_id in range(self._num_labels):
sum_of_label, count_of_label = 0, 0
for res in self._global_results[label_id]:
if math.isfinite(res):
sum_of_label += res
count_of_label += 1
results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label)
return results_on_labels
@property
def performance(self):
"""
Get the performance by global result.
Returns:
(:class:`float`): mean performance.
"""
result_sum, count = 0, 0
for label_id in range(self._num_labels):
for res in self._global_results[label_id]:
if math.isfinite(res):
result_sum += res
count += 1
return 0. if count == 0 else result_sum / count
def get_results(self):
"""Global result of the metric can be return"""
return copy.deepcopy(self._global_results)
def _check_evaluate_param(self, explainer, inputs, targets, saliency):
"""Check the evaluate parameters."""
check_value_type('explainer', explainer, Attribution)
self._record_explainer(explainer)
verify_argument(inputs, 'inputs')
output = explainer.network(inputs)
check_value_type("output of explainer model", output, Tensor)
output_dim = explainer.network(inputs).shape[1]
if output_dim != self._num_labels:
raise ValueError("The output dimension of of black-box model in explainer does not match the dimension "
"of num_labels set in the __init__, please check explainer and num_labels again.")
verify_targets(targets, self._num_labels)
check_value_type('saliency', saliency, (Tensor, type(None)))