#include <torch/extension.h>
#include <torch_npu/csrc/framework/utils/RandomOpAdapter.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/include/ops.h>
#include "inc/aclnn_common.h"
using npu_preparation = at_npu::native::OpPreparation;
namespace op_infer {
constexpr int SIZE = 8;
c10::SmallVector<int64_t, SIZE> array_to_small_vector(c10::IntArrayRef shape)
{
c10::SmallVector<int64_t, SIZE> small_shape;
for (size_t i = 0; i < shape.size(); ++i) {
small_shape.emplace_back(shape[i]);
}
return small_shape;
}
}
std::vector<at::Tensor> npu_weight_quant_gmm(const std::vector<at::Tensor>& x,
const std::vector<at::Tensor>& weight,
const std::vector<at::Tensor>& antiquant_scale,
const std::vector<at::Tensor>& antiquant_offset,
const std::vector<at::Tensor>& bias,
const c10::optional<at::Tensor>& group_list,
c10::optional<int64_t> group_list_type,
c10::optional<int64_t> act_type)
{
auto group_list_real = group_list.value_or(at::Tensor());
int64_t split_item_value = 3;
int64_t group_type_value = 0;
int64_t group_list_type_value = group_list_type.value_or(0);
int64_t act_type_value = act_type.value_or(0);
const at::TensorList x_(x);
const at::TensorList weight_(weight);
const at::TensorList bias_(bias);
const at::TensorList antiquant_scale_(antiquant_scale);
const at::TensorList antiquant_offset_(antiquant_offset);
c10::TensorOptions options = x_[0].options().dtype(x_[0].scalar_type());
size_t dim_num_w = weight[0].sizes().size();
auto output_size = op_infer::array_to_small_vector({x[0].sizes()[0], weight[0].sizes()[dim_num_w - 1]});
std::vector<at::Tensor> y{at::empty(output_size, options)};
at::TensorList result = at::TensorList(y);
auto scale = nullptr;
auto offset = nullptr;
auto per_token_scale = nullptr;
auto act_in = nullptr;
auto act_quant_scale = nullptr;
auto act_quant_offset = nullptr;
auto act_out = nullptr;
auto dynamic_quant_scale_out = nullptr;
ACLNN_CMD(aclnnGroupedMatmulV4, x_, weight_, bias_, scale, offset, antiquant_scale_,
antiquant_offset_, per_token_scale, group_list_real, act_in, act_quant_scale, act_quant_offset,
split_item_value, group_type_value, group_list_type_value, act_type_value,
result, act_out, dynamic_quant_scale_out);
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("npu_weight_quant_gmm", &npu_weight_quant_gmm, "weight quantize grouped matmul forward");
}