#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
const static int64_t IN_NOT_SPLIT_OUT_NOT_SPLIT = 0;
const static int64_t IN_SPLIT_OUT_NOT_SPLIT = 1;
const static int64_t IN_NOT_SPLIT_OUT_SPLIT = 2;
const static int64_t IN_SPLIT_OUT_SPLIT = 3;
const static int64_t INT4_NUMS_IN_INT32 = 8;
const static int64_t DEFAULT_SPLIT = -1;
const static int64_t M_SPLIT = 0;
const static int64_t K_SPLIT = 2;
using npu_preparation = at_npu::native::OpPreparation;
void check_dims(int64_t split_item, size_t num_x, size_t num_weight, size_t num_group_list)
{
TORCH_CHECK(num_x > 0 && num_weight > 0,
"Invalid inputs: neither x nor weight could be empty." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(split_item == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item == IN_SPLIT_OUT_NOT_SPLIT ||
split_item == IN_NOT_SPLIT_OUT_SPLIT || split_item == IN_SPLIT_OUT_SPLIT,
"Invalid value of split_item [", split_item, "], which should only be one of 0/1/2/3."
+ OPS_ERROR(ErrCode::PARAM));
if (split_item == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item == IN_SPLIT_OUT_NOT_SPLIT) {
if (num_group_list > 0) {
TORCH_CHECK(num_x == 1 && num_weight == num_group_list, "Invalid inputs. "
"When split_item = 0 or 1 and input group_list is not None, "
"the following two conditions are supposed to be satisfied: "
"(1) length of x equals 1; (2) length of weight equals that of group_list. "
"Actual lengths: x [", num_x, "], weight [", num_weight, "], "
"group_list [", num_group_list, "]." + OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(num_x == num_weight,
"When split_item = 0 or 1 and input group_list is None, "
"the num of x tensors must equal the num of weight tensors."
"Actual lengths: x [", num_x, "], weight [", num_weight, "]." + OPS_ERROR(ErrCode::PARAM));
}
}
}
void create_new_tensor_multi_dim(std::vector<at::Tensor> &y, const at::Tensor &x_i, size_t n,
c10::TensorOptions options)
{
auto x_sizes = x_i.sizes();
std::vector<int64_t> y_sizes(x_sizes.begin(), x_sizes.end());
y_sizes.at(x_sizes.size() - 1) = static_cast<int64_t>(n);
auto output_size = op_infer::array_to_small_vector(y_sizes);
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
void create_new_tensor(std::vector<at::Tensor> &y, size_t dim_m, size_t dim_n, c10::TensorOptions options)
{
auto output_size = op_infer::array_to_small_vector({dim_m, dim_n});
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
void create_new_tensor_batch(std::vector<at::Tensor> &y, size_t batch, size_t dim_m, size_t dim_n,
c10::TensorOptions options)
{
auto output_size = op_infer::array_to_small_vector({batch, dim_m, dim_n});
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
void calculate_dim_m(size_t& dim_m, size_t num_x, const at::TensorList x)
{
for (size_t i = 0; i < num_x; i++) {
dim_m += x[i].sizes()[0];
}
}
bool is_weight_trans(const at::Tensor &tensor)
{
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
return tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2);
}
std::vector<at::Tensor> npu_grouped_matmul(const at::TensorList x,
const at::TensorList weight,
const c10::optional<at::TensorList> bias,
const c10::optional<at::TensorList> scale,
const c10::optional<at::TensorList> offset,
const c10::optional<at::TensorList> antiquant_scale,
const c10::optional<at::TensorList> antiquant_offset,
const c10::optional<at::TensorList> per_token_scale,
const c10::optional<at::Tensor>& group_list,
const c10::optional<at::TensorList> activation_input,
const c10::optional<at::TensorList> activation_quant_scale,
const c10::optional<at::TensorList> activation_quant_offset,
c10::optional<int64_t> split_item,
c10::optional<int64_t> group_type,
c10::optional<int64_t> group_list_type,
c10::optional<int64_t> act_type,
const c10::OptionalIntArrayRef tuning_config,
c10::optional<int64_t> output_dtype,
c10::optional<int64_t> x_dtype,
c10::optional<int64_t> weight_dtype,
c10::optional<int64_t> scale_dtype,
c10::optional<int64_t> per_token_scale_dtype)
{
TORCH_CHECK(group_type.has_value(),
"Requires manual passing group_type, current is None.", OPS_ERROR(ErrCode::VALUE));
int64_t group_type_value = group_type.value();
TORCH_CHECK(group_type_value == DEFAULT_SPLIT || group_type_value == M_SPLIT || group_type_value == K_SPLIT,
"Use Tensor input with current cann version, "
"The group type must be -1, 0 or 2, but now is [", group_type_value, "]",
OPS_ERROR(ErrCode::VALUE));
static const bool is_grouped_matmul_V4_available = check_aclnn_kernel_available("aclnnGroupedMatmulV4");
if (C10_UNLIKELY(!is_grouped_matmul_V4_available)) {
TORCH_CHECK(!group_list.has_value(),
"group_list don't support Tensor input with current cann version. "
"Please update cann version to 8.0.RC3 or higher, or use List[int] as input.",
OPS_ERROR(ErrCode::VALUE));
auto num_x = x.size();
auto num_weight = weight.size();
auto group_list_real = at::IntArrayRef{};
size_t num_group_list = 0;
int64_t split_item_value = split_item.value_or(0);
check_dims(split_item_value, num_x, num_weight, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x[0].options().dtype(
output_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(output_dtype.value()))
: x[0].scalar_type());
if (split_item_value == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item_value == IN_SPLIT_OUT_NOT_SPLIT) {
y.reserve(num_x);
for (size_t i = 0; i < num_x; i++) {
create_new_tensor_multi_dim(y, x[i], weight[i].size(1), options);
}
} else if (split_item_value == IN_NOT_SPLIT_OUT_SPLIT || split_item_value == IN_SPLIT_OUT_SPLIT) {
if (num_x > 1) {
size_t dim_m = 0;
calculate_dim_m(dim_m, num_x, x);
create_new_tensor(y, dim_m, weight[0].sizes()[1], options);
} else if (num_x == 1) {
create_new_tensor(y, x[0].sizes()[0], weight[0].sizes()[1], options);
}
}
at::TensorList result = at::TensorList(y);
auto bias_real = bias.value_or(at::TensorList());
auto scale_real = scale.value_or(at::TensorList());
auto offset_real = offset.value_or(at::TensorList());
auto antiquant_scale_real = antiquant_scale.value_or(at::TensorList());
auto antiquant_offset_real = antiquant_offset.value_or(at::TensorList());
EXEC_NPU_CMD(aclnnGroupedMatmul, x, weight, bias_real, scale_real, offset_real, antiquant_scale_real,
antiquant_offset_real, group_list_real, split_item_value, result);
return y;
}
auto num_x = x.size();
bool singleWeight = weight.size() == 1 && weight[0].sizes().size() == 3;
auto num_weight = singleWeight ? static_cast<size_t>(weight[0].size(0)) : static_cast<size_t>(weight.size());
auto group_list_real = group_list.value_or(at::Tensor());
auto num_group_list = group_list_real.size(0);
int64_t split_item_value = split_item.value_or(0);
check_dims(split_item_value, num_x, num_weight, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x[0].options().dtype(
output_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(output_dtype.value()))
: x[0].scalar_type());
size_t dim_num_w = weight[0].sizes().size();
size_t n0 = static_cast<size_t>(weight[0].size(dim_num_w - 1));
bool weight_trans = is_weight_trans(weight[0]);
#if VERSION_BETWEEN(V2R1, V2R7)
bool mxfp4_valid = x_dtype.has_value() && weight_dtype.has_value() &&
(x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)) &&
(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1));
#endif
#if VERSION_BETWEEN(V2R8, VERSION_NEWEST)
bool mxfp4_valid = false;
if (x_dtype.has_value()) {
mxfp4_valid = (x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) || x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2));
} else {
mxfp4_valid = x[0].scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
if (weight_dtype.has_value()) {
mxfp4_valid = mxfp4_valid && (weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) || weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2));
} else {
mxfp4_valid = mxfp4_valid && weight[0].scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
#endif
size_t n_new = (mxfp4_valid && !weight_trans) ? (n0 * FP4_IN_INT8) : n0;
if (mxfp4_valid) {
TORCH_CHECK(x[0].size(1) != 1, "In mxfp4, dim K should not be 2.", OPS_ERROR(ErrCode::VALUE));
}
if (split_item_value == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item_value == IN_SPLIT_OUT_NOT_SPLIT) {
if (num_group_list > 0) {
y.reserve(num_group_list);
int64_t glr_value_0 = group_list_real[0].item<int64_t>();
TORCH_CHECK(glr_value_0 >= 0,
"group_list[0] should be larger than or equal to 0, but now is ", glr_value_0, "." +
OPS_ERROR(ErrCode::VALUE));
create_new_tensor(y, glr_value_0, n0, options);
int64_t glr_value_pre = glr_value_0;
for (int i = 1; i < num_group_list; i++) {
int64_t glr_value_cur = group_list_real[i].item<int64_t>();
TORCH_CHECK(glr_value_cur - glr_value_pre >= 0,
"group_list[", i, "] - group_list[", i - 1, "] should be larger than or equal to 0, but now is ",
glr_value_cur - glr_value_pre, "." + OPS_ERROR(ErrCode::VALUE));
size_t ni = singleWeight ? n0 : weight[i].size(dim_num_w - 1);
create_new_tensor(y, glr_value_cur - glr_value_pre, ni, options);
glr_value_pre = glr_value_cur;
}
} else {
y.reserve(num_x);
for (size_t i = 0; i < num_x; i++) {
size_t ni = singleWeight ? n0 : weight[i].size(dim_num_w - 1);
create_new_tensor_multi_dim(y, x[i], ni, options);
}
}
} else if (split_item_value == IN_NOT_SPLIT_OUT_SPLIT || split_item_value == IN_SPLIT_OUT_SPLIT) {
if (num_x > 1) {
size_t dim_m = 0;
for (size_t i = 0; i < num_x; i++) {
dim_m += static_cast<size_t>(x[i].size(0));
}
weight[0].dtype() == at::ScalarType::Int ?
create_new_tensor(y, dim_m, n0 * INT4_NUMS_IN_INT32, options) :
create_new_tensor(y, dim_m, n_new, options);
} else if (num_x == 1) {
if (group_type_value == K_SPLIT) {
TORCH_CHECK(num_weight == 1,
"When group_list is 2(K_SPLIT) and split_item is 2/3, the length of weight must equal x.");
weight[0].dtype() == at::ScalarType::Int ?
create_new_tensor_batch(y, num_group_list, x[0].size(0), n0 * INT4_NUMS_IN_INT32, options) :
create_new_tensor_batch(y, num_group_list, x[0].size(0), n_new, options);
} else {
(weight[0].dtype() == at::ScalarType::Int ||
(weight[0].dtype() == at::ScalarType::Float && weight[0].dtype() != x[0].dtype())) && (!weight_trans)
? create_new_tensor(y, x[0].size(0), n0 * INT4_NUMS_IN_INT32, options)
: create_new_tensor(y, x[0].size(0), n_new, options);
}
}
}
at::TensorList result = at::TensorList(y);
auto bias_real = bias.value_or(at::TensorList());
auto scale_real = scale.value_or(at::TensorList());
auto offset_real = offset.value_or(at::TensorList());
auto antiquant_scale_real = antiquant_scale.value_or(at::TensorList());
auto antiquant_offset_real = antiquant_offset.value_or(at::TensorList());
auto per_token_scale_real = per_token_scale.value_or(at::TensorList());
auto activation_input_real = activation_input.value_or(at::TensorList());
auto activation_quant_scale_real = activation_quant_scale.value_or(at::TensorList());
auto activation_quant_offset_real = activation_quant_offset.value_or(at::TensorList());
auto act_out = at::TensorList();
auto dynamic_quant_scale_out = at::TensorList();
int64_t group_list_type_value = group_list_type.value_or(0);
int64_t act_type_value = act_type.value_or(0);
auto tuning_config_real = tuning_config.value_or(at::IntArrayRef{});
TensorListWrapper x_wrapper = {x,
x_dtype.has_value() ? c10_npu::GetAclDataType(x_dtype.value())
: npu_preparation::convert_to_acl_data_type(x[0].scalar_type())};
TensorListWrapper weight_wrapper = {weight,
weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
: npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
TensorListWrapper scale_wrapper = {scale_real,
scale_dtype.has_value() ? c10_npu::GetAclDataType(scale_dtype.value())
: (scale_real.empty() ? aclDataType::ACL_UINT64
: npu_preparation::convert_to_acl_data_type(scale_real[0].scalar_type()))};
TensorListWrapper per_token_scale_wrapper = {per_token_scale_real,
per_token_scale_dtype.has_value()
? c10_npu::GetAclDataType(per_token_scale_dtype.value())
: (per_token_scale_real.empty() ? aclDataType::ACL_FLOAT
: npu_preparation::convert_to_acl_data_type(per_token_scale_real[0].scalar_type()))};
TensorListWrapper antiquant_scale_wrapper = {antiquant_scale_real,
antiquant_scale_real.empty()
? aclDataType::ACL_FLOAT16
: (antiquant_scale_real[0].scalar_type() == at::ScalarType::Byte ? aclDataType::ACL_FLOAT8_E8M0
: npu_preparation::convert_to_acl_data_type(antiquant_scale_real[0].scalar_type()))};
int64_t weight_format = at_npu::native::custom_ops::get_npu_format(weight[0]);
const bool is_weight_nz = (weight_format == ACL_FORMAT_FRACTAL_NZ) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_2) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_4) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_16);
if (is_weight_nz) {
static const bool is_weight_nz_available = check_aclnn_kernel_available("aclnnGroupedMatmulWeightNz");
TORCH_CHECK(is_weight_nz_available,
"Format of weight in npu_grouped_matmul is FRACTAL_NZ, current CANN version "
"do not support with this format. Please try to update the version of CANN."
+ OPS_ERROR(ErrCode::PARAM));
int64_t quant_per_group_size = 0;
EXEC_NPU_CMD(aclnnGroupedMatmulWeightNz, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real, antiquant_scale_wrapper,
antiquant_offset_real, per_token_scale_wrapper, group_list_real, activation_input_real,
activation_quant_scale_real, activation_quant_offset_real, split_item_value, group_type_value,
group_list_type_value, act_type_value, tuning_config_real, quant_per_group_size,
result, act_out, dynamic_quant_scale_out);
return y;
}
static const bool is_grouped_matmul_V5_available = check_aclnn_kernel_available("aclnnGroupedMatmulV5");
static const bool dtypeValid = x[0].scalar_type() != at::ScalarType::Float8_e5m2 &&
x[0].scalar_type() != at::ScalarType::Float8_e4m3fn &&
!x_dtype.has_value() && !weight_dtype.has_value();
if (!is_grouped_matmul_V5_available || !dtypeValid || mxfp4_valid) {
EXEC_NPU_CMD(aclnnGroupedMatmulV4, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real, antiquant_scale_real,
antiquant_offset_real, per_token_scale_wrapper, group_list_real, activation_input_real,
activation_quant_scale_real, activation_quant_offset_real, split_item_value, group_type_value,
group_list_type_value, act_type_value, result, act_out, dynamic_quant_scale_out);
} else {
EXEC_NPU_CMD(aclnnGroupedMatmulV5, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real, antiquant_scale_real,
antiquant_offset_real, per_token_scale_wrapper, group_list_real, activation_input_real,
activation_quant_scale_real, activation_quant_offset_real, split_item_value, group_type_value,
group_list_type_value, act_type_value, tuning_config_real, result, act_out, dynamic_quant_scale_out);
}
return y;
}
std::vector<at::Tensor> npu_grouped_matmul(const at::TensorList x,
const at::TensorList weight,
const c10::optional<at::TensorList> bias,
const c10::optional<at::TensorList> scale,
const c10::optional<at::TensorList> offset,
const c10::optional<at::TensorList> antiquant_scale,
const c10::optional<at::TensorList> antiquant_offset,
const c10::optional<at::TensorList> per_token_scale,
c10::OptionalIntArrayRef group_list,
const c10::optional<at::TensorList> activation_input,
const c10::optional<at::TensorList> activation_quant_scale,
const c10::optional<at::TensorList> activation_quant_offset,
c10::optional<int64_t> split_item,
c10::optional<int64_t> group_type,
c10::optional<int64_t> group_list_type,
c10::optional<int64_t> act_type,
c10::optional<at::ScalarType> output_dtype)
{
auto num_x = x.size();
auto num_weight = weight.size();
auto group_list_real = group_list.value_or(at::IntArrayRef{});
auto num_group_list = group_list_real.size();
int64_t split_item_value = split_item.value_or(0);
check_dims(split_item_value, num_x, num_weight, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x[0].options().dtype(output_dtype.value_or(x[0].scalar_type()));
if (split_item_value == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item_value == IN_SPLIT_OUT_NOT_SPLIT) {
if (num_group_list > 0) {
y.reserve(num_group_list);
TORCH_CHECK(group_list_real[0] >= 0,
"group_list[0] should be larger than or equal to 0, but now is ", group_list_real[0], "." +
OPS_ERROR(ErrCode::VALUE));
create_new_tensor(y, group_list_real[0], weight[0].sizes()[1], options);
for (size_t i = 1; i < num_group_list; i++) {
TORCH_CHECK(group_list_real[i] - group_list_real[i - 1] >= 0,
"group_list[", i, "] - group_list[", i - 1, "] should be larger than or equal to 0, but now is ",
group_list_real[i] - group_list_real[i - 1], "." + OPS_ERROR(ErrCode::VALUE));
create_new_tensor(y, group_list_real[i] - group_list_real[i - 1], weight[i].sizes()[1], options);
}
} else {
y.reserve(num_x);
for (size_t i = 0; i < num_x; i++) {
create_new_tensor_multi_dim(y, x[i], weight[i].size(1), options);
}
}
} else if (split_item_value == IN_NOT_SPLIT_OUT_SPLIT || split_item_value == IN_SPLIT_OUT_SPLIT) {
if (num_x > 1) {
size_t dim_m = 0;
for (size_t i = 0; i < num_x; i++) {
dim_m += x[i].sizes()[0];
}
create_new_tensor(y, dim_m, weight[0].sizes()[1], options);
} else if (num_x == 1) {
create_new_tensor(y, x[0].sizes()[0], weight[0].sizes()[1], options);
}
}
at::TensorList result = at::TensorList(y);
auto bias_real = bias.value_or(at::TensorList());
auto scale_real = scale.value_or(at::TensorList());
auto offset_real = offset.value_or(at::TensorList());
auto antiquant_scale_real = antiquant_scale.value_or(at::TensorList());
auto antiquant_offset_real = antiquant_offset.value_or(at::TensorList());
EXEC_NPU_CMD(aclnnGroupedMatmul, x, weight, bias_real, scale_real, offset_real, antiquant_scale_real,
antiquant_offset_real, group_list_real, split_item_value, result);
return y;
}
}