#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
const static int64_t IN_NOT_SPLIT_OUT_NOT_SPLIT = 0;
const static int64_t IN_SPLIT_OUT_NOT_SPLIT = 1;
const static int64_t IN_NOT_SPLIT_OUT_SPLIT = 2;
const static int64_t IN_SPLIT_OUT_SPLIT = 3;
const static int64_t INT4_NUMS_IN_INT32 = 8;
const static int64_t DEFAULT_SPLIT = -1;
const static int64_t M_SPLIT = 0;
const static int64_t K_SPLIT = 2;
using npu_preparation = at_npu::native::OpPreparation;
static void check_dims(int64_t split_item, size_t num_x, size_t num_weight, size_t num_group_list) {
TORCH_CHECK(num_x > 0 && num_weight > 0,
"Invalid inputs: neither x nor weight could be empty." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(split_item == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item == IN_SPLIT_OUT_NOT_SPLIT ||
split_item == IN_NOT_SPLIT_OUT_SPLIT || split_item == IN_SPLIT_OUT_SPLIT,
"Invalid value of split_item [", split_item,
"], which should only be one of 0/1/2/3." + OPS_ERROR(ErrCode::PARAM));
if (split_item == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item == IN_SPLIT_OUT_NOT_SPLIT) {
if (num_group_list > 0) {
TORCH_CHECK(num_x == 1 && num_weight == num_group_list,
"Invalid inputs. "
"When split_item = 0 or 1 and input group_list is not None, "
"the following two conditions are supposed to be satisfied: "
"(1) length of x equals 1; (2) length of weight equals that of group_list. "
"Actual lengths: x [",
num_x, "], weight [", num_weight,
"], "
"group_list [",
num_group_list, "]." + OPS_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(num_x == num_weight,
"When split_item = 0 or 1 and input group_list is None, "
"the num of x tensors must equal the num of weight tensors."
"Actual lengths: x [",
num_x, "], weight [", num_weight, "]." + OPS_ERROR(ErrCode::PARAM));
}
}
}
static void create_new_tensor_multi_dim(
std::vector<at::Tensor> &y, const at::Tensor &x_i, size_t n, c10::TensorOptions options) {
auto x_sizes = x_i.sizes();
std::vector<int64_t> y_sizes(x_sizes.begin(), x_sizes.end());
y_sizes.at(x_sizes.size() - 1) = static_cast<int64_t>(n);
auto output_size = op_infer::array_to_small_vector(y_sizes);
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
static void create_new_tensor(std::vector<at::Tensor> &y, size_t dim_m, size_t dim_n, c10::TensorOptions options) {
auto output_size = op_infer::array_to_small_vector({dim_m, dim_n});
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
static void create_new_tensor_batch(
std::vector<at::Tensor> &y, size_t batch, size_t dim_m, size_t dim_n, c10::TensorOptions options) {
auto output_size = op_infer::array_to_small_vector({batch, dim_m, dim_n});
y.emplace_back(npu_preparation::apply_tensor_without_format(output_size, options));
}
static void calculate_dim_m(size_t &dim_m, size_t num_x, const at::TensorList x) {
for (size_t i = 0; i < num_x; i++) {
dim_m += x[i].sizes()[0];
}
}
static bool is_weight_trans(const at::Tensor &tensor) {
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
return tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2);
}
at::Tensor _scaled_grouped_mm(const at::Tensor &mat_a, const at::Tensor &mat_b, const at::Tensor &scale_a,
const at::Tensor &scale_b,
const c10::optional<at::Tensor> &offs,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &scale_result,
c10::optional<c10::ScalarType> out_dtype, bool use_fast_accum) {
TORCH_CHECK(c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950,
"This interface is supported only on the Ascend950 platform and after.", OPS_ERROR(ErrCode::PARAM));
const int32_t ndim_a = mat_a.dim();
const int32_t ndim_b = mat_b.dim();
const int32_t ndim_sa = scale_a.dim();
const int32_t ndim_sb = scale_b.dim();
TORCH_CHECK(
ndim_a == 2 || ndim_a == 3,
"mat_a dimension must be 2D or 3D, actual dimension: ", ndim_a
);
TORCH_CHECK(
ndim_b == 2 || ndim_b == 3,
"mat_b dimension must be 2D or 3D, actual dimension: ", ndim_b
);
bool is_fp8_a = (scale_a.scalar_type() == at::kFloat);
bool is_fp8_b = (scale_b.scalar_type() == at::kFloat);
bool is_mx_a = (scale_a.scalar_type() == at::kFloat8_e8m0fnu);
bool is_mx_b = (scale_b.scalar_type() == at::kFloat8_e8m0fnu);
if (is_fp8_a) {
TORCH_CHECK(ndim_sa == 1 || ndim_sa == 2,
"scale_a dimension must be 1D or 2D for fp8, actual dimension: ", ndim_sa);
} else if (is_mx_a) {
TORCH_CHECK(ndim_sa == 2 || ndim_sa == 3,
"scale_a dimension must be 2D or 3D for mx, actual dimension: ", ndim_sa);
} else {
TORCH_CHECK(false, "scale_a must be float32 or float8_e8m0fnu, but got ", scale_a.dtype());
}
if (is_fp8_b) {
TORCH_CHECK(ndim_sb == 1 || ndim_sb == 2,
"scale_b dimension must be 1D or 2D for fp8, actual dimension: ", ndim_sb);
} else if (is_mx_b) {
TORCH_CHECK(ndim_sb == 2 || ndim_sb == 3 || ndim_sb == 4,
"scale_b dimension must be 2D/3D/4D for mx, actual dimension: ", ndim_sb);
} else {
TORCH_CHECK(false, "scale_b must be float32 or float8_e8m0fnu, but got ", scale_b.dtype());
}
if (is_fp8_a && is_fp8_b) {
if (ndim_a == 2) {
int scale_multiplier = 1;
if (ndim_b == 2) {
scale_multiplier = offs->size(0);
}
TORCH_CHECK(ndim_sa == 1, "scale_a must be 1D for 2D mat_a (KC mode), but got ", ndim_sa, "D");
TORCH_CHECK(scale_a.is_contiguous(), "scale_a must be contiguous");
TORCH_CHECK(scale_a.size(0) == mat_a.size(0) * scale_multiplier, "scale_a size[0] must equal ", mat_a.size(0) * scale_multiplier);
} else {
TORCH_CHECK(ndim_sa == 2, "scale_a must be 2D for 3D mat_a, but got ", ndim_sa, "D");
TORCH_CHECK(scale_a.stride(1) == 1, "scale_a must be contiguous in last dim");
TORCH_CHECK(scale_a.size(0) == mat_a.size(0), "scale_a size[0] must equal mat_a batch dim (G)");
}
if (ndim_b == 2) {
TORCH_CHECK(ndim_sb == 1, "scale_b must be 1D for 2D mat_b, but got ", ndim_sb, "D");
TORCH_CHECK(scale_b.is_contiguous(), "scale_b must be contiguous in last dim");
int scale_multiplier = 1;
if (ndim_a == 2) {
scale_multiplier = offs->size(0);
}
int64_t expected_n = mat_b.size(1);
TORCH_CHECK(scale_b.size(0) == expected_n * scale_multiplier, "scale_b size mismatch");
} else {
TORCH_CHECK(ndim_sb == 2 , "scale_b must be 2 for 2D mat_b, but got ", ndim_sb, "D");
if (ndim_sb == 2) {
TORCH_CHECK(scale_b.stride(1) == 1, "scale_b must be contiguous in last dim");
TORCH_CHECK(scale_b.size(0) == mat_b.size(0), "scale_b size[0] must equal mat_b batch dim (G)");
TORCH_CHECK(scale_b.size(1) == mat_b.size(2), "scale_b size[1] must equal mat_b N dim");
}
}
}
const bool use_a_2d = (ndim_a == 2);
const bool use_b_2d = (ndim_b == 2);
if (!use_a_2d || !use_b_2d) {
TORCH_CHECK(
mat_a.size(-1) == mat_b.size(-2),
"contraction dimension mismatch between mat_a and mat_b"
);
}
TORCH_CHECK(
!bias.has_value(),
"NPU _scaled_grouped_mm does not support bias yet"
);
TORCH_CHECK(
!scale_result.has_value(),
"NPU _scaled_grouped_mm does not support scale_result yet"
);
const bool req_offsets = (use_a_2d || use_b_2d);
TORCH_CHECK(
offs.has_value() == req_offsets,
"offsets required when using 2D input tensor"
);
if (offs.has_value()) {
TORCH_CHECK(offs->dim() == 1, "offsets tensor must be 1D");
TORCH_CHECK(offs->dtype() == at::kInt, "offsets data type must be int32");
}
auto out_type = out_dtype.value_or(at::kBFloat16);
TORCH_CHECK(
out_type == at::kBFloat16,
"_scaled_grouped_mm on NPU only supports BF16 output type"
);
std::vector<at::Tensor> x_vec;
if (use_a_2d) {
x_vec.push_back(mat_a);
} else {
x_vec.push_back(mat_a.reshape({-1, mat_a.size(-1)}));
}
at::TensorList x = at::TensorList(x_vec);
std::vector<at::Tensor> weight_vec;
if (use_b_2d) {
auto b_expanded = mat_b.unsqueeze(0).expand({1, -1, -1}).contiguous();
weight_vec.push_back(b_expanded);
} else {
weight_vec.push_back(mat_b);
}
at::TensorList weight = at::TensorList(weight_vec);
std::vector<at::Tensor> per_token_scale_vec;
per_token_scale_vec.push_back(scale_a);
c10::optional<at::TensorList> per_token_scale = c10::optional<at::TensorList>(at::TensorList(per_token_scale_vec));
std::vector<at::Tensor> scale_vec;
if (scale_b.dim() == 2 || scale_b.dim() == 4) {
scale_vec.push_back(scale_b);
} else {
scale_vec.push_back(scale_b.unsqueeze(0).expand({1, -1, -1, -1}).contiguous());
}
c10::optional<at::TensorList> scale = c10::optional<at::TensorList>(at::TensorList(scale_vec));
c10::optional<at::Tensor> group_list = c10::nullopt;
if (offs.has_value()) {
group_list = offs->to(at::kLong);
}
int64_t split_item_val = IN_NOT_SPLIT_OUT_SPLIT;
c10::optional<int64_t> split_item = split_item_val;
c10::optional<int64_t> group_type = DEFAULT_SPLIT;
if (use_b_2d) {
group_type = K_SPLIT;
} else if (use_a_2d) {
group_type = M_SPLIT;
} else {
if (mat_b.size(0) == 1) {
group_type = K_SPLIT;
} else if (mat_a.size(0) == mat_b.size(0)) {
group_type = DEFAULT_SPLIT;
} else {
group_type = M_SPLIT;
}
}
TORCH_CHECK(group_type != K_SPLIT,
"K_SPLIT (group_type=2) is not supported yet. "
"This occurs when mat_b is 2D or mat_b has only 1 weight shared by multiple groups. "
"Current mat_a size: ", mat_a.sizes(), ", mat_b size: ", mat_b.sizes());
c10::optional<int64_t> group_list_type = 0;
c10::optional<int64_t> act_type = 0;
c10::optional<int64_t> output_dtype = static_cast<int64_t>(out_type);
c10::optional<at::TensorList> bias_tl = c10::nullopt;
c10::optional<at::TensorList> offset_tl = c10::nullopt;
c10::optional<at::TensorList> antiquant_scale_tl = c10::nullopt;
c10::optional<at::TensorList> antiquant_offset_tl = c10::nullopt;
c10::optional<at::TensorList> activation_input_tl = c10::nullopt;
c10::optional<at::TensorList> activation_quant_scale_tl = c10::nullopt;
c10::optional<at::TensorList> activation_quant_offset_tl = c10::nullopt;
c10::OptionalIntArrayRef tuning_config = c10::OptionalIntArrayRef{};
c10::optional<int64_t> x_dtype = c10::nullopt;
c10::optional<int64_t> weight_dtype = c10::nullopt;
c10::optional<int64_t> scale_dtype = c10::nullopt;
c10::optional<int64_t> per_token_scale_dtype = c10::nullopt;
TORCH_CHECK(
group_type.has_value(), "Requires manual passing group_type, current is None.", OPS_ERROR(ErrCode::VALUE));
int64_t group_type_value = group_type.value();
TORCH_CHECK(group_type_value == DEFAULT_SPLIT || group_type_value == M_SPLIT || group_type_value == K_SPLIT,
"Use Tensor input with current cann version, "
"The group type must be -1, 0 or 2, but now is [",
group_type_value, "]", OPS_ERROR(ErrCode::VALUE));
static const bool is_grouped_matmul_V4_available = check_aclnn_kernel_available("aclnnGroupedMatmulV4");
if (C10_UNLIKELY(!is_grouped_matmul_V4_available)) {
TORCH_CHECK(!group_list.has_value(),
"group_list don't support Tensor input with current cann version. "
"Please update cann version to 8.0.RC3 or higher, or use List[int] as input.",
OPS_ERROR(ErrCode::VALUE));
auto num_x = x.size();
auto num_weight = weight.size();
auto group_list_real = at::IntArrayRef{};
size_t num_group_list = 0;
int64_t split_item_value = split_item.value_or(0);
check_dims(split_item_value, num_x, num_weight, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x[0].options().dtype(output_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(output_dtype.value()))
: x[0].scalar_type());
if (split_item_value == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item_value == IN_SPLIT_OUT_NOT_SPLIT) {
y.reserve(num_x);
for (size_t i = 0; i < num_x; i++) {
create_new_tensor_multi_dim(y, x[i], weight[i].size(1), options);
}
} else if (split_item_value == IN_NOT_SPLIT_OUT_SPLIT || split_item_value == IN_SPLIT_OUT_SPLIT) {
if (num_x > 1) {
size_t dim_m = 0;
calculate_dim_m(dim_m, num_x, x);
create_new_tensor(y, dim_m, weight[0].sizes()[1], options);
} else if (num_x == 1) {
create_new_tensor(y, x[0].sizes()[0], weight[0].sizes()[1], options);
}
}
at::TensorList result = at::TensorList(y);
auto bias_real = bias_tl.value_or(at::TensorList());
auto scale_real = scale.value_or(at::TensorList());
auto offset_real = offset_tl.value_or(at::TensorList());
auto antiquant_scale_real = antiquant_scale_tl.value_or(at::TensorList());
auto antiquant_offset_real = antiquant_offset_tl.value_or(at::TensorList());
EXEC_NPU_CMD(aclnnGroupedMatmul, x, weight, bias_real, scale_real, offset_real, antiquant_scale_real,
antiquant_offset_real, group_list_real, split_item_value, result);
return y[0];
}
auto num_x = x.size();
bool singleWeight = weight.size() == 1 && weight[0].sizes().size() == 3;
auto num_weight = singleWeight ? static_cast<size_t>(weight[0].size(0)) : static_cast<size_t>(weight.size());
auto group_list_real = group_list.value_or(at::Tensor());
auto num_group_list = group_list_real.size(0);
int64_t split_item_value = split_item.value_or(0);
check_dims(split_item_value, num_x, num_weight, num_group_list);
std::vector<at::Tensor> y;
c10::TensorOptions options = x[0].options().dtype(output_dtype.has_value()
? npu_preparation::convert_to_scalar_type(c10_npu::GetAclDataType(output_dtype.value()))
: x[0].scalar_type());
size_t dim_num_w = weight[0].sizes().size();
size_t n0 = static_cast<size_t>(weight[0].size(dim_num_w - 1));
bool weight_trans = is_weight_trans(weight[0]);
#if VERSION_BETWEEN(V2R1, V2R7)
bool mxfp4_valid = x_dtype.has_value() && weight_dtype.has_value() &&
(x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2)) &&
(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2) ||
weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1));
#endif
#if VERSION_BETWEEN(V2R8, VERSION_NEWEST)
bool mxfp4_valid = false;
if (x_dtype.has_value()) {
mxfp4_valid = (x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2));
} else {
mxfp4_valid = x[0].scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
if (weight_dtype.has_value()) {
mxfp4_valid = mxfp4_valid &&
(weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) ||
weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2));
} else {
mxfp4_valid = mxfp4_valid && weight[0].scalar_type() == at::ScalarType::Float4_e2m1fn_x2;
}
#endif
size_t n_new = (mxfp4_valid && !weight_trans) ? (n0 * FP4_IN_INT8) : n0;
if (mxfp4_valid) {
TORCH_CHECK(x[0].size(1) != 1, "In mxfp4, dim K should not be 2.", OPS_ERROR(ErrCode::VALUE));
}
if (split_item_value == IN_NOT_SPLIT_OUT_NOT_SPLIT || split_item_value == IN_SPLIT_OUT_NOT_SPLIT) {
if (num_group_list > 0) {
y.reserve(num_group_list);
int64_t glr_value_0 = group_list_real[0].item<int64_t>();
TORCH_CHECK(glr_value_0 >= 0, "group_list[0] should be larger than or equal to 0, but now is ", glr_value_0,
"." + OPS_ERROR(ErrCode::VALUE));
create_new_tensor(y, glr_value_0, n0, options);
int64_t glr_value_pre = glr_value_0;
for (int i = 1; i < num_group_list; i++) {
int64_t glr_value_cur = group_list_real[i].item<int64_t>();
TORCH_CHECK(glr_value_cur - glr_value_pre >= 0, "group_list[", i, "] - group_list[", i - 1,
"] should be larger than or equal to 0, but now is ", glr_value_cur - glr_value_pre,
"." + OPS_ERROR(ErrCode::VALUE));
size_t ni = singleWeight ? n0 : weight[i].size(dim_num_w - 1);
create_new_tensor(y, glr_value_cur - glr_value_pre, ni, options);
glr_value_pre = glr_value_cur;
}
} else {
y.reserve(num_x);
for (size_t i = 0; i < num_x; i++) {
size_t ni = singleWeight ? n0 : weight[i].size(dim_num_w - 1);
create_new_tensor_multi_dim(y, x[i], ni, options);
}
}
} else if (split_item_value == IN_NOT_SPLIT_OUT_SPLIT || split_item_value == IN_SPLIT_OUT_SPLIT) {
if (num_x > 1) {
size_t dim_m = 0;
for (size_t i = 0; i < num_x; i++) {
dim_m += static_cast<size_t>(x[i].size(0));
}
weight[0].dtype() == at::ScalarType::Int ? create_new_tensor(y, dim_m, n0 * INT4_NUMS_IN_INT32, options)
: create_new_tensor(y, dim_m, n_new, options);
} else if (num_x == 1) {
if (group_type_value == K_SPLIT) {
TORCH_CHECK(num_weight == 1,
"When group_list is 2(K_SPLIT) and split_item is 2/3, the length of weight must equal x.");
weight[0].dtype() == at::ScalarType::Int
? create_new_tensor_batch(y, num_group_list, x[0].size(0), n0 * INT4_NUMS_IN_INT32, options)
: create_new_tensor_batch(y, num_group_list, x[0].size(0), n_new, options);
} else {
(weight[0].dtype() == at::ScalarType::Int ||
(weight[0].dtype() == at::ScalarType::Float && weight[0].dtype() != x[0].dtype())) &&
(!weight_trans)
? create_new_tensor(y, x[0].size(0), n0 * INT4_NUMS_IN_INT32, options)
: create_new_tensor(y, x[0].size(0), n_new, options);
}
}
}
at::TensorList result = at::TensorList(y);
auto bias_real = bias_tl.value_or(at::TensorList());
auto scale_real = scale.value_or(at::TensorList());
auto offset_real = offset_tl.value_or(at::TensorList());
auto antiquant_scale_real = antiquant_scale_tl.value_or(at::TensorList());
auto antiquant_offset_real = antiquant_offset_tl.value_or(at::TensorList());
auto per_token_scale_real = per_token_scale.value_or(at::TensorList());
auto activation_input_real = activation_input_tl.value_or(at::TensorList());
auto activation_quant_scale_real = activation_quant_scale_tl.value_or(at::TensorList());
auto activation_quant_offset_real = activation_quant_offset_tl.value_or(at::TensorList());
auto act_out = at::TensorList();
auto dynamic_quant_scale_out = at::TensorList();
int64_t group_list_type_value = group_list_type.value_or(0);
int64_t act_type_value = act_type.value_or(0);
auto tuning_config_real = tuning_config.value_or(at::IntArrayRef{});
TensorListWrapper x_wrapper = {x,
x_dtype.has_value() ? c10_npu::GetAclDataType(x_dtype.value())
: npu_preparation::convert_to_acl_data_type(x[0].scalar_type())};
TensorListWrapper weight_wrapper = {weight,
weight_dtype.has_value() ? c10_npu::GetAclDataType(weight_dtype.value())
: npu_preparation::convert_to_acl_data_type(weight[0].scalar_type())};
TensorListWrapper scale_wrapper = {scale_real,
scale_dtype.has_value()
? c10_npu::GetAclDataType(scale_dtype.value())
: (scale_real.empty() ? aclDataType::ACL_UINT64
: npu_preparation::convert_to_acl_data_type(scale_real[0].scalar_type()))};
TensorListWrapper per_token_scale_wrapper = {per_token_scale_real,
per_token_scale_dtype.has_value()
? c10_npu::GetAclDataType(per_token_scale_dtype.value())
: (per_token_scale_real.empty()
? aclDataType::ACL_FLOAT
: npu_preparation::convert_to_acl_data_type(per_token_scale_real[0].scalar_type()))};
TensorListWrapper antiquant_scale_wrapper = {antiquant_scale_real,
antiquant_scale_real.empty()
? aclDataType::ACL_FLOAT16
: (antiquant_scale_real[0].scalar_type() == at::ScalarType::Byte
? aclDataType::ACL_FLOAT8_E8M0
: npu_preparation::convert_to_acl_data_type(antiquant_scale_real[0].scalar_type()))};
int64_t weight_format = at_npu::native::custom_ops::get_npu_format(weight[0]);
const bool is_weight_nz = (weight_format == ACL_FORMAT_FRACTAL_NZ) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_2) || (weight_format == ACL_FORMAT_FRACTAL_NZ_C0_4) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_16);
if (is_weight_nz) {
static const bool is_weight_nz_available = check_aclnn_kernel_available("aclnnGroupedMatmulWeightNz");
TORCH_CHECK(is_weight_nz_available,
"Format of weight in npu_grouped_matmul is FRACTAL_NZ, current CANN version "
"do not support with this format. Please try to update the version of CANN." +
OPS_ERROR(ErrCode::PARAM));
int64_t quant_per_group_size = 0;
EXEC_NPU_CMD(aclnnGroupedMatmulWeightNz, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real,
antiquant_scale_wrapper, antiquant_offset_real, per_token_scale_wrapper, group_list_real,
activation_input_real, activation_quant_scale_real, activation_quant_offset_real, split_item_value,
group_type_value, group_list_type_value, act_type_value, tuning_config_real, quant_per_group_size, result,
act_out, dynamic_quant_scale_out);
return y[0];
}
static const bool is_grouped_matmul_V5_available = check_aclnn_kernel_available("aclnnGroupedMatmulV5");
static const bool dtypeValid = x[0].scalar_type() != at::ScalarType::Float8_e5m2 &&
x[0].scalar_type() != at::ScalarType::Float8_e4m3fn && !x_dtype.has_value() && !weight_dtype.has_value();
if (!is_grouped_matmul_V5_available || !dtypeValid || mxfp4_valid) {
EXEC_NPU_CMD(aclnnGroupedMatmulV4, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real,
antiquant_scale_real, antiquant_offset_real, per_token_scale_wrapper, group_list_real,
activation_input_real, activation_quant_scale_real, activation_quant_offset_real, split_item_value,
group_type_value, group_list_type_value, act_type_value, result, act_out, dynamic_quant_scale_out);
} else {
EXEC_NPU_CMD(aclnnGroupedMatmulV5, x_wrapper, weight_wrapper, bias_real, scale_wrapper, offset_real,
antiquant_scale_real, antiquant_offset_real, per_token_scale_wrapper, group_list_real,
activation_input_real, activation_quant_scale_real, activation_quant_offset_real, split_item_value,
group_type_value, group_list_type_value, act_type_value, tuning_config_real, result, act_out,
dynamic_quant_scale_out);
}
return y[0];
}
}