import torch
import torch_npu
__all__ = []
class MatmulApply(torch.autograd.Function):
"""Using NPU custom operator to replace the native writing method to improve performance.
Compute Function:
attn = (q @ k.transpose(-2, -1))
This interface is faster than the original on NPU.
.. note::
In the dynamic shape scene, Due to the operator restriction, the broadcast scene is not supported.
Args:
tensor1 (Tensor): the first tensor to be multiplied.
tensor2 (Tensor): the second tensor to be multiplied.
Returns:
Tensor: the output tensor.
Examples::
>>> tensor1 = torch.randn(68, 5, 75, 16).npu()
>>> tensor1.requires_grad_(True)
>>> tensor2 = torch.randn(68, 5, 75, 16).npu()
>>> tensor2.requires_grad_(True)
>>> output = matmul_transpose(tensor1, tensor2)
>>> output.sum().backward()
"""
@staticmethod
def forward(ctx, self, mat2):
ctx.save_for_backward(self, mat2)
result = torch.matmul(self, mat2.transpose(-2, -1))
return result.detach()
@staticmethod
def backward(ctx, grad):
self, mat2 = ctx.saved_tensors
self_grad = torch_npu.npu_bmmV2(grad, mat2, [])
mat2_grad = torch_npu.npu_bmmV2(grad.transpose(-2, -1), self, [])
return self_grad, mat2_grad