42e2723c创建于 2025年7月14日历史提交
# Copyright (c) 2025, Huawei Technologies Co., Ltd.  All rights reserved.
import torch_npu


def npu_matmul_add_fp32(total_input, grad_output, grad):
    # check if any dimension of total_input is 0
    for dim in total_input.shape:
        if dim == 0:
            return

    # check if any dimension of grad_output is 0
    for dim in grad_output.shape:
        if dim == 0:
            return

    torch_npu.npu_matmul_add_fp32(grad_output, total_input, grad)