#include <vector>
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
namespace op_api {
constexpr size_t LAST_SECOND_DIM_INDEX = 2;
constexpr int64_t PERGROUP_DIM_NUM = 2;
constexpr int64_t INT4_NUMS_IN_INT32 = 8;
static const uint64_t GROUP_MAX = 65535UL;
static const size_t A8W4_GROUP_DIM = 3;
static const size_t A8W4_INPUT_DIM = 2;
using npu_preparation = at_npu::native::OpPreparation;
bool static is_transpose_last_two_dims(const at::Tensor &tensor)
{
if (tensor.dim() < 2 || tensor.dim() > 6) {
return false;
}
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
if (tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2)) {
int64_t tmpNxD = tensor.size(dim1) * tensor.size(dim2);
for (int64_t batchDim = tensor.dim() - 3; batchDim >= 0; batchDim--) {
if (tensor.stride(batchDim) != tmpNxD) {
return false;
}
tmpNxD *= tensor.size(batchDim);
}
if (tensor.size(dim1) == 1 && tensor.size(dim2) == 1) {
return false;
}
return true;
}
return false;
}
static bool is_transpose_certain_two_dims(const at::Tensor &tensor, int64_t dim)
{
return tensor.stride(dim + 1) == tensor.stride(dim) * tensor.size(dim);
}
static bool is_x_scale_same_transpose(const at::Tensor &x, const at::Tensor &scale, int64_t dim_x, int64_t dim_scale)
{
if (x.dim() < dim_x + 2 || scale.dim() < dim_scale + 2) {
return true;
}
if (x.size(dim_x) == 1 && x.size(dim_x + 1)== 1) {
return true;
}
if (scale.size(dim_scale) == 1 && scale.size(dim_scale + 1)== 1) {
return true;
}
bool x_trans = is_transpose_certain_two_dims(x, dim_x);
bool scale_trans = is_transpose_certain_two_dims(scale, dim_scale);
if (x_trans == scale_trans) {
return true;
}
return false;
}
static bool is_nz_format(const at::Tensor& x2)
{
const torch_npu::NPUStorageDesc &tensor_desc =
torch_npu::NPUBridge::GetNpuStorageImpl(x2)->npu_desc_;
return tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ ||
tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ_C0_4 ||
tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ_C0_16;
}
uint64_t infer_out_batch_shape(const at::Tensor &x1, const at::Tensor &x2, std::vector<uint64_t> &batch_record)
{
TORCH_CHECK(at_npu::native::FormatHelper::IsBaseFormatType(x2) || is_nz_format(x2),
"x2 should be in the original image format or nz format, but it is ",
npu_preparation::get_tensor_npu_format(x2), OPS_ERROR(ErrCode::PARAM));
uint64_t batch_val = 1;
auto x1_dim_num = x1.dim();
auto x2_dim_num = x2.dim();
auto out_dim_num = std::max(x1_dim_num, x2_dim_num);
auto &shape_long = x1_dim_num > x2_dim_num ? x1 : x2;
auto &shape_short = x1_dim_num > x2_dim_num ? x2 : x1;
int64_t vaild_offset = out_dim_num - std::min(x1_dim_num, x2_dim_num);
for (int64_t i = 0; i < out_dim_num - LAST_SECOND_DIM_INDEX; i++) {
auto short_dim = i < vaild_offset ? 1 : shape_short.size(i - vaild_offset);
auto long_dim = shape_long.size(i);
TORCH_CHECK(!(short_dim > 1 && long_dim > 1 && short_dim != long_dim),
"the x1 shape and x2 shape not supported for broadcast, the short_dim is ",
short_dim, " and the long_dim is ", long_dim, OPS_ERROR(ErrCode::PARAM));
uint64_t cur_batch_value = static_cast<uint64_t>(std::max(short_dim, long_dim));
batch_val = batch_val * cur_batch_value;
batch_record.push_back(cur_batch_value);
}
return batch_val;
}
int64_t check_and_get_groups(
at::IntArrayRef group_size_list,
const at::Tensor& x1,
const at::Tensor& x2,
const at::Tensor& scale,
const c10::optional<at::Tensor>& pertoken_scale)
{
int64_t groups = 0;
if (group_size_list.empty()) {
return groups;
}
size_t group_dim = group_size_list.size();
TORCH_CHECK(group_dim == A8W4_GROUP_DIM, "group_sizes only support input with three elements, but got ",
group_dim, OPS_ERROR(ErrCode::PARAM));
int64_t group_m = static_cast<int64_t>(group_size_list[0]);
int64_t group_n = static_cast<int64_t>(group_size_list[1]);
int64_t group_k = static_cast<int64_t>(group_size_list[2]);
bool invalid_group_param = ((group_m <= GROUP_MAX && group_m >= 0)
&& (group_n <= GROUP_MAX && group_n >= 0)
&& (group_k <= GROUP_MAX && group_k >= 0));
TORCH_CHECK(invalid_group_param, "group param value must conform to range [0, 65535]", OPS_ERROR(ErrCode::VALUE));
groups = static_cast<int64_t>((static_cast<uint64_t>(group_m) << 32) + (static_cast<uint64_t>(group_n) << 16) +
(static_cast<uint64_t>(group_k)));
return groups;
}
at::Tensor npu_quant_matmul(const at::Tensor &x1, const at::Tensor &x2, const at::Tensor &scale,
const c10::optional<at::Tensor> &offset, const c10::optional<at::Tensor> &pertoken_scale,
const c10::optional<at::Tensor> &bias, c10::optional<int64_t> output_dtype,
c10::optional<int64_t> x1_dtype, c10::optional<int64_t> x2_dtype,
c10::optional<int64_t> pertoken_scale_dtype, c10::optional<int64_t> scale_dtype,
c10::OptionalIntArrayRef group_sizes, const c10::optional<at::Tensor> &y_scale)
{
if (is_nz_format(x2)) {
static const bool is_quant_matmul_weight_nz_available = check_aclnn_kernel_available("aclnnQuantMatmulWeightNz");
TORCH_CHECK(is_quant_matmul_weight_nz_available,
"Get aclnnQuantMatmulWeightNz or aclnnQuantMatmulWeightNzGetWorkspaceSize failed, only "
"aclnnQuantMatmulWeightNz support X2's format is nz, please upgrade CANN.",
OPS_ERROR(ErrCode::PARAM));
} else {
static const bool is_quant_matmul_v5_available = check_aclnn_kernel_available("aclnnQuantMatmulV5");
TORCH_CHECK(is_quant_matmul_v5_available,
"Get aclnnQuantMatmulV5 or aclnnQuantMatmulV5 failed, only "
"aclnnQuantMatmulV5 support A8W4, please upgrade CANN.",
OPS_ERROR(ErrCode::TYPE));
}
bool is_a8W4_int = x1.dtype() == at::kChar && x2.dtype() == at::kInt;
bool is_a8W4_float = x1.dtype() == at::kFloat8_e4m3fn && x2.dtype() == at::kFloat;
at::IntArrayRef group_size_list = group_sizes.value_or(at::IntArrayRef{});
int64_t group_size = check_and_get_groups(group_size_list, x1, x2, scale, pertoken_scale);
bool is_a4w4 = x1.dtype() == at::kInt && x2.dtype() == at::kInt;
bool trans_x1 = is_transpose_last_two_dims(x1);
bool trans_x2 = is_transpose_last_two_dims(x2);
auto x1_dim_num = x1.dim();
auto x2_dim_num = x2.dim();
auto x2_n_dim = (is_a4w4 && !trans_x2) ? x2.size(x2_dim_num - 1) * INT4_NUMS_IN_INT32 : x2.size(x2_dim_num - 1);
#if VERSION_BETWEEN(V2R1, V2R7)
bool mxfp4_valid = x1_dtype.has_value() && x2_dtype.has_value() &&
x1_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) &&
x2_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1);
#endif
#if VERSION_BETWEEN(V2R8, VERSION_NEWEST)
bool mxfp4_valid = false;
if (x1_dtype.has_value()) {
mxfp4_valid = x1_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1);
} else {
mxfp4_valid = x1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
if (x2_dtype.has_value()) {
mxfp4_valid = mxfp4_valid && x2_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1);
} else {
mxfp4_valid = mxfp4_valid && x2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
#endif
c10::SmallVector<int64_t, SIZE> output_size;
if (is_a8W4_int) {
output_size = {x1.sizes()[0], x2.sizes()[1] * INT4_NUMS_IN_INT32};
} else if (is_a8W4_float) {
if (trans_x2) {
output_size = {x1.sizes()[0], x2.sizes()[1]};
} else {
output_size = {x1.sizes()[0], x2.sizes()[1] * INT4_NUMS_IN_INT32};
}
} else {
std::vector<uint64_t> batch_record;
uint64_t batch_val = infer_out_batch_shape(x1, x2, batch_record);
const at::Tensor long_tensor = x1_dim_num > x2_dim_num ? x1 : x2;
output_size = op_infer::array_to_small_vector(long_tensor.sizes());
if (mxfp4_valid) {
TORCH_CHECK(x1.dim() >= 2 && x1.dim() <= 6,
"x1 dim num should be 2 ~ 6, please check x1 dim num. Actual x1 dim = ", x1.dim(),
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x2.dim() >= 2 && x2.dim() <= 6,
"x2 dim num should be 2 ~ 6, please check x2 dim num. Actual x2 dim = ", x2.dim(),
OPS_ERROR(ErrCode::PARAM));
int64_t x1_size_last_second = x1.sizes()[x1_dim_num - LAST_SECOND_DIM_INDEX];
int64_t x2_size_last = x2.sizes()[x2_dim_num - 1];
int64_t real_m = !trans_x1 ? x1_size_last_second : x1_size_last_second * FP4_IN_INT8;
int64_t real_n = trans_x2 ? x2_size_last : x2_size_last * FP4_IN_INT8;
output_size[long_tensor.dim() - LAST_SECOND_DIM_INDEX] = real_m;
output_size[long_tensor.dim() - 1] = real_n;
} else {
output_size[long_tensor.dim() - LAST_SECOND_DIM_INDEX] = x1.size(x1_dim_num - LAST_SECOND_DIM_INDEX);
output_size[long_tensor.dim() - 1] = x2_n_dim;
}
for (int64_t i = 0; i < long_tensor.dim() - LAST_SECOND_DIM_INDEX; i++) {
output_size[i] = static_cast<int64_t>(batch_record[i]);
}
}
c10::TensorOptions options;
aclDataType output_acltype = ACL_INT8;
if (!output_dtype.has_value()) {
options = x1.options().dtype(at::kChar);
} else {
output_acltype = c10_npu::GetAclDataType(output_dtype.value());
options = x1.options().dtype(npu_preparation::convert_to_scalar_type(output_acltype));
}
at::Tensor result = npu_preparation::apply_tensor_without_format(output_size, options);
const at::Tensor &offset_real = offset.value_or(at::Tensor());
const at::Tensor &pertoken_scale_real = pertoken_scale.value_or(at::Tensor());
const at::Tensor &bias_real = bias.value_or(at::Tensor());
bool transpose1 = false;
bool transpose2 = false;
TensorWrapper x1_wrapper = make_wrapper(x1, x1_dtype);
TensorWrapper x2_wrapper = make_wrapper(x2, x2_dtype);
TensorWrapper x1_scale_wrapper = make_wrapper(pertoken_scale_real, pertoken_scale_dtype);
TensorWrapper x2_scale_wrapper = make_wrapper(scale, scale_dtype);
TensorWrapper result_wrapper = make_wrapper(result, output_dtype);
at::Tensor x1_offset = at::empty({0}, options);
at::Tensor x2_offset = at::Tensor();
at::Tensor y_offset = at::empty({0}, options);
if (is_a8W4_int) {
y_offset = offset_real;
} else {
x2_offset = offset_real;
}
bool use_aclnn_v5 = x1_dtype.has_value() || (x1.dtype() != at::kInt && x1.dtype() != at::kChar) ||
is_a8W4_float || is_a8W4_int;
aclDataType pertoken_scale_dtype_real = pertoken_scale_dtype.has_value()
? c10_npu::GetAclDataType(pertoken_scale_dtype.value())
: (pertoken_scale.has_value()
? c10_npu::GetAclDataType(static_cast<int64_t>(pertoken_scale_real.scalar_type()))
: aclDataType::ACL_INT8);
bool need_check_trans = pertoken_scale.has_value()
&& (((pertoken_scale_real.dim() == x1.dim() && scale.dim() == x2.dim())
|| pertoken_scale_dtype_real == aclDataType::ACL_FLOAT8_E8M0)
&& (pertoken_scale_real.dim() >= 2 && scale.dim() >= 2))
&& !(is_a8W4_float || is_a8W4_int);
if (need_check_trans) {
int64_t dim_x1 = x1.dim() - 2;
int64_t dim_x2 = x2.dim() - 2;
int64_t dim_x1_scale = 0;
int64_t dim_x2_scale = 0;
if (pertoken_scale_dtype_real != aclDataType::ACL_FLOAT8_E8M0) {
dim_x1_scale = pertoken_scale_real.dim() - 2;
dim_x2_scale = scale.dim() - 2;
}
TORCH_CHECK(is_x_scale_same_transpose(x1, pertoken_scale_real, dim_x1, dim_x1_scale),
"Input x1 tensor and pertoken_scale tensor's transpose are not same, please check input.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(is_x_scale_same_transpose(x2, scale, dim_x2, dim_x2_scale),
"Input x2 tensor and scale tensor's transpose are not same, please check input.",
OPS_ERROR(ErrCode::PARAM));
}
bool use_trans_quant_param = scale.dtype() == at::kFloat && !pertoken_scale.has_value() &&
(output_acltype != ACL_BF16 || use_aclnn_v5) && output_acltype != ACL_INT32;
if (use_trans_quant_param) {
const at::Tensor quant_param = op_api::npu_trans_quant_param(scale, offset);
if (is_nz_format(x2)) {
EXEC_NPU_CMD(aclnnQuantMatmulWeightNz, x1_wrapper, x2_wrapper, pertoken_scale_real, quant_param, y_scale,
x1_offset, x2_offset, y_offset, bias_real, transpose1, transpose2, group_size, result_wrapper);
} else {
EXEC_NPU_CMD(aclnnQuantMatmulV5, x1_wrapper, x2_wrapper, pertoken_scale_real, quant_param, y_scale,
x1_offset, x2_offset, y_offset, bias_real, transpose1, transpose2, group_size, result_wrapper);
}
} else {
if (!is_a4w4 && is_nz_format(x2)) {
EXEC_NPU_CMD(aclnnQuantMatmulWeightNz, x1_wrapper, x2_wrapper, x1_scale_wrapper, x2_scale_wrapper, y_scale,
x1_offset, x2_offset, y_offset, bias_real, transpose1, transpose2, group_size, result_wrapper);
} else {
EXEC_NPU_CMD(aclnnQuantMatmulV5, x1_wrapper, x2_wrapper, x1_scale_wrapper, x2_scale_wrapper, y_scale,
x1_offset, x2_offset, y_offset, bias_real, transpose1, transpose2, group_size, result_wrapper);
}
}
return result;
}
}