#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
constexpr int64_t MXFP_DIVISOR_SIZE = 64LL;
constexpr int64_t MXFP_MULTI_BASE_SIZE = 2LL;
constexpr int64_t NUM_TWO = 2LL;
constexpr int64_t NUM_ONE = 1LL;
constexpr int64_t DIM_2 = 2LL;
constexpr int64_t DIM_1 = 1LL;
constexpr int64_t DIM_0 = 0LL;
constexpr int64_t DIM_3 = 3LL;
constexpr int64_t DIM_4 = 4LL;
constexpr int64_t DIM_5 = 5LL;
constexpr int64_t FLOAT8_E5M2 = 35LL;
constexpr int64_t FLOAT8_E4M3FN = 36LL;
constexpr int64_t HIFLOAT8 = 34LL;
void create_new_tensor(at::Tensor &y, size_t dim_m, size_t dim_n, c10::TensorOptions options)
{
auto output_size = op_infer::array_to_small_vector({dim_m, dim_n});
y = npu_preparation::apply_tensor_without_format(output_size, options);
}
void create_new_tensor_batch(at::Tensor &y, size_t batch, size_t dim_m, size_t dim_n,
const c10::TensorOptions &options)
{
auto output_size = op_infer::array_to_small_vector({batch, dim_m, dim_n});
y = npu_preparation::apply_tensor_without_format(output_size, options);
}
bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
if (tensor.dim() < DIM_2) {
return false;
}
auto sizes = tensor.sizes();
auto strides = tensor.strides();
int64_t last_dim = tensor.dim() - DIM_1;
int64_t penultimate_dim = tensor.dim() - DIM_2;
if (strides[penultimate_dim] != NUM_ONE || strides[last_dim] != sizes[penultimate_dim]) {
return false;
}
int64_t expected_stride = sizes[last_dim] * sizes[penultimate_dim];
for (int64_t batch_dim = tensor.dim() - DIM_3; batch_dim >= DIM_0; --batch_dim) {
if (strides[batch_dim] != expected_stride) {
return false;
}
expected_stride *= sizes[batch_dim];
}
return true;
}
int64_t infer_nz_logical_n(const at::Tensor &weight_scale, bool is_mx_quant)
{
TORCH_CHECK(weight_scale.dim() >= DIM_2, "The dim of weight_scale[0] should be greater than or equal to 2, "
"but got ", weight_scale.dim(), OPS_ERROR(ErrCode::PARAM));
if (is_mx_quant) {
TORCH_CHECK(weight_scale.dim() == DIM_4, "The dim of weight_scale[0] should be equal to 4 in MX quant mode, "
"but got ", weight_scale.dim(), OPS_ERROR(ErrCode::PARAM));
return weight_scale.size(DIM_2);
}
return weight_scale.sizes().back();
}
std::tuple<at::Tensor, at::Tensor> npu_grouped_matmul_swiglu_quant_v2(
const at::Tensor & x,
const at::TensorList weight,
const at::TensorList weight_scale,
const at::Tensor & x_scale,
const at::Tensor & group_list,
const c10::optional<at::Tensor> & smooth_scale,
const c10::optional<at::TensorList> weight_assist_matrix,
const c10::optional<at::Tensor> & bias,
c10::optional<int64_t> dequant_mode,
c10::optional<int64_t> dequant_dtype,
c10::optional<int64_t> quant_mode,
c10::optional<int64_t> quant_dtype,
c10::optional<int64_t> group_list_type,
const c10::OptionalIntArrayRef tuning_config,
c10::optional<int64_t> x_dtype,
c10::optional<int64_t> weight_dtype,
c10::optional<int64_t> weight_scale_dtype,
c10::optional<int64_t> x_scale_dtype)
{
TORCH_CHECK(weight.size() == NUM_ONE, "The size of weight should be 1, current size is ", weight.size(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(weight_scale.size() == NUM_ONE, "The size of weight_scale should be 1, current size is ",
weight_scale.size(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x.dim() >= DIM_2, "The x dim should greater than 2, but the actual value is ", x.dim(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!weight_scale[DIM_0].sizes().empty(), "The weight_scale[0] is empty.", OPS_ERROR(ErrCode::PARAM));
const bool is_weight_nz = at_npu::native::custom_ops::get_npu_format(weight[DIM_0]) == ACL_FORMAT_FRACTAL_NZ ||
at_npu::native::custom_ops::get_npu_format(weight[DIM_0]) == ACL_FORMAT_FRACTAL_NZ_C0_16
|| weight[DIM_0].dim() == DIM_5;
auto x_size = x.sizes();
int n = 0;
bool weight_trans = is_transpose_last_two_dims(weight[DIM_0]);
const bool is_mx_quant = weight_scale_dtype.has_value();
const bool is_5d_nz = is_weight_nz && (weight[DIM_0].dim() == DIM_5);
if (is_5d_nz) {
n = static_cast<int>(infer_nz_logical_n(weight_scale[DIM_0], is_mx_quant));
} else {
if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
n = static_cast<int>(weight[DIM_0].sizes()[DIM_2]);
} else {
n = static_cast<int>(weight_scale[DIM_0].sizes().back());
}
}
int m = x_size[DIM_0];
int k = x_size[DIM_1];
const bool mxfp8w4_nz_input = is_weight_nz &&
x.scalar_type() == at::kFloat8_e4m3fn &&
weight_dtype.has_value() &&
weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1);
if (x_dtype.has_value()) {
TORCH_CHECK(x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)
|| x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)
|| x_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The optional parameter x_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2, torch_npu.hifloat8, or None, but the actual value is ",
c10_npu::CustomDataTypeToString(x_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
if (weight_dtype.has_value()) {
TORCH_CHECK(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)
|| weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)
|| weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The optional parameter weight_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2, torch_npu.hifloat8, or None, but the actual value is ",
c10_npu::CustomDataTypeToString(weight_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
if (!mxfp8w4_nz_input) {
TORCH_CHECK(
(x_dtype.has_value() && weight_dtype.has_value()) || (!x_dtype.has_value() && !weight_dtype.has_value()),
"The optional parameter x_dtype and weight_dtype should both be torch_npu.float4_e2m1fn_x2, torch_npu.float4_e1m2fn_x2"
"torch_npu.hifloat8, or None.",
OPS_ERROR(ErrCode::VALUE));
}
if (weight_scale_dtype.has_value()) {
TORCH_CHECK(weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"The optional parameter weight_scale_dtype only supports float8_e8m0fnu or None, but the actual value is ",
c10_npu::CustomDataTypeToString(weight_scale_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
if (x_scale_dtype.has_value()) {
TORCH_CHECK(x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"The optional parameter x_scale_dtype only supports float8_e8m0fnu or None, but the actual value is ",
c10_npu::CustomDataTypeToString(x_scale_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
if (dequant_dtype.has_value()) {
TORCH_CHECK(dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Float)
|| dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Char)
|| dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::Half)
|| dequant_dtype.value() == static_cast<int64_t>(c10::ScalarType::BFloat16),
"The optional parameter dequant_dtype only support torch.float32, torch.int8, torch.float16 and torch.bfloat16 ,but the actual value is ",
c10_npu::CustomDataTypeToString(dequant_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
}
int64_t dequant_mode_real = dequant_mode.value_or(0);
int64_t dequant_dtype_real = dequant_dtype.value_or(0);
const std::map<int64_t, int64_t> TorchToGeMap = {
{6, 0},
{5, 1},
{15, 27}};
auto it = TorchToGeMap.find(dequant_dtype.value_or(0));
if (it != TorchToGeMap.end()) {
dequant_dtype_real = it->second;
}
int64_t quant_mode_real = quant_mode.value_or(0);
int64_t group_list_type_real = group_list_type.value_or(0);
auto weight_assist_matrix_real = weight_assist_matrix.value_or(at::TensorList());
auto tuning_config_real = tuning_config.value_or(at::IntArrayRef{});
auto bias_real = bias.value_or(at::Tensor());
auto smooth_scale_real = smooth_scale.value_or(at::Tensor());
const bool mxfp4_input = x_dtype.has_value() && weight_dtype.has_value() &&
(x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)) &&
(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1));
at::Tensor output;
at::Tensor output_scale;
if (!weight_scale_dtype.has_value()) {
if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN
|| c10_npu::GetAclDataType(quant_dtype.value()) == HIFLOAT8) {
c10::TensorOptions options_output = x.options().dtype(quant_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(quant_dtype.value()))
: x[DIM_0].scalar_type());
create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
} else {
output = npu_preparation::apply_tensor_without_format({m, n / MXFP_MULTI_BASE_SIZE}, c10::dtype(c10::ScalarType::Char));
}
output_scale = npu_preparation::apply_tensor_without_format({m}, c10::dtype(c10::ScalarType::Float));
} else {
if (dequant_dtype.has_value()) {
dequant_dtype_real = static_cast<int64_t>(c10_npu::GetAclDataType(dequant_dtype.value()));
}
TORCH_CHECK(!weight[DIM_0].sizes().empty(), "weight[0] is empty.", OPS_ERROR(ErrCode::PARAM));
if (!is_weight_nz) {
TORCH_CHECK(weight[DIM_0].dim() == DIM_3, "weight[0] dim should be equal to 3, but the actual value is ",
weight[DIM_0].dim(), OPS_ERROR(ErrCode::PARAM));
if (c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
n = static_cast<int>(weight[DIM_0].sizes()[DIM_2]);
} else {
n = static_cast<int>(weight_scale[DIM_0].sizes().back());
}
}
c10::TensorOptions options_output = x.options().dtype(quant_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(quant_dtype.value()))
: x[DIM_0].scalar_type());
c10::TensorOptions options = x.options().dtype(npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(weight_scale_dtype.value())));
if (mxfp4_input) {
if (!weight_trans) {
if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN) {
create_new_tensor(output, m, ((n / MXFP_MULTI_BASE_SIZE) * FP4_IN_INT8), options_output);
create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n * FP4_IN_INT8 / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE),
MXFP_MULTI_BASE_SIZE, options);
} else {
create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n * FP4_IN_INT8 / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
}
} else {
if (c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E5M2 || c10_npu::GetAclDataType(quant_dtype.value()) == FLOAT8_E4M3FN) {
create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
} else {
create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE / NUM_TWO, options_output);
create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
}
}
} else {
create_new_tensor(output, m, n / MXFP_MULTI_BASE_SIZE, options_output);
create_new_tensor_batch(output_scale, m, op_infer::CeilDiv(n / MXFP_MULTI_BASE_SIZE, MXFP_DIVISOR_SIZE), MXFP_MULTI_BASE_SIZE, options);
}
}
TensorWrapper x_wrapper = {x,
x_dtype.has_value() ? c10_npu::GetAclDataType(x_dtype.value())
: npu_preparation::convert_to_acl_data_type(x.scalar_type())};
TensorListWrapper weight_wrapper = {weight,
weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
: npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
TensorListWrapper weight_scale_wrapper = {weight_scale,
weight_scale_dtype.has_value() ? c10_npu::GetAclDataType(weight_scale_dtype.value())
: (weight_scale.empty() ? aclDataType::ACL_FLOAT
: npu_preparation::convert_to_acl_data_type(weight_scale[0].scalar_type()))};
TensorWrapper x_scale_wrapper = {x_scale,
x_scale_dtype.has_value() ? c10_npu::GetAclDataType(x_scale_dtype.value())
: (!x_scale.numel() ? aclDataType::ACL_FLOAT
: npu_preparation::convert_to_acl_data_type(x_scale.scalar_type()))};
TensorWrapper output_wrapper = {output,
quant_dtype.has_value() ? c10_npu::GetAclDataType(quant_dtype.value()): aclDataType::ACL_FLOAT};
TensorWrapper output_scale_wrapper = {output_scale,
weight_scale_dtype.has_value() ? aclDataType::ACL_FLOAT8_E8M0 : aclDataType::ACL_FLOAT};
if (is_weight_nz) {
static const bool is_weight_nz_available = check_aclnn_kernel_available("aclnnGroupedMatmulSwigluQuantWeightNzV2");
TORCH_CHECK(is_weight_nz_available,
"Format of weight in npu_grouped_matmul is FRACTAL_NZ, current CANN version "
"do not support with this format. Please try to update the version of CANN."
+ OPS_ERROR(ErrCode::PARAM));
if (mxfp8w4_nz_input) {
EXEC_NPU_CMD(
aclnnGroupedMatmulSwigluQuantWeightNzV2,
x,
weight_wrapper,
weight_scale_wrapper,
weight_assist_matrix_real,
bias_real,
x_scale_wrapper,
smooth_scale_real,
group_list,
dequant_mode_real,
dequant_dtype_real,
quant_mode_real,
group_list_type_real,
tuning_config_real,
output,
output_scale_wrapper);
} else {
at::Tensor weight_for_nz = weight[DIM_0];
if (at_npu::native::custom_ops::get_npu_format(weight_for_nz) != ACL_FORMAT_FRACTAL_NZ) {
weight_for_nz = weight_for_nz.clone();
auto &desc = torch_npu::NPUBridge::GetNpuStorageImpl(weight_for_nz)->npu_desc_;
desc.npu_format_ = ACL_FORMAT_FRACTAL_NZ;
desc.storage_sizes_ = op_infer::array_to_small_vector(weight_for_nz.sizes());
}
c10::SmallVector<at::Tensor, 1> weight_nz_vec = {weight_for_nz};
at::TensorList weight_nz_list(weight_nz_vec);
TensorListWrapper weight_nz_wrapper = {weight_nz_list,
weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
: npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
EXEC_NPU_CMD(
aclnnGroupedMatmulSwigluQuantWeightNzV2,
x_wrapper,
weight_dtype.has_value() ? weight_wrapper : weight_nz_wrapper,
weight_scale_wrapper,
weight_assist_matrix_real,
bias_real,
x_scale_wrapper,
smooth_scale_real,
group_list,
dequant_mode_real,
dequant_dtype_real,
quant_mode_real,
group_list_type_real,
tuning_config_real,
output_wrapper,
output_scale_wrapper);
}
} else {
EXEC_NPU_CMD(
aclnnGroupedMatmulSwigluQuantV2,
x_wrapper,
weight_wrapper,
weight_scale_wrapper,
weight_assist_matrix_real,
bias_real,
x_scale_wrapper,
smooth_scale_real,
group_list,
dequant_mode_real,
dequant_dtype_real,
quant_mode_real,
group_list_type_real,
tuning_config_real,
output_wrapper,
output_scale_wrapper);
}
return std::tuple<at::Tensor, at::Tensor>(output, output_scale);
}
}