e0cccf2e创建于 2025年6月20日历史提交
__all__ = []

import math
import torch
from torch import matmul
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error


def _matmul_checksum(a, b, c):
    r"""
    Compare whether there are any feature anomalies in the calculation results of matmul.
    Args:
        a(Tensor): matmul's input parameter a, and the device must be npu.
        b(Tensor): matmul's input parameter b, and the device must be npu.
        c(Tensor): matmul's output result c, and the device must be npu.

    Returns: The bool scalar tensor, located on the npu side, indicates whether there are any anomalies in the calculation result.

    """
    if not isinstance(a, torch.Tensor) or a.device.type != 'npu':
        raise TypeError(f"tensor should be torch.Tensor, and device type should be npu" + pta_error(ErrCode.PARAM))
    if not isinstance(b, torch.Tensor) or b.device.type != 'npu':
        raise TypeError(f"tensor should be torch.Tensor, and device type should be npu" + pta_error(ErrCode.PARAM))
    if not isinstance(c, torch.Tensor) or c.device.type != 'npu':
        raise TypeError(f"tensor should be torch.Tensor, and device type should be npu" + pta_error(ErrCode.PARAM))

    c_sum = torch.sum(c, dim=-1, dtype=torch.float32)
    b1 = torch.sum(b, dim=-1, keepdim=True, dtype=torch.float32)
    c1 = matmul(a.to(torch.float32), b1)
    c1_trans = c1.squeeze(-1)
    n_b = b.shape[-1]

    c_max, _ = torch.max(torch.abs(c), dim=-1)
    c_mean = torch.mean(torch.abs(c), dim=-1)
    if torch.min(c_max / c_mean) > 5:
        c_ele_round_error_accum = c_max * 2 ** (-8) * math.sqrt(n_b)
    else:
        c_ele_round_error_accum = c_mean * 2 ** (-8) * n_b

    error_total = (c_ele_round_error_accum).to(torch.float)

    error = torch.abs(c_sum - c1_trans)
    flag = (error - 5 * error_total) > 5 * 1e-20
    any_flag = torch.any(flag)
    if any_flag:
        matmul(a, b, out=c)
        c_mean2 = torch.mean(torch.abs(c), dim=-1)
        return torch.any(c_mean != c_mean2)
    return any_flag