from functools import lru_cache
import torch
import torch_npu
from mindspeed.op_builder import GroupMatmulAddOpBuilder
from mindspeed.ops.npu_matmul_add import is_a5
__all__ = ["npu_groupmatmul_add_fp32"]
groupmatmul_add_op_builder = GroupMatmulAddOpBuilder()
def npu_groupmatmul_add_fp32(x, dy, grouplist, grad):
if is_a5():
torch_npu.npu_grouped_matmul_add_(grad.view(grouplist.shape[0], x.shape[-1], dy.shape[-1]), x, dy, grouplist)
else:
groupmatmul_add_ops = groupmatmul_add_op_builder.load()
groupmatmul_add_ops.npu_groupmatmul_add_fp32(x, dy, grouplist.to('npu'), grad)