# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
from functools import lru_cache
import torch
import torch_npu
from mindspeed.op_builder import GroupMatmulAddOpBuilder
from mindspeed.ops.npu_matmul_add import is_a5

__all__ = ["npu_groupmatmul_add_fp32"]

groupmatmul_add_op_builder = GroupMatmulAddOpBuilder()


def npu_groupmatmul_add_fp32(x, dy, grouplist, grad):
    if is_a5():
        torch_npu.npu_grouped_matmul_add_(grad.view(grouplist.shape[0], x.shape[-1], dy.shape[-1]), x, dy, grouplist)
    else:
        groupmatmul_add_ops = groupmatmul_add_op_builder.load()
        groupmatmul_add_ops.npu_groupmatmul_add_fp32(x, dy, grouplist.to('npu'), grad)