from functools import lru_cache
import torch
import torch_npu
from mindspeed.op_builder import MatmulAddOpBuilder
__all__ = ["npu_matmul_add_fp32"]
matmul_add_op_builder = MatmulAddOpBuilder()
@lru_cache
def is_a5():
try:
return "Ascend910_95" in torch_npu.npu.get_device_name() or "Ascend950" in torch_npu.npu.get_device_name()
except Exception:
return False
def npu_matmul_add_fp32(total_input, grad_output, grad):
for dim in total_input.shape:
if dim == 0:
return
for dim in grad_output.shape:
if dim == 0:
return
if is_a5():
grad.addmm_(grad_output.t(), total_input)
else:
matmul_add_ops = matmul_add_op_builder.load()
matmul_add_ops.npu_matmul_add_fp32(grad_output, total_input, grad)
def npu_matmul_add_fp16(total_input, grad_output, grad):
for dim in total_input.shape:
if dim == 0:
return
for dim in grad_output.shape:
if dim == 0:
return
grad_weight = grad_output.t().matmul(total_input)
grad.add_(grad_weight)