#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/custom_functions/aclops/inner_compute.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
namespace {
static const int64_t TENSORS_DIMS_LOWER_LIMIT = 2;
static const int64_t PENULT_DIM = -2;
bool is_transpose_last_two_dims_v2(const at::Tensor &Tensors)
{
if (Tensors.dim() < TENSORS_DIMS_LOWER_LIMIT) {
return false;
}
int64_t numel = at_npu::native::NPUNativeFunctions::get_storage_size(Tensors);
int64_t dim1 = Tensors.dim() - 1;
int64_t dim2 = Tensors.dim() - TENSORS_DIMS_LOWER_LIMIT;
TORCH_CHECK(Tensors.element_size() > 0,
"expected Tensors valid, "
"but input Tensors has element_size ",
Tensors.element_size(),
OPS_ERROR(ErrCode::PARAM));
int64_t tensor_size = static_cast<int64_t>(Tensors.storage().nbytes()) / Tensors.element_size();
auto tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(Tensors)->get_npu_desc();
if (tensor_desc.base_sizes_.size() == static_cast<uint64_t>(Tensors.dim()) && Tensors.stride(dim2) == 1 &&
Tensors.stride(dim1) == Tensors.size(dim2) && Tensors.size(dim1) == tensor_desc.base_sizes_[dim2] &&
Tensors.size(dim2) == tensor_desc.base_sizes_[dim1] && tensor_size == numel) {
return true;
} else {
return false;
}
}
c10::SmallVector<int64_t, SIZE> bmm_v2_output_size(const at::Tensor &mat1, const at::Tensor &mat2)
{
auto dim_tensor1 = mat1.dim();
auto dim_tensor2 = mat2.dim();
int64_t m = dim_tensor1 == 1 ? 1 : mat1.size(PENULT_DIM);
int64_t n = dim_tensor2 == 1 ? 1 : mat2.size(-1);
auto batch_a = op_infer::array_to_small_vector(
at::IntArrayRef(mat1.sizes().data(), std::max<int64_t>(dim_tensor1 + PENULT_DIM, 0)));
auto batch_b = op_infer::array_to_small_vector(
at::IntArrayRef(mat2.sizes().data(), std::max<int64_t>(dim_tensor2 + PENULT_DIM, 0)));
batch_a.insert(batch_a.begin(), std::max<int64_t>(batch_a.size(), batch_b.size()) - batch_a.size(), 1);
batch_b.insert(batch_b.begin(), std::max<int64_t>(batch_a.size(), batch_b.size()) - batch_b.size(), 1);
c10::SmallVector<int64_t, SIZE> output_size;
for (size_t i = 0; i < batch_a.size(); ++i) {
if (batch_a[i] == 1) {
output_size.emplace_back(batch_b[i]);
} else if (batch_b[i] == 1) {
output_size.emplace_back(batch_a[i]);
} else if (batch_a[i] != batch_b[i]) {
AT_ERROR("mat1 and mat2 cannot broadcast, but they are mat1 ", mat1.sizes().data(), " mat2 ",
mat2.sizes().data());
} else {
output_size.emplace_back(batch_a[i]);
}
}
output_size.emplace_back(m);
output_size.emplace_back(n);
return output_size;
}
at::Tensor pure_bmm_v2(const at::Tensor &self, const at::Tensor &mat2,
const c10::SmallVector<int64_t, SIZE> &output_size)
{
auto tensor1 = self.dim() == 1 ? self.view({1, self.size(0)}) : self;
auto tensor2 = mat2.dim() == 1 ? mat2.view({mat2.size(0), 1}) : mat2;
at::Tensor result =
(tensor1.scalar_type() == at::ScalarType::Half) ?
npu_preparation::apply_tensor_with_format(output_size, tensor1.options(), ACL_FORMAT_FRACTAL_NZ, true) :
npu_preparation::apply_tensor_with_format(output_size, tensor1.options(), ACL_FORMAT_ND);
at::Tensor contiguous_self = tensor1;
at::Tensor contiguous_mat2 = tensor2;
bool is_self_t = is_transpose_last_two_dims_v2(tensor1);
bool is_mat2_t = is_transpose_last_two_dims_v2(tensor2);
if (!is_self_t) {
contiguous_self = npu_utils::format_contiguous(tensor1);
}
if (!is_mat2_t) {
contiguous_mat2 = npu_utils::format_contiguous(tensor2);
}
at_npu::native::OpCommand cmd;
cmd.Name("BatchMatMul")
.InputWithoutContiguous(contiguous_self)
.InputWithoutContiguous(contiguous_mat2)
.Output(result)
.Attr("adj_x1", is_self_t)
.Attr("adj_x2", is_mat2_t)
.Run();
return result;
}
at::Tensor reshape_tensor_self(const at::Tensor &self, c10::SmallVector<int64_t, SIZE> &expect_output_size)
{
c10::SmallVector<int64_t, SIZE> self_permute_idx;
c10::SmallVector<int64_t, SIZE> self_batch_idx;
for (int64_t i = 0; i < self.dim(); ++i) {
if (i < self.dim() - 2) {
if (expect_output_size[i] == 1) {
self_batch_idx.emplace_back(i);
continue;
}
} else if (i == self.dim() - 1) {
for (uint64_t j = 0; j < self_batch_idx.size(); ++j) {
self_permute_idx.emplace_back(self_batch_idx[j]);
}
}
self_permute_idx.emplace_back(i);
}
at::Tensor tmp_self = self.permute(self_permute_idx);
int64_t m_idx = 0;
c10::SmallVector<int64_t, SIZE> tmp_self_size;
c10::SmallVector<int64_t, SIZE> tmp_self_size_low;
m_idx = self.dim() - static_cast<int64_t>(self_batch_idx.size()) - 1;
tmp_self_size = op_infer::array_to_small_vector(tmp_self.sizes());
tmp_self_size_low.insert(tmp_self_size_low.end(), tmp_self_size.begin(), tmp_self_size.begin() + m_idx);
tmp_self_size_low.emplace_back(-1);
tmp_self = tmp_self.reshape(tmp_self_size_low);
return tmp_self;
}
at::Tensor reshape_tensor_mat2(const at::Tensor &mat2, c10::SmallVector<int64_t, SIZE> &expect_output_size)
{
c10::SmallVector<int64_t, SIZE> mat2_permute_idx;
c10::SmallVector<int64_t, SIZE> mat2_batch_idx;
for (int64_t i = 0; i < mat2.dim(); ++i) {
if (i < mat2.dim() - 2) {
if (expect_output_size[i] == 1) {
mat2_batch_idx.emplace_back(i);
continue;
}
} else if (i == mat2.dim() - 2) {
for (uint64_t j = 0; j < mat2_batch_idx.size(); ++j) {
mat2_permute_idx.emplace_back(mat2_batch_idx[j]);
}
}
mat2_permute_idx.emplace_back(i);
}
at::Tensor tmp_mat2 = mat2.permute(mat2_permute_idx);
int64_t k_idx = 0;
c10::SmallVector<int64_t, SIZE> tmp_mat2_size;
c10::SmallVector<int64_t, SIZE> tmp_mat2_size_low;
k_idx = mat2.dim() - static_cast<int64_t>(mat2_batch_idx.size()) - 2;
tmp_mat2_size = op_infer::array_to_small_vector(tmp_mat2.sizes());
tmp_mat2_size_low.insert(tmp_mat2_size_low.end(), tmp_mat2_size.begin(), tmp_mat2_size.begin() + k_idx);
tmp_mat2_size_low.insert(tmp_mat2_size_low.end(), {-1, mat2.size(-1)});
tmp_mat2 = tmp_mat2.reshape(tmp_mat2_size_low);
return tmp_mat2;
}
c10::SmallVector<int64_t, SIZE> align_small_vector(c10::SmallVector<int64_t, SIZE> svec,
c10::SmallVector<int64_t, SIZE> golden_svec)
{
c10::SmallVector<int64_t, SIZE> tmp_svec;
tmp_svec = svec;
int64_t size_to_fill = static_cast<int64_t>(golden_svec.size() - svec.size());
if (size_to_fill > 0) {
tmp_svec.insert(tmp_svec.begin(), size_to_fill, 1);
}
return tmp_svec;
}
void expand_tensor(at::Tensor &self, at::Tensor &mat2, c10::SmallVector<int64_t, SIZE> &expand_output_size)
{
self = self.dim() == 1 ? self.view({1, self.size(0)}) : self;
mat2 = mat2.dim() == 1 ? mat2.view({mat2.size(0), 1}) : mat2;
int64_t m = self.size(PENULT_DIM);
int64_t k1 = self.size(-1);
int64_t k2 = mat2.size(PENULT_DIM);
int64_t n = mat2.size(-1);
std::vector<int64_t> expand_batch_portion(expand_output_size.begin(), expand_output_size.end() - 2);
std::vector<int64_t> self_expand_size(expand_batch_portion);
std::vector<int64_t> mat2_expand_size(expand_batch_portion);
self_expand_size.insert(self_expand_size.end(), {m, k1});
mat2_expand_size.insert(mat2_expand_size.end(), {k2, n});
int64_t expand_batch_product =
std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1L, std::multiplies<int64_t>());
std::vector<int64_t> self_bmm_view({expand_batch_product});
std::vector<int64_t> mat2_bmm_view({expand_batch_product});
self_bmm_view.insert(self_bmm_view.end(), {m, k1});
mat2_bmm_view.insert(mat2_bmm_view.end(), {k2, n});
self = self.expand(self_expand_size).reshape(self_bmm_view);
mat2 = mat2.expand(mat2_expand_size).reshape(mat2_bmm_view);
}
at::Tensor npu_bmmV2_impl(const at::Tensor &self, const at::Tensor &mat2, at::IntArrayRef output_sizes)
{
auto expect_output_size = op_infer::array_to_small_vector(output_sizes);
auto infer_output_size = bmm_v2_output_size(self, mat2);
at::Tensor tmp_self = self;
at::Tensor tmp_mat2 = mat2;
if (expect_output_size.empty()) {
if (tmp_self.dim() == tmp_mat2.dim()) {
return pure_bmm_v2(tmp_self, tmp_mat2, infer_output_size);
}
expand_tensor(tmp_self, tmp_mat2, infer_output_size);
expect_output_size = infer_output_size;
infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2);
auto res = pure_bmm_v2(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size);
infer_output_size = expect_output_size;
if (self.dim() == 1) {
infer_output_size.erase(infer_output_size.end() - 2);
return res.view(infer_output_size);
} else if (mat2.dim() == 1) {
infer_output_size.erase(infer_output_size.end() - 1);
return res.view(infer_output_size);
}
return res;
}
c10::SmallVector<int64_t, SIZE> axis_reduce;
c10::SmallVector<int64_t, SIZE> tmp_self_size;
c10::SmallVector<int64_t, SIZE> tmp_mat2_size;
c10::SmallVector<int64_t, SIZE> tmp_expect_output_size = align_small_vector(expect_output_size, infer_output_size);
for (int i = 0; i < static_cast<int64_t>(tmp_expect_output_size.size()); ++i) {
if (tmp_expect_output_size[i] != infer_output_size[i]) {
axis_reduce.emplace_back(i);
}
}
if (axis_reduce.empty()) {
if (tmp_self.dim() == tmp_mat2.dim()) {
return pure_bmm_v2(tmp_self, tmp_mat2, infer_output_size);
}
expand_tensor(tmp_self, tmp_mat2, infer_output_size);
infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2);
return pure_bmm_v2(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size);
}
tmp_self_size = align_small_vector(op_infer::array_to_small_vector(self.sizes()), infer_output_size);
tmp_mat2_size = align_small_vector(op_infer::array_to_small_vector(mat2.sizes()), infer_output_size);
tmp_self = self.reshape(tmp_self_size);
tmp_mat2 = mat2.reshape(tmp_mat2_size);
tmp_self = reshape_tensor_self(tmp_self, tmp_expect_output_size);
tmp_mat2 = reshape_tensor_mat2(tmp_mat2, tmp_expect_output_size);
infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2);
expand_tensor(tmp_self, tmp_mat2, infer_output_size);
infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2);
return pure_bmm_v2(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size);
}
}
at::Tensor npu_bmm_v2_mat1_backward_symint(const at::Tensor &grad, const at::Tensor &mat1, const at::Tensor &mat2,
c10::SymIntArrayRef sizes_symint)
{
at::IntArrayRef sizes = c10::asIntArrayRefUnchecked(sizes_symint);
auto grad_with_full_size = grad;
std::vector<int64_t> axis_reshape(grad.sizes().begin(), grad.sizes().end());
if (mat1.dim() == 1) {
axis_reshape.insert(axis_reshape.begin() + axis_reshape.size() - 1, 1);
} else if (mat2.dim() == 1) {
axis_reshape.insert(axis_reshape.end(), 1);
}
at::Tensor mat2_cp;
if (mat2.dim() == 1) {
mat2_cp = mat2.view({1, mat2.size(0)});
} else {
TORCH_CHECK(mat2.dim() >= 2, "mat2.dim must be greater than or equal to 1, but got ",
mat2.dim(), OPS_ERROR(ErrCode::PARAM));
mat2_cp = mat2.transpose(-2, -1);
}
return acl_op::npu_bmmV2(grad.view(axis_reshape), mat2_cp, sizes);
}
at::Tensor npu_bmm_v2_mat2_backward_symint(const at::Tensor &grad, const at::Tensor &mat1, const at::Tensor &mat2,
c10::SymIntArrayRef sizes_symint)
{
at::IntArrayRef sizes = c10::asIntArrayRefUnchecked(sizes_symint);
auto grad_with_full_size = grad;
std::vector<int64_t> axis_reshape(grad.sizes().begin(), grad.sizes().end());
if (mat1.dim() == 1) {
axis_reshape.insert(axis_reshape.begin() + axis_reshape.size() - 1, 1);
} else if (mat2.dim() == 1) {
axis_reshape.insert(axis_reshape.end(), 1);
}
if (mat1.dim() == 1) {
return acl_op::npu_bmmV2(mat1.view({mat1.size(0), 1}), grad.view(axis_reshape), sizes);
}
return acl_op::npu_bmmV2(mat1.transpose(-2, -1), grad.view(axis_reshape), sizes);
}
at::Tensor npu_bmmV2(const at::Tensor &self, const at::Tensor &mat2, at::IntArrayRef output_sizes)
{
TORCH_CHECK(self.scalar_type() != at::ScalarType::Char && mat2.scalar_type() != at::ScalarType::Char,
"bmm is not support int8 dtype" + OPS_ERROR(ErrCode::TYPE));
return npu_bmmV2_impl(self, mat2, output_sizes);
}
}