#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
constexpr int64_t INT8 = 2;
constexpr int64_t INT4 = 29;
constexpr int64_t MAX_LAST_DIM = 8192;
constexpr int64_t DIV_LAST_DIM = 4;
constexpr int64_t INT4_IN_INT8_NUM = 2;
}
std::tuple<at::Tensor, at::Tensor> npu_swiglu_quant(const at::Tensor& x, const c10::optional<at::Tensor>& smooth_scales,
const c10::optional<at::Tensor>& offsets, const c10::optional<at::Tensor>& group_index, bool activate_left,
int64_t quant_mode, int64_t group_list_type, c10::optional<at::ScalarType> dst_type)
{
TORCH_CHECK(quant_mode == 0 || quant_mode == 1, "quant_mode only support 0(static) or 1(dynamic), but got ",
quant_mode, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(group_list_type == 0 || group_list_type == 1,
"group_list_type only support 0(cumsum) or 1(count), but got ",
group_list_type, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!dst_type.has_value() || dst_type == at::ScalarType::Char || dst_type == at::ScalarType::QUInt4x2,
"dtype must be torch.int8 for int8 or torch.quint4x2 for int4" + OPS_ERROR(ErrCode::TYPE));
if (!dst_type.has_value()) {
dst_type = at::ScalarType::Char;
}
const at::Tensor& smooth_scales_opt = c10::value_or_else(smooth_scales, [] { return at::Tensor(); });
const at::Tensor& offsets_opt = c10::value_or_else(offsets, [] { return at::Tensor(); });
TORCH_CHECK(quant_mode != 0 || smooth_scales_opt.sizes() == offsets_opt.sizes(),
"smooth_scales and offsets should have the same shape when quant_mode is 0",
OPS_ERROR(ErrCode::PARAM));
int64_t x_last_dim = x.size(x.dim() - 1);
TORCH_CHECK(x_last_dim <= MAX_LAST_DIM, "x last dim size should not be larger than ", MAX_LAST_DIM, ", but got ",
x_last_dim, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x_last_dim % 2 == 0, "x last dim size should be even", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(dst_type != at::ScalarType::QUInt4x2 || x_last_dim % DIV_LAST_DIM == 0,
"x shape last dim must be divded by 4 when dst_type is torch.quint4x2, but got ",
x_last_dim, OPS_ERROR(ErrCode::PARAM));
at::SmallVector<int64_t, op_infer::SIZE> y_size;
at::SmallVector<int64_t, op_infer::SIZE> scale_size;
for (int i = 0; i < x.dim() - 1; i++) {
y_size.push_back(x.size(i));
scale_size.push_back(x.size(i));
}
auto last_dim = dst_type == at::ScalarType::Char ? x_last_dim / 2 : x_last_dim / 2 / INT4_IN_INT8_NUM;
y_size.push_back(last_dim);
at::Tensor y = npu_preparation::apply_tensor_without_format(y_size, c10::dtype(c10::ScalarType::Char));
at::Tensor scale = npu_preparation::apply_tensor_without_format(scale_size, c10::dtype(c10::ScalarType::Float));
std::string quant_mode_str = quant_mode == 0 ? "static" : "dynamic";
char* quant_mode_ptr = const_cast<char*>(quant_mode_str.c_str());
const at::Tensor& group_index_opt = c10::value_or_else(group_index, [] { return at::Tensor(); });
int output_type = dst_type == at::ScalarType::Char ? INT8 : INT4;
EXEC_NPU_CMD(aclnnSwiGluQuantV2, x, smooth_scales_opt, offsets_opt, group_index_opt, activate_left, quant_mode_ptr,
group_list_type, output_type, y, scale);
return std::tie(y, scale);
}
}