#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
constexpr size_t X_DIM = 2;
constexpr size_t FUSED_TYPE_ARRAY_SIZE = 100;
void infer_out_batch_shape(const at::Tensor &x1, const at::Tensor &x2, std::vector<uint64_t> &batch_record)
{
auto x1_dim_num = x1.dim();
auto x2_dim_num = x2.dim();
auto out_dim_num = std::max(x1_dim_num, x2_dim_num);
auto &shape_long = x1_dim_num > x2_dim_num ? x1 : x2;
auto &shape_short = x1_dim_num > x2_dim_num ? x2 : x1;
int64_t vaild_offset = out_dim_num - std::min(x1_dim_num, x2_dim_num);
for (int64_t i = 0; i < out_dim_num - X_DIM; i++) {
auto short_dim = i < vaild_offset ? 1 : shape_short.size(i - vaild_offset);
auto long_dim = shape_long.size(i);
TORCH_CHECK(!(short_dim > 1 && long_dim > 1 && short_dim != long_dim),
"the x1 shape and x2 shape not supported for broadcast, the short_dim is ", short_dim,
" and the long_dim is ", long_dim, OPS_ERROR(ErrCode::PARAM));
uint64_t cur_batch_value = static_cast<uint64_t>(std::max(short_dim, long_dim));
batch_record.push_back(cur_batch_value);
}
}
}
at::Tensor npu_fused_matmul(
const at::Tensor &x, const at::Tensor &x2,
const c10::optional<at::Tensor> &bias, const c10::optional<at::Tensor> &x3,
c10::string_view fused_op_type
)
{
auto x1_dim_num = x.dim();
TORCH_CHECK(x1_dim_num >= X_DIM, "x1 shape dim num cannot be less than 2, but it is ",
x1_dim_num);
auto x2_dim_num = x2.dim();
TORCH_CHECK(x2_dim_num >= X_DIM, "x2 shape dim num cannot be less than 2, but it is ",
x2_dim_num);
auto x1_m_dim = x.size(x1_dim_num - X_DIM);
auto x1_k_dim = x.size(x1_dim_num - 1);
auto x2_n_dim = x2.size(x2_dim_num - 1);
auto x2_k_dim = x2.size(x2_dim_num - X_DIM);
TORCH_CHECK(x1_k_dim == x2_k_dim, "The k of x1 and x2 should be equal. but x1_k_dim is ",
x1_k_dim, ", x2_k_dim is ", x2_k_dim);
std::vector<uint64_t> batch_record;
infer_out_batch_shape(x, x2, batch_record);
const at::Tensor long_tensor = x1_dim_num > x2_dim_num ? x : x2;
auto output_size = op_infer::array_to_small_vector(long_tensor.sizes());
output_size[long_tensor.dim() - X_DIM] = x1_m_dim;
output_size[long_tensor.dim() - 1] = x2_n_dim;
for (int64_t i = 0; i < long_tensor.dim() - X_DIM; i++) {
output_size[i] = static_cast<int64_t>(batch_record[i]);
}
auto output_type = x.dtype();
if (fused_op_type == "16cast32") {
output_type = at::kFloat;
}
auto result = at_npu::native::OpPreparation::apply_tensor_without_format(output_size,
output_type);
const at::Tensor &x3_real = x3.value_or(at::Tensor());
const at::Tensor &bias_real = bias.value_or(at::Tensor());
int8_t cube_math_type = op_plugin::utils::get_cube_math_type_with_passthrough();
char fused_type[FUSED_TYPE_ARRAY_SIZE] = {0};
TORCH_CHECK(std::string(fused_op_type).size() <= FUSED_TYPE_ARRAY_SIZE,
"the len of fused_op_type is bigger than the default");
std::string(fused_op_type).copy(fused_type, FUSED_TYPE_ARRAY_SIZE);
EXEC_NPU_CMD(aclnnFusedMatmul, x, x2, bias_real, x3_real, fused_type, cube_math_type, result);
return result;
}
}