#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
namespace acl_op {
namespace {
std::tuple<at::Tensor, at::Tensor> npu_matmul_backward(const at::Tensor &grad, const at::Tensor &self,
const at::Tensor &other, std::array<bool, 2> mask,
at::Tensor &grad_self, at::Tensor &grad_other)
{
auto dim_self = self.dim();
auto dim_other = other.dim();
auto size_grad = grad.sizes();
auto size_self = self.sizes();
auto size_other = other.sizes();
if (dim_self == 1 && dim_other == 1) {
grad_self = mask[0] ? other.mul(grad) : grad_self;
grad_other = mask[1] ? self.mul(grad) : grad_other;
} else if (dim_self == 2 && dim_other == 1) {
grad_self = mask[0] ? grad.unsqueeze(1).mm(other.unsqueeze(0)) : grad_self;
grad_other = mask[1] ? self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) : grad_other;
} else if (dim_self == 1 && dim_other == 2) {
grad_self = mask[0] ? grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) : grad_self;
grad_other = mask[1] ? self.unsqueeze(1).mm(grad.unsqueeze(0)) : grad_other;
} else if (dim_self >= 3 && (dim_other == 1 || dim_other == 2)) {
const int64_t view_size = dim_other == 1 ? 1 : size_grad[size_grad.size() - 1];
auto unfolded_grad = (dim_other == 1 ? grad.unsqueeze(-1) : grad).contiguous().view({-1, view_size});
if (mask[0]) {
grad_self = unfolded_grad.mm(dim_other == 1 ? other.unsqueeze(0) : other.transpose(-1, -2)).view(size_self);
}
if (mask[1]) {
auto unfolded_self = self.contiguous().view({-1, size_self[dim_self - 1]});
grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other);
}
} else if ((dim_self == 1 || dim_self == 2) && dim_other >= 3) {
const int64_t view_size = dim_self == 1 ? 1 : size_grad[size_grad.size() - 2];
auto unfolded_grad_T =
dim_self == 1 ? grad.view({-1, view_size}) : grad.transpose(-1, -2).contiguous().view({-1, view_size});
if (mask[0]) {
auto unfolded_other_T =
other.transpose(-1, -2).contiguous().view({-1, size_other[dim_other - 2]}).transpose(-1, -2);
grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self);
}
if (mask[1]) {
std::vector<int64_t> size_other_T(size_other.begin(), size_other.end() - 2);
size_other_T.insert(size_other_T.end(), {size_other[dim_other - 1], size_other[dim_other - 2]});
grad_other =
unfolded_grad_T.mm(dim_self == 1 ? self.unsqueeze(0) : self).view(size_other_T).transpose(-1, -2);
}
} else {
grad_self = mask[0] ? acl_op::matmul(grad, other.transpose(-1, -2)) : grad_self;
grad_other = mask[1] ? acl_op::matmul(self.transpose(-1, -2), grad) : grad_other;
}
return std::make_tuple(grad_self, grad_other);
}
}
std::tuple<at::Tensor, at::Tensor> matmul_backward(const at::Tensor &grad, const at::Tensor &self,
const at::Tensor &other, std::array<bool, 2> mask)
{
if (!grad.defined()) {
return std::make_tuple(at::Tensor(), at::Tensor());
}
TORCH_CHECK(self.dim() > 0 && other.dim() > 0, "both matrices must be at least 1D,"
"but they are", self.dim(), "D and ", other.dim(), "D" + OPS_ERROR(ErrCode::PARAM));
at::Tensor grad_self;
at::Tensor grad_other;
if (!mask[0] && !mask[1]) {
return std::make_tuple(grad_self, grad_other);
}
if (torch_npu::utils::is_npu(self) && torch_npu::utils::is_npu(other) && self.scalar_type() == at::kHalf &&
other.scalar_type() == at::kHalf && at_npu::native::env::CheckBmmV2Enable()) {
grad_self = mask[0] ? at_npu::native::matmul_by_bmmV2(grad, other.transpose(-1, -2)) : grad_self;
grad_other = mask[1] ? at_npu::native::matmul_by_bmmV2(self.transpose(-1, -2), grad) : grad_other;
return std::make_tuple(grad_self, grad_other);
}
npu_matmul_backward(grad, self, other, mask, grad_self, grad_other);
return std::make_tuple(grad_self, grad_other);
}
}