#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/AclOpsInterface.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
const int64_t NUM_TWO = 2;
constexpr float DEFAULT_BLOCKSIZE = 64.0;
std::tuple<at::Tensor, at::Tensor> npu_swiglu_mx_quant(
const at::Tensor& x, const c10::optional<at::Tensor>& group_index,
int64_t activate_dim, bool activate_left, int64_t swiglu_mode,
double clamp_limit, double glu_alpha, double glu_bias,
int64_t group_mode, int64_t axis, int64_t dst_type,
c10::string_view round_mode, int64_t scale_alg, double max_dtype_value)
{
TORCH_CHECK(x.dim() > 1, "x dim should larger than 1", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(swiglu_mode == 0 || swiglu_mode == 1, "swiglu_mode only support 0 or 1, but got ", swiglu_mode,
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(std::isfinite(clamp_limit) && clamp_limit > 0.0, "clamp_limit should be positive finite",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(std::isfinite(glu_alpha), "glu_alpha should be finite", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(std::isfinite(glu_bias), "glu_bias should be finite", OPS_ERROR(ErrCode::PARAM));
static const bool is_available = check_aclnn_kernel_available("aclnnSwigluMxQuant");
TORCH_CHECK(is_available,
"Current CANN version do not support this api. Please try to update the version of CANN."
+ OPS_ERROR(ErrCode::PARAM));
const at::Tensor& group_index_opt = c10::value_or_else(group_index, [] { return at::Tensor(); });
int64_t activate_dim_value = activate_dim;
char *round_mode_ptr = const_cast<char *>(round_mode.data());
if (activate_dim_value < 0) {
activate_dim_value = activate_dim_value + x.dim();
}
TORCH_CHECK(activate_dim_value <= (x.dim() - 1) && activate_dim_value >= 0, "activate_dim should be in range [0, x.dim()-1]", OPS_ERROR(ErrCode::PARAM));
int64_t quant_dim_value = axis;
if (quant_dim_value < 0) {
quant_dim_value = quant_dim_value + x.dim();
}
TORCH_CHECK(quant_dim_value >= 0 && quant_dim_value <= (x.dim() - 1), "quant_dim should be in range [0, x.dim()-1]", OPS_ERROR(ErrCode::PARAM));
int64_t select_dim = activate_dim_value;
at::SmallVector<int64_t, op_infer::SIZE> y_size;
at::SmallVector<int64_t, op_infer::SIZE> scale_size;
for (int i = 0; i < x.dim(); i++) {
if (i == select_dim) {
y_size.push_back(x.size(i) / NUM_TWO);
scale_size.push_back(x.size(i) / NUM_TWO);
} else {
y_size.push_back(x.size(i));
scale_size.push_back(x.size(i));
}
}
int64_t quant_size = 0;
if (!group_index_opt.defined()) {
quant_size = static_cast<int64_t>(std::ceil(static_cast<double>(scale_size[quant_dim_value]) / DEFAULT_BLOCKSIZE));
} else {
if (quant_dim_value == (x.dim() - 1)) {
quant_size = static_cast<int64_t>(std::ceil(static_cast<double>(scale_size[quant_dim_value]) / DEFAULT_BLOCKSIZE));
} else {
quant_size = static_cast<int64_t>(std::floor(static_cast<double>(scale_size[quant_dim_value]) / DEFAULT_BLOCKSIZE));
quant_size = quant_size + group_index_opt.sizes()[0];
}
}
scale_size[quant_dim_value] = quant_size;
scale_size.push_back(NUM_TWO);
at::Tensor y;
aclDataType y_acltype;
if (dst_type == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
dst_type == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)) {
int64_t last_dim_val = y_size[x.dim() - 1];
TORCH_CHECK(last_dim_val % NUM_TWO == 0, "Y last dim should be even when type of y is float4_e1m2 or float4_e2m1", OPS_ERROR(ErrCode::PARAM));
y_size[x.dim() - 1] = last_dim_val / NUM_TWO;
}
if (dst_type == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
dst_type == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)) {
y = npu_preparation::apply_tensor_without_format(y_size, c10::ScalarType::Byte);
y_acltype = c10_npu::GetAclDataType(dst_type);
} else {
y_acltype = c10_npu::GetAclDataType(dst_type);
at::ScalarType scalar_dtype = npu_preparation::convert_to_scalar_type(y_acltype);
y = npu_preparation::apply_tensor_without_format(y_size, c10::dtype(scalar_dtype));
}
TensorWrapper y_wrapper = {y, y_acltype};
at::Tensor scale = npu_preparation::apply_tensor_without_format(scale_size, c10::dtype(c10::ScalarType::Byte));
TensorWrapper mxscale_wrapper = {scale, aclDataType::ACL_FLOAT8_E8M0};
EXEC_NPU_CMD(aclnnSwigluMxQuant, x, group_index_opt, activate_dim_value, activate_left,
swiglu_mode, clamp_limit, glu_alpha, glu_bias, group_mode, axis,
y_acltype, round_mode_ptr, scale_alg, max_dtype_value, y_wrapper, mxscale_wrapper);
return std::tie(y, scale);
}
}