#include "op_plugin/AclOpsInterface.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_utils = at_npu::native::NpuUtils;
namespace {
at::Tensor matmul_opt_nocheck(c10::optional<at::Tensor> out_opt, const at::Tensor& tensor1, const at::Tensor& tensor2)
{
at::NoNamesGuard guard;
auto has_out = out_opt.has_value();
at::Tensor out = out_opt.value_or(at::Tensor());
if (torch_npu::utils::is_npu(tensor1) &&
torch_npu::utils::is_npu(tensor2) &&
tensor1.scalar_type() == at::kHalf &&
tensor2.scalar_type() == at::kHalf &&
at_npu::native::env::CheckBmmV2Enable()) {
auto res = at_npu::native::matmul_by_bmmV2(tensor1, tensor2);
return has_out ? out.set_(res) : res;
}
auto dim_tensor1 = tensor1.dim();
auto dim_tensor2 = tensor2.dim();
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
return has_out ? at::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0) :
tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
at::Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2;
auto size1 = tensor1.sizes();
auto size2 = t2.sizes();
std::vector<int64_t> output_size;
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
if (dim_tensor2 > 1) {
output_size.push_back(size2[dim_tensor2 - 1]);
}
at::Tensor t1 = tensor1.contiguous().view({-1, size1[size1.size() - 1]});
at::Tensor output =
has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size) : at::_unsafe_view(t1.mm(t2), output_size);
return has_out ? out.set_(output) : output;
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
const int64_t m = tensor1.size(-1);
const int64_t p = tensor2.size(-1);
const at::Tensor t2_T = tensor2.transpose(-1, -2);
const at::Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
const at::Tensor res_T = matmul_opt_nocheck(out_opt, t2_T, t1_T);
if (dim_tensor1 == 2) {
at::Tensor res = res_T.transpose(-1, -2).contiguous();
return has_out ? out.set_(res) : res;
} else {
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
shape.push_back(p);
at::Tensor res = res_T.reshape(shape).contiguous();
return has_out ? out.set_(res) : res;
}
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1;
int64_t m1 = tensor1.size(-1);
at::IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max<int64_t>(dim_tensor1 - 2, 0));
int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1;
int64_t p = tensor2.size(-1);
at::IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max<int64_t>(dim_tensor2 - 2, 0));
std::vector<int64_t> expand_batch_portion = at::infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
int expand_batch_product =
std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
std::vector<int64_t> tensor2_bmm_view({expand_batch_product});
tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p});
at::Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).contiguous().view(tensor1_bmm_view);
at::Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).contiguous().view(tensor2_bmm_view);
std::vector<int64_t> output_shape(expand_batch_portion);
if (dim_tensor1 > 1) {
output_shape.push_back(n);
}
if (dim_tensor2 > 1) {
output_shape.push_back(p);
}
at::Tensor output = has_out ?
at::_unsafe_view(at::bmm_out(out, tensor1_expanded, tensor2_expanded), output_shape) :
at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
return has_out ? out.set_(output) : output;
}
TORCH_CHECK(false, "both arguments to matmul need to be at least 1D, but they are ", dim_tensor1, "D and ",
dim_tensor2, "D" + OPS_ERROR(ErrCode::PARAM));
}
}
at::Tensor matmul(const at::Tensor& tensor1, const at::Tensor& tensor2) {
TORCH_CHECK(tensor1.scalar_type() != at::ScalarType::Char && tensor2.scalar_type() != at::ScalarType::Char,
"matmul is not support int8 dtype" + OPS_ERROR(ErrCode::TYPE))
auto maybe_outnames = at::namedinference::compute_matmul_outnames(tensor1, tensor2);
auto result = matmul_opt_nocheck(c10::nullopt, tensor1, tensor2);
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
at::Tensor& matmul_out(const at::Tensor& tensor1, const at::Tensor& tensor2, at::Tensor& result) {
TORCH_CHECK(tensor1.scalar_type() != at::ScalarType::Char && tensor2.scalar_type() != at::ScalarType::Char,
"matmul is not support int8 dtype" + OPS_ERROR(ErrCode::TYPE))
auto maybe_outnames = at::namedinference::compute_matmul_outnames(tensor1, tensor2);
if (!result.is_contiguous()) {
at::Tensor contiguous_result = npu_utils::format_contiguous(result);
matmul_opt_nocheck(c10::optional<at::Tensor>(contiguous_result), tensor1, tensor2);
npu_utils::format_fresh_view(result, contiguous_result);
} else {
matmul_opt_nocheck(c10::optional<at::Tensor>(result), tensor1, tensor2);
}
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
}