#include <cstring>
#include <set>
#include "op_plugin/OpApiInterface.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
static const int DIM_ONE = 1;
static const int DIM_TWO = 2;
static const int DIM_THREE = 3;
static const int DIM_FOUR = 4;
static const int NO_QUANT_MODE = 0;
static const int PERTENSOR_QUANT_MODE = 1;
static const int MX_QUANT_MODE = 6;
const std::set<int> SUPPORT_WORLD_SIZE_LIST{ 2, 4, 8, 16, 32, 64, 128, 256 };
static const std::set<int64_t> SUPPORT_MX_QUANT_DTYPE_LIST = { static_cast<int64_t>(c10_npu::DType::FLOAT8_E4M3FN),
static_cast<int64_t>(c10_npu::DType::FLOAT8_E5M2), static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1),
static_cast<int64_t>(at::ScalarType::Float8_e4m3fn), static_cast<int64_t>(at::ScalarType::Float8_e5m2) };
static bool IsContains(const std::set<int64_t> &list, int64_t value)
{
return list.find(value) != list.end();
}
static bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
if (tensor.dim() < 2 || tensor.dim() > 6) {
return false;
}
int64_t dim1 = tensor.dim() - 1;
int64_t dim2 = tensor.dim() - 2;
if (tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2)) {
int64_t tmpNxD = tensor.size(dim1) * tensor.size(dim2);
for (int64_t batchDim = tensor.dim() - 3; batchDim >= 0; batchDim--) {
if (tensor.stride(batchDim) != tmpNxD) {
return false;
}
tmpNxD *= tensor.size(batchDim);
}
if (tensor.size(dim1) == 1 && tensor.size(dim2) == 1) {
return false;
}
return true;
}
return false;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_alltoallv_quant_gmm(const at::Tensor &gmm_x,
const at::Tensor &gmm_weight, const at::Tensor &gmm_x_scale, const at::Tensor &gmm_weight_scale,
c10::string_view hcom, int64_t ep_world_size, at::IntArrayRef send_counts, at::IntArrayRef recv_counts,
int64_t gmm_y_dtype, const c10::optional<at::Tensor> &send_counts_tensor,
const c10::optional<at::Tensor> &recv_counts_tensor, const c10::optional<at::Tensor> &mm_x,
const c10::optional<at::Tensor> &mm_weight, const c10::optional<at::Tensor> &mm_x_scale,
const c10::optional<at::Tensor> &mm_weight_scale, c10::optional<int64_t> gmm_x_quant_mode,
c10::optional<int64_t> gmm_weight_quant_mode, c10::optional<int64_t> mm_x_quant_mode,
c10::optional<int64_t> mm_weight_quant_mode, bool permute_out_flag, c10::OptionalIntArrayRef group_size,
c10::optional<int64_t> gmm_x_dtype, c10::optional<int64_t> gmm_weight_dtype,
c10::optional<int64_t> gmm_x_scale_dtype, c10::optional<int64_t> gmm_weight_scale_dtype,
c10::optional<int64_t> mm_x_dtype, c10::optional<int64_t> mm_weight_dtype, c10::optional<int64_t> mm_x_scale_dtype,
c10::optional<int64_t> mm_weight_scale_dtype, c10::optional<int64_t> mm_y_dtype,
c10::optional<c10::string_view> comm_mode)
{
TORCH_CHECK(gmm_x_quant_mode.has_value(), "The input gmm_x_quant_mode must be provided, but got None.",
OPS_ERROR(ErrCode::PARAM));
int64_t gmm_x_quant_mode_value = gmm_x_quant_mode.value();
TORCH_CHECK(gmm_weight_quant_mode.has_value(), "The input gmm_weight_quant_mode must be provided, but got None.",
OPS_ERROR(ErrCode::PARAM));
int64_t gmm_weight_quant_mode_value = gmm_weight_quant_mode.value();
TORCH_CHECK(gmm_x.defined(), "The input tensor gmm_x cannot be None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_x.dim() == DIM_TWO, "The dim of gmm_x should be 2, but got ", gmm_x.dim(), ".",
OPS_ERROR(ErrCode::PARAM));
if (gmm_x_quant_mode_value == PERTENSOR_QUANT_MODE) {
TORCH_CHECK(gmm_x_dtype.has_value(), "The input gmm_x_dtype must be provided for pertensor quant mode.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The input gmm_x_dtype should be hifloat8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(gmm_x.scalar_type() == at::ScalarType::Byte || gmm_x.scalar_type() == at::ScalarType::Char,
"The input gmm_x tensor dtype should be uint8 or int8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(gmm_x.scalar_type())), ".", OPS_ERROR(ErrCode::TYPE));
} else if (gmm_x_quant_mode_value == MX_QUANT_MODE) {
if (gmm_x.scalar_type() == at::ScalarType::Byte) {
TORCH_CHECK(gmm_x_dtype.has_value(),
"The input gmm_x_dtype must be provided for mx quant mode when gmm_x dtype is uint8.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1),
"The input gmm_x_dtype should be float4_e2m1 for mx quant mode when gmm_x dtype is uint8, but got ",
op_plugin::utils::DTypeToString(gmm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
if (gmm_x_dtype.has_value()) {
TORCH_CHECK(IsContains(SUPPORT_MX_QUANT_DTYPE_LIST, gmm_x_dtype.value()),
"The input gmm_x_dtype must be float8_e4m3_fn, float8_e5m2 or float4_e2m1 for mx quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
TORCH_CHECK(gmm_weight.defined(), "The input tensor gmm_weight cannot be None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_weight.dim() == DIM_THREE, "The dim of gmm_weight should be 3, but got ", gmm_weight.dim(), ".",
OPS_ERROR(ErrCode::PARAM));
if (gmm_weight_quant_mode_value == PERTENSOR_QUANT_MODE) {
TORCH_CHECK(gmm_weight_dtype.has_value(),
"The input gmm_weight_dtype must be provided for pertensor quant mode.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The input gmm_weight_dtype should be hifloat8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(gmm_weight.scalar_type() == at::ScalarType::Byte ||
gmm_weight.scalar_type() == at::ScalarType::Char,
"The input gmm_weight tensor dtype should be uint8 or int8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(gmm_weight.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (gmm_weight_quant_mode_value == MX_QUANT_MODE) {
if (gmm_weight.scalar_type() == at::ScalarType::Byte) {
TORCH_CHECK(gmm_weight_dtype.has_value(),
"The input gmm_weight_dtype must be provided for mx quant mode when gmm_weight dtype is uint8.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1),
"The input gmm_weight_dtype should be float4_e2m1_fn for mx quant mode when gmm_weight dtype is uint8, "
"but got ",
op_plugin::utils::DTypeToString(gmm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
if (gmm_weight_dtype.has_value()) {
TORCH_CHECK(IsContains(SUPPORT_MX_QUANT_DTYPE_LIST, gmm_weight_dtype.value()),
"The input gmm_weight_dtype must be float8_e4m3_fn, float8_e5m2 or float4_e2m1 for mx quant mode, but "
"got ",
op_plugin::utils::DTypeToString(gmm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
TORCH_CHECK(gmm_x_scale.defined(), "The input tensor gmm_x_scale cannot be None.", OPS_ERROR(ErrCode::PARAM));
int gmm_x_scale_dim = (gmm_x_quant_mode_value == MX_QUANT_MODE) ? DIM_THREE : DIM_ONE;
TORCH_CHECK(gmm_x_scale.dim() == gmm_x_scale_dim, "The dim of gmm_x_scale should be ", gmm_x_scale_dim,
", but got ", gmm_x_scale.dim(), ".", OPS_ERROR(ErrCode::PARAM));
if (gmm_x_quant_mode_value == PERTENSOR_QUANT_MODE) {
if (gmm_x_scale_dtype.has_value()) {
TORCH_CHECK(gmm_x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT) ||
gmm_x_scale_dtype.value() == static_cast<int64_t>(at::ScalarType::Float),
"The input gmm_x_scale_dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_x_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(gmm_x_scale.scalar_type() == at::ScalarType::Float,
"The input gmm_x_scale tensor dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(gmm_x_scale.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (gmm_x_quant_mode_value == MX_QUANT_MODE) {
TORCH_CHECK(gmm_x_scale_dtype.has_value(), "The input gmm_x_scale_dtype must be provided for mx quant mode.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"The input gmm_x_scale_dtype should be float8_e8m0 for mx quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_x_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(gmm_weight_scale.defined(), "The input tensor gmm_weight_scale cannot be None.",
OPS_ERROR(ErrCode::PARAM));
int gmm_weight_scale_dim = (gmm_weight_quant_mode_value == MX_QUANT_MODE) ? DIM_FOUR : DIM_ONE;
TORCH_CHECK(gmm_weight_scale.dim() == gmm_weight_scale_dim, "The dim of gmm_weight_scale should be ",
gmm_weight_scale_dim, ", but got ", gmm_weight_scale.dim(), ".", OPS_ERROR(ErrCode::PARAM));
if (gmm_weight_quant_mode_value == PERTENSOR_QUANT_MODE) {
if (gmm_weight_scale_dtype.has_value()) {
TORCH_CHECK(gmm_weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT) ||
gmm_weight_scale_dtype.value() == static_cast<int64_t>(at::ScalarType::Float),
"The input gmm_weight_scale_dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_weight_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(gmm_weight_scale.scalar_type() == at::ScalarType::Float,
"The input gmm_weight_scale tensor dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(gmm_weight_scale.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (gmm_weight_quant_mode_value == MX_QUANT_MODE) {
TORCH_CHECK(gmm_weight_scale_dtype.has_value(),
"The input gmm_weight_scale_dtype must be provided for mx quant mode.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(gmm_weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"The input gmm_weight_scale_dtype should be float8_e8m0 for mx quant mode, but got ",
op_plugin::utils::DTypeToString(gmm_weight_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
char *hcom_ptr = const_cast<char *>(hcom.data());
TORCH_CHECK(SUPPORT_WORLD_SIZE_LIST.find(ep_world_size) != SUPPORT_WORLD_SIZE_LIST.end(),
"The input ep_world_size should be in [2, 4, 8, 16, 32, 64, 128, 256], but got ", ep_world_size, ".",
OPS_ERROR(ErrCode::VALUE));
const at::Tensor &send_count_tensor_real = send_counts_tensor.value_or(at::Tensor());
const at::Tensor &recv_count_tensor_real = recv_counts_tensor.value_or(at::Tensor());
const at::Tensor &mm_x_real = mm_x.value_or(at::Tensor());
const at::Tensor &mm_weight_real = mm_weight.value_or(at::Tensor());
const at::Tensor &mm_x_scale_real = mm_x_scale.value_or(at::Tensor());
const at::Tensor &mm_weight_scale_real = mm_weight_scale.value_or(at::Tensor());
bool is_gmm_mxfp4 = false;
if ((gmm_x_quant_mode_value == MX_QUANT_MODE) && (gmm_weight_quant_mode_value == MX_QUANT_MODE)) {
if ((gmm_x_dtype.has_value() && gmm_weight_dtype.has_value()) &&
(gmm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) &&
gmm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1))) {
is_gmm_mxfp4 = true;
}
}
bool gmm_weight_trans = is_transpose_last_two_dims(gmm_weight);
int64_t n1 = gmm_weight.size(DIM_TWO);
if (is_gmm_mxfp4 && !gmm_weight_trans) {
n1 *= FP4_IN_INT8;
}
int64_t a = 0;
for (auto &i : recv_counts) {
a += i;
}
aclDataType gmm_y_acltype = c10_npu::GetAclDataType(gmm_y_dtype);
at::ScalarType gmm_y_scalar_dtype = at_npu::native::OpPreparation::convert_to_scalar_type(gmm_y_acltype);
auto gmm_y = at_npu::native::OpPreparation::apply_tensor_without_format({ a, n1 }, c10::dtype(gmm_y_scalar_dtype));
at::Tensor mm_y{ nullptr };
int64_t mm_x_quant_mode_value = NO_QUANT_MODE;
int64_t mm_weight_quant_mode_value = NO_QUANT_MODE;
bool mm_x_has_value = mm_x.has_value();
bool mm_weight_has_value = mm_weight.has_value();
if (mm_x_has_value != mm_weight_has_value) {
std::string error_msg = "The input mm_x and mm_weight must be both provided or both None, but the input ";
if (mm_x_has_value && !mm_weight_has_value) {
error_msg += "mm_x is provided and mm_weight is None.";
} else if (!mm_x_has_value && mm_weight_has_value) {
error_msg += "mm_weight is provided and mm_x is None.";
}
TORCH_CHECK(false, error_msg, OPS_ERROR(ErrCode::PARAM));
}
if (mm_x.has_value() && mm_weight.has_value()) {
TORCH_CHECK(mm_y_dtype.has_value(),
"The input mm_y_dtype must be provided when mm_x and mm_weight are present.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_x_quant_mode.has_value(), "The input mm_x_quant_mode must be provided when mm_x is present.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_weight_quant_mode.has_value(),
"The input mm_weight_quant_mode must be provided when mm_weight is present.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_x_scale.has_value(), "The input mm_x_scale must be provided when mm_x is present.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_weight_scale.has_value(),
"The input mm_weight_scale must be provided when mm_weight is present.", OPS_ERROR(ErrCode::PARAM));
mm_x_quant_mode_value = mm_x_quant_mode.value();
mm_weight_quant_mode_value = mm_weight_quant_mode.value();
const at::Tensor &mm_x_value = mm_x.value();
const at::Tensor &mm_weight_value = mm_weight.value();
TORCH_CHECK(mm_x_value.dim() == DIM_TWO, "The dim of mm_x should be 2, but got ", mm_x_value.dim(), ".",
OPS_ERROR(ErrCode::PARAM));
if (mm_x_quant_mode_value == PERTENSOR_QUANT_MODE) {
TORCH_CHECK(mm_x_dtype.has_value(), "The input mm_x_dtype must be provided for pertensor quant mode.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The input mm_x_dtype should be hifloat8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(mm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(mm_x_value.scalar_type() == at::ScalarType::Byte ||
mm_x_value.scalar_type() == at::ScalarType::Char,
"The input mm_x tensor dtype should be uint8 or int8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(mm_x_value.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (mm_x_quant_mode_value == MX_QUANT_MODE) {
if (mm_x_value.scalar_type() == at::ScalarType::Byte) {
TORCH_CHECK(mm_x_dtype.has_value(),
"The input mm_x_dtype must be provided for mx quant mode when mm_x dtype is uint8.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1),
"The input mm_x_dtype should be float4_e2m1 for mx quant mode when gmm_x dtype is uint8, but got ",
op_plugin::utils::DTypeToString(mm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
if (mm_x_dtype.has_value()) {
TORCH_CHECK(IsContains(SUPPORT_MX_QUANT_DTYPE_LIST, mm_x_dtype.value()),
"The input mm_x_dtype must be float8_e4m3_fn, float8_e5m2 or float4_e2m1 for mx quant mode, but "
"got ",
op_plugin::utils::DTypeToString(mm_x_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
TORCH_CHECK(mm_weight_value.dim() == DIM_TWO, "The dim of mm_weight should be 2, but got ",
mm_weight_value.dim(), ".", OPS_ERROR(ErrCode::PARAM));
if (mm_weight_quant_mode_value == PERTENSOR_QUANT_MODE) {
TORCH_CHECK(mm_weight_dtype.has_value(),
"The input mm_weight_dtype must be provided for pertensor quant mode.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
"The input mm_weight_dtype should be hifloat8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(mm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(mm_weight_value.scalar_type() == at::ScalarType::Byte ||
mm_weight_value.scalar_type() == at::ScalarType::Char,
"The input mm_weight tensor dtype should be uint8 or int8 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(mm_weight_value.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (mm_weight_quant_mode_value == MX_QUANT_MODE) {
if (mm_weight_value.scalar_type() == at::ScalarType::Byte) {
TORCH_CHECK(mm_weight_dtype.has_value(),
"The input mm_weight_dtype must be provided for mx quant mode when mm_weight dtype is uint8.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1),
"The input mm_weight_dtype should be float4_e2m1 for mx quant mode when mm_weight dtype is uint8, "
"but got ",
op_plugin::utils::DTypeToString(mm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
if (mm_weight_dtype.has_value()) {
TORCH_CHECK(IsContains(SUPPORT_MX_QUANT_DTYPE_LIST, mm_weight_dtype.value()),
"The input mm_weight_dtype must be float8_e4m3_fn, float8_e5m2 or float4_e2m1 for mx quant mode, "
"but got ",
op_plugin::utils::DTypeToString(mm_weight_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
if (mm_x_scale.has_value()) {
const at::Tensor &mm_x_scale_val = mm_x_scale.value();
TORCH_CHECK(mm_x_scale_val.dim() == ((mm_x_quant_mode_value == MX_QUANT_MODE) ? DIM_THREE : DIM_ONE),
"The dim of mm_x_scale should be ", ((mm_x_quant_mode_value == MX_QUANT_MODE) ? DIM_THREE : DIM_ONE),
" for quant_mode ", mm_x_quant_mode_value, ", but got ", mm_x_scale_val.dim(), ".",
OPS_ERROR(ErrCode::PARAM));
if (mm_x_quant_mode_value == PERTENSOR_QUANT_MODE) {
if (mm_x_scale_dtype.has_value()) {
TORCH_CHECK(mm_x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT) ||
mm_x_scale_dtype.value() == static_cast<int64_t>(at::ScalarType::Float),
"The input mm_x_scale_dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(mm_x_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(mm_x_scale_val.scalar_type() == at::ScalarType::Float,
"The input mm_x_scale tensor dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(mm_x_scale_val.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else {
TORCH_CHECK(mm_x_scale_dtype.has_value(), "mm_x_scale_dtype must be provided for mx quant mode.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_x_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"mm_x_scale_dtype should be float8_e8m0 for mx quant mode, but got ",
op_plugin::utils::DTypeToString(mm_x_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
if (mm_weight_scale.has_value()) {
const at::Tensor &mm_weight_scale_val = mm_weight_scale.value();
TORCH_CHECK(mm_weight_scale_val.dim() ==
((mm_weight_quant_mode_value == MX_QUANT_MODE) ? DIM_THREE : DIM_ONE),
"The dim of mm_weight_scale should be ",
((mm_weight_quant_mode_value == MX_QUANT_MODE) ? DIM_THREE : DIM_ONE), " for quant_mode ",
mm_weight_quant_mode_value, ", but got ", mm_weight_scale_val.dim(), ".", OPS_ERROR(ErrCode::PARAM));
if (mm_weight_quant_mode_value == PERTENSOR_QUANT_MODE) {
if (mm_weight_scale_dtype.has_value()) {
TORCH_CHECK(mm_weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT) ||
mm_weight_scale_dtype.value() == static_cast<int64_t>(at::ScalarType::Float),
"The input mm_weight_scale_dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(mm_weight_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(mm_weight_scale_val.scalar_type() == at::ScalarType::Float,
"The input mm_weight_scale tensor dtype should be float32 for pertensor quant mode, but got ",
op_plugin::utils::DTypeToString(static_cast<int64_t>(mm_weight_scale_val.scalar_type())), ".",
OPS_ERROR(ErrCode::TYPE));
} else if (mm_weight_quant_mode_value == MX_QUANT_MODE) {
TORCH_CHECK(mm_weight_scale_dtype.has_value(),
"The input mm_weight_scale_dtype must be provided for mx quant mode.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(mm_weight_scale_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0),
"The input mm_weight_scale_dtype should be float8_e8m0 for mx quant mode, but got ",
op_plugin::utils::DTypeToString(mm_weight_scale_dtype.value()), ".", OPS_ERROR(ErrCode::VALUE));
}
}
bool is_mm_mxfp4 = false;
if ((mm_x_quant_mode_value == MX_QUANT_MODE) && (mm_weight_quant_mode_value == MX_QUANT_MODE)) {
if ((mm_x_dtype.has_value() && mm_weight_dtype.has_value()) &&
(mm_x_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1) &&
mm_weight_dtype.value() == static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1))) {
is_mm_mxfp4 = true;
}
}
bool mm_weight_trans = is_transpose_last_two_dims(mm_weight.value_or(at::Tensor()));
int64_t bs = mm_x_value.size(0);
int64_t n2 = mm_weight_value.size(1);
if (is_mm_mxfp4 && !mm_weight_trans) {
n2 *= FP4_IN_INT8;
}
aclDataType mm_y_acltype = c10_npu::GetAclDataType(mm_y_dtype.value());
at::ScalarType mm_y_scalar_dtype = at_npu::native::OpPreparation::convert_to_scalar_type(mm_y_acltype);
mm_y = at_npu::native::OpPreparation::apply_tensor_without_format({ bs, n2 }, c10::dtype(mm_y_scalar_dtype));
} else {
TORCH_CHECK(!mm_x_quant_mode.has_value(), "The input mm_x_quant_mode should be None when mm_x is not provided.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!mm_x_dtype.has_value(), "The input mm_x_dtype should be None when mm_x is not provided.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!mm_x_scale_dtype.has_value(),
"The input mm_x_scale_dtype should be None when mm_x is not provided.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!mm_weight_quant_mode.has_value(),
"The input mm_weight_quant_mode should be None when mm_weight is not provided.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!mm_weight_dtype.has_value(),
"The input mm_weight_dtype should be None when mm_weight is not provided.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!mm_weight_scale_dtype.has_value(),
"The input mm_weight_scale_dtype should be None when mm_weight is not provided.",
OPS_ERROR(ErrCode::PARAM));
}
at::Tensor permute_out{ nullptr };
if (permute_out_flag) {
int64_t h1 = gmm_x.size(1);
aclDataType permute_out_acltype;
if (gmm_x_dtype.has_value()) {
permute_out_acltype = c10_npu::GetAclDataType(gmm_x_dtype.value());
} else {
permute_out_acltype = at_npu::native::OpPreparation::convert_to_acl_data_type(gmm_x.scalar_type());
}
at::ScalarType permute_out_scalar_dtype =
at_npu::native::OpPreparation::convert_to_scalar_type(permute_out_acltype);
permute_out =
at_npu::native::OpPreparation::apply_tensor_without_format({ a, h1 }, c10::dtype(permute_out_scalar_dtype));
}
at::IntArrayRef group_size_list = group_size.value_or(at::IntArrayRef{});
int64_t group_sizes = op_plugin::utils::check_and_get_group_size(group_size_list);
TensorWrapper gmm_x_wrapper = make_wrapper(gmm_x, gmm_x_dtype);
TensorWrapper gmm_weight_wrapper = make_wrapper(gmm_weight, gmm_weight_dtype);
TensorWrapper gmm_x_scale_wrapper = make_wrapper(gmm_x_scale, gmm_x_scale_dtype);
TensorWrapper gmm_weight_scale_wrapper = make_wrapper(gmm_weight_scale, gmm_weight_scale_dtype);
TensorWrapper mm_x_wrapper = make_wrapper(mm_x_real, mm_x_dtype);
TensorWrapper mm_weight_wrapper = make_wrapper(mm_weight_real, mm_weight_dtype);
TensorWrapper mm_x_scale_wrapper = make_wrapper(mm_x_scale_real, mm_x_scale_dtype);
TensorWrapper mm_weight_scale_wrapper = make_wrapper(mm_weight_scale_real, mm_weight_scale_dtype);
TensorWrapper permute_out_wrapper = { permute_out, (gmm_x_dtype.has_value()) ?
c10_npu::GetAclDataType(gmm_x_dtype.value()) :
at_npu::native::OpPreparation::convert_to_acl_data_type(gmm_x.scalar_type()) };
bool transposeGmmWeight = false;
bool transposeMmWeight = false;
if (comm_mode.has_value()) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnAlltoAllvQuantGroupedMatMulV2"),
"Current CANN version do not support this api. Please try to update the version of CANN." + OPS_ERROR(ErrCode::PARAM));
const char* comm_mode_real = const_cast<char *>(comm_mode.value().data());
EXEC_NPU_CMD(aclnnAlltoAllvQuantGroupedMatMulV2, gmm_x_wrapper, gmm_weight_wrapper, gmm_x_scale_wrapper,
gmm_weight_scale_wrapper, send_count_tensor_real, recv_count_tensor_real, mm_x_wrapper, mm_weight_wrapper,
mm_x_scale_wrapper, mm_weight_scale_wrapper, gmm_x_quant_mode_value, gmm_weight_quant_mode_value,
mm_x_quant_mode_value, mm_weight_quant_mode_value, hcom_ptr, comm_mode_real, ep_world_size, send_counts, recv_counts,
transposeGmmWeight, transposeMmWeight, group_sizes, permute_out_flag,
gmm_y, mm_y, permute_out_wrapper);
} else {
EXEC_NPU_CMD(aclnnAlltoAllvQuantGroupedMatMul, gmm_x_wrapper, gmm_weight_wrapper, gmm_x_scale_wrapper,
gmm_weight_scale_wrapper, send_count_tensor_real, recv_count_tensor_real, mm_x_wrapper, mm_weight_wrapper,
mm_x_scale_wrapper, mm_weight_scale_wrapper, gmm_x_quant_mode_value, gmm_weight_quant_mode_value,
mm_x_quant_mode_value, mm_weight_quant_mode_value, hcom_ptr, ep_world_size, send_counts, recv_counts,
transposeGmmWeight, transposeMmWeight, group_sizes, permute_out_flag, gmm_y, mm_y, permute_out_wrapper);
}
return std::tie(gmm_y, mm_y, permute_out);
}
}