#include <torch/extension.h>
#include <torch_npu/csrc/framework/utils/RandomOpAdapter.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/include/ops.h>
#include "../flop_counter/flop_counter.h"
#include "inc/aclnn_common.h"
using npu_preparation = at_npu::native::OpPreparation;
namespace op_infer {
constexpr int SIZE = 8;
c10::SmallVector<int64_t, SIZE> array_to_small_vector(c10::IntArrayRef shape)
{
c10::SmallVector<int64_t, SIZE> small_shape;
for (size_t i = 0; i < shape.size(); ++i) {
small_shape.emplace_back(shape[i]);
}
return small_shape;
}
}
void _check_dims(size_t num_x, const at::TensorList &weight, size_t num_group_list)
{
size_t num_w = weight.size();
TORCH_CHECK(num_x > 0 && num_w > 0,
"Neither x nor weight could be empty.");
size_t dim_num_w = weight[0].sizes().size();
size_t dim_0_w = weight[0].sizes()[0];
}
void _create_new_tensor(std::vector<at::Tensor> &y, int64_t dim_m, int64_t dim_n, c10::TensorOptions options,
int64_t group_type_value, int64_t num_group_list)
{
auto output_size = (2 == group_type_value) ? op_infer::array_to_small_vector({num_group_list, dim_m, dim_n})
: op_infer::array_to_small_vector({dim_m, dim_n});
y.emplace_back(at::empty(output_size, options));
}
void _foreach_transpose(const at::TensorList &tensorList, std::vector<at::Tensor> &tensors)
{
for (int i = 0; i< tensorList.size(); i++) {
at::Tensor tensor = tensorList[i].transpose(-1, -2);
tensors.emplace_back(tensor);
}
}
bool _is_transposed(at::Tensor &tensors)
{
int dim_sum = tensors.dim();
TORCH_CHECK(dim_sum >= 2 && dim_sum <= 3,
"input tensor of is_tensor_transposed should be either 2- or 3-dimensional.");
int shape_dim = tensors.sizes().size() - 2;
if (tensors.stride(dim_sum - 2) == 1 && tensors.stride(dim_sum - 1) == tensors.sizes().at(shape_dim)) {
return true;
} else {
return false;
}
}
std::vector<at::Tensor> npu_gmm(const std::vector<at::Tensor>& x,
const std::vector<at::Tensor>& weight,
const std::vector<at::Tensor>& bias,
c10::optional<std::vector<int64_t>> group_list,
c10::optional<int64_t> group_type,
c10::optional<int64_t> group_list_type)
{
auto num_x = x.size();
auto num_w = weight.size();
auto group_list_real_ = group_list.value_or(std::vector<int64_t>{});
at::IntArrayRef group_list_real(group_list_real_);
auto num_group_list = group_list_real.size();
int64_t split_item_value = 3;
int64_t group_type_value = group_type.value_or(-1);
const at::TensorList x_(x);
const at::TensorList weight_(weight);
const at::TensorList bias_(bias);
_check_dims(num_x, weight_, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x_[0].options().dtype(x_[0].scalar_type());
size_t dim_num_w = weight[0].sizes().size();
_create_new_tensor(y, x[0].sizes()[0], weight[0].sizes()[dim_num_w - 1], options, group_type_value,
num_group_list);
at::TensorList result = at::TensorList(y);
auto scale_real = at::TensorList();
auto offset_real = at::TensorList();
auto antiquant_scale_real = at::TensorList();
auto antiquant_offset_real = at::TensorList();
ACLNN_CMD(aclnnGroupedMatmulV2, x_, weight_, bias_, scale_real, offset_real, antiquant_scale_real,
antiquant_offset_real, group_list_real, split_item_value, group_type_value, result);
#ifdef FLOP_COUNT
FLOP_COUNT(FlopCounter::gmm_flop_int, x_, weight_, group_list, group_type_value);
#endif
return y;
}
std::vector<at::Tensor> npu_gmm(const std::vector<at::Tensor>& x,
const std::vector<at::Tensor>& weight,
const std::vector<at::Tensor>& bias,
const c10::optional<at::Tensor>& group_list,
c10::optional<int64_t> group_type,
c10::optional<int64_t> group_list_type)
{
auto num_x = x.size();
auto num_w = weight.size();
auto group_list_real = group_list.value_or(at::Tensor());
auto num_group_list = group_list_real.sizes()[0];
int64_t split_item_value = 3;
int64_t group_type_value = group_type.value_or(-1);
int64_t group_list_type_value = group_list_type.value_or(0);
int64_t act_type_value = 0;
const at::TensorList x_(x);
const at::TensorList weight_(weight);
const at::TensorList bias_(bias);
_check_dims(num_x, weight_, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x_[0].options().dtype(x_[0].scalar_type());
size_t dim_num_w = weight[0].sizes().size();
_create_new_tensor(y, x[0].sizes()[0], weight[0].sizes()[dim_num_w - 1], options, group_type_value,
num_group_list);
at::TensorList result = at::TensorList(y);
auto scale_real = at::TensorList();
auto offset_real = at::TensorList();
auto antiquant_scale_real = at::TensorList();
auto antiquant_offset_real = at::TensorList();
auto perToken_scale_real = at::TensorList();
auto activation_input_real = at::TensorList();
auto activation_quant_scale_real = at::TensorList();
auto activation_quant_offset_real = at::TensorList();
auto activation_feature_out_real = at::TensorList();
auto dynQuant_scale_out_real = at::TensorList();
ACLNN_CMD(aclnnGroupedMatmulV4, x_, weight_, bias_, scale_real, offset_real, antiquant_scale_real,
antiquant_offset_real, perToken_scale_real, group_list_real, activation_input_real,
activation_quant_scale_real, activation_quant_offset_real, split_item_value, group_type_value,
group_list_type_value, act_type_value, result, activation_feature_out_real, dynQuant_scale_out_real);
#ifdef FLOP_COUNT
FLOP_COUNT(FlopCounter::gmm_flop_tensor, x_, weight_, group_list, group_type_value);
#endif
return y;
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>> npu_gmm_backward(
const std::vector<at::Tensor>& grad,
const std::vector<at::Tensor>& x,
const std::vector<at::Tensor>& weight,
const c10::optional<std::vector<int64_t>> group_list,
c10::optional<int64_t> group_list_type)
{
auto num_w = weight.size();
auto group_list_real = group_list.value_or(std::vector<int64_t>{});
const at::TensorList x_(x);
const at::TensorList weight_(weight);
std::vector<at::Tensor> xt;
std::vector<at::Tensor> wt;
_foreach_transpose(x_, xt);
_foreach_transpose(weight_, wt);
std::vector<at::Tensor> bias_real;
std::vector<at::Tensor> dx = npu_gmm(grad, wt, bias_real, group_list_real, 0, group_list_type);
std::vector<at::Tensor> dw = npu_gmm(xt, grad, bias_real, group_list_real, 2, group_list_type);
std::vector<at::Tensor> dbias;
std::vector<at::Tensor> dw_output;
for (int i = 0; i < num_w; i++) {
at::Tensor dw_tensor = dw[i].reshape(weight[i].sizes());
dw_output.emplace_back(dw_tensor);
}
return std::make_tuple(dx, dw_output, dbias);
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>> npu_gmm_backward(
const std::vector<at::Tensor>& grad,
const std::vector<at::Tensor>& x,
const std::vector<at::Tensor>& weight,
const c10::optional<at::Tensor>& group_list,
c10::optional<int64_t> group_list_type)
{
auto num_w = weight.size();
auto group_list_real = group_list.value_or(at::Tensor());
std::vector<at::Tensor> bias_real;
at::TensorList weight_(weight);
std::vector<at::Tensor> wt;
_foreach_transpose(weight_, wt);
std::vector<at::Tensor> dx = npu_gmm(grad, wt, bias_real, group_list_real, 0, group_list_type);
at::Tensor weight_tensor = weight.at(0);
bool is_weight_transposed = _is_transposed(weight_tensor);
std::vector<at::Tensor> dw;
if (is_weight_transposed == true) {
at::Tensor grad_tensor = grad.at(0).contiguous();
at::TensorList grad_(grad_tensor);
std::vector<at::Tensor> gradt;
_foreach_transpose(grad_, gradt);
std::vector<at::Tensor> dwt = npu_gmm(gradt, x, bias_real, group_list_real, 2, group_list_type);
at::TensorList dwt_(dwt);
_foreach_transpose(dwt_, dw);
} else {
at::TensorList x_(x);
std::vector<at::Tensor> xt;
_foreach_transpose(x_, xt);
dw = npu_gmm(xt, grad, bias_real, group_list_real, 2, group_list_type);
}
std::vector<at::Tensor> dbias;
std::vector<at::Tensor> dw_output;
for (int i = 0; i < num_w; i++) {
at::Tensor dw_tensor = dw[i].reshape(weight[i].sizes());
dw_output.emplace_back(dw_tensor);
}
return std::make_tuple(dx, dw_output, dbias);
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>> npu_gmm_backward_fusion(
const std::vector<at::Tensor>& grad,
const std::vector<at::Tensor>& weight,
const c10::optional<std::vector<int64_t>> group_list,
c10::optional<int64_t> group_list_type)
{
auto num_w = weight.size();
auto group_list_real = group_list.value_or(std::vector<int64_t>{});
const at::TensorList weight_(weight);
std::vector<at::Tensor> wt;
_foreach_transpose(weight_, wt);
std::vector<at::Tensor> bias_real;
std::vector<at::Tensor> dx = npu_gmm(grad, wt, bias_real, group_list_real, 0, group_list_type);
std::vector<at::Tensor> dbias;
std::vector<at::Tensor> dw_output;
return std::make_tuple(dx, dw_output, dbias);
}
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>> npu_gmm_backward_fusion(
const std::vector<at::Tensor>& grad,
const std::vector<at::Tensor>& weight,
const c10::optional<at::Tensor>& group_list,
c10::optional<int64_t> group_list_type)
{
auto num_w = weight.size();
auto group_list_real = group_list.value_or(at::Tensor());
std::vector<at::Tensor> bias_real;
at::TensorList weight_(weight);
std::vector<at::Tensor> wt;
_foreach_transpose(weight_, wt);
std::vector<at::Tensor> dx = npu_gmm(grad, wt, bias_real, group_list_real, 0, group_list_type);
std::vector<at::Tensor> dbias;
std::vector<at::Tensor> dw_output;
return std::make_tuple(dx, dw_output, dbias);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
using gmmv1 = std::vector<at::Tensor>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, c10::optional<std::vector<int64_t>>, c10::optional<int64_t>, c10::optional<int64_t>);
using gmmv2 = std::vector<at::Tensor>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const c10::optional<at::Tensor>&, c10::optional<int64_t>, c10::optional<int64_t>);
using gmmv1_backward = std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const c10::optional<std::vector<int64_t>>, c10::optional<int64_t>);
using gmmv2_backward = std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const c10::optional<at::Tensor>&, c10::optional<int64_t>);
using gmmv1_backward_fusion = std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const c10::optional<std::vector<int64_t>>, c10::optional<int64_t>);
using gmmv2_backward_fusion = std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>, std::vector<at::Tensor>>(*)(const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, const c10::optional<at::Tensor>&, c10::optional<int64_t>);
m.def("npu_gmm", (gmmv1)&npu_gmm, "grouped matmul forward with group_list type List[int]");
m.def("npu_gmm_backward", (gmmv1_backward)&npu_gmm_backward, "grouped matmul backward with group_list type List[int]");
m.def("npu_gmm_backward_fusion", (gmmv1_backward_fusion)&npu_gmm_backward_fusion, "grouped matmul backward with group_list type List[int]");
m.def("npu_gmm", (gmmv2)&npu_gmm, "grouped matmul forward with group_list type Tensor");
m.def("npu_gmm_backward", (gmmv2_backward)&npu_gmm_backward, "grouped matmul backward with group_list type Tensor");
m.def("npu_gmm_backward_fusion", (gmmv2_backward_fusion)&npu_gmm_backward_fusion, "grouped matmul backward with group_list type Tensor");
}