* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace {
constexpr int64_t DIM_X = 2;
constexpr int64_t DIM_EXPERT_IDX = 2;
constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2;
constexpr int64_t EXPERT_TOKENS_COUNT = 1;
constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2;
constexpr int64_t QUANT_MODE_UNQUANT = -1;
constexpr int64_t QUANT_MODE_STATIC = 0;
constexpr int64_t QUANT_MODE_DYNAMIC = 1;
constexpr int64_t QUANT_MODE_MXFP8_E5M2 = 2;
constexpr int64_t QUANT_MODE_MXFP8_E4M3FN = 3;
constexpr int64_t QUANT_MODE_HIF8_CAST = 6;
constexpr int64_t QUANT_MODE_HIF8_PERTENSOR = 7;
constexpr int64_t QUANT_MODE_HIF8_PER_TOKEN_DIM = 8;
constexpr int64_t QUANT_MODE_MXFP4_E2M1 = 9;
constexpr int64_t QUANT_MODE_FP8_PERBLOCK_E5M2 =11;
constexpr int64_t QUANT_MODE_FP8_PERBLOCK_E4M3FN = 12;
constexpr int64_t QUANT_MODE_INT4_DYNAMIC = 13;
constexpr int64_t MXQUANT_BLOCK_SIZE = 32;
constexpr int64_t FP8_QUANT_BLOCK_SIZE = 128;
constexpr int64_t PAD_TO_EVEN_FACTOR = 2;
constexpr int64_t INT4_NUMS_IN_INT8 = 2;
constexpr int64_t EXPERT_NUM_V2 = 128;
constexpr int64_t EXPERT_NUM_MIN_V2 = 0;
constexpr int64_t EXPERT_NUM_MAX_V2 = 128;
constexpr int64_t HIDDEN_DIM_VAL_V2 = 2048;
};
inline bool IsQuantModeMXFP4(int64_t quantMode) {
return quantMode == QUANT_MODE_MXFP4_E2M1;
}
inline bool IsQuantModeMXFP8(int64_t quantMode) {
return quantMode == QUANT_MODE_MXFP8_E5M2 || quantMode == QUANT_MODE_MXFP8_E4M3FN;
}
inline bool IsQuantModeFP8(int64_t quantMode) {
return quantMode == QUANT_MODE_FP8_PERBLOCK_E5M2 || quantMode == QUANT_MODE_FP8_PERBLOCK_E4M3FN;
}
inline bool IsQuantModeHIF8(int64_t quantMode) {
return quantMode == QUANT_MODE_HIF8_CAST || quantMode == QUANT_MODE_HIF8_PERTENSOR || quantMode == QUANT_MODE_HIF8_PER_TOKEN_DIM;
}
inline bool IsInt4OutputDType(c10::optional<int64_t> xDtype) {
return xDtype.has_value() && xDtype.value() == static_cast<int64_t>(c10_npu::DType::INT4);
}
inline bool IsDynamicQuantInt4Output(int64_t quantMode, c10::optional<int64_t> xDtype) {
return quantMode == QUANT_MODE_INT4_DYNAMIC && (!xDtype.has_value() || IsInt4OutputDType(xDtype));
}
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
using tensor_list = std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>;
at::IntArrayRef init_new_active_expert_range(at::IntArrayRef &active_expert_range, int64_t expert_num) {
if (active_expert_range.empty()) {
static std::vector<int64_t> default_active_expert_range = {0, expert_num};
return at::IntArrayRef(default_active_expert_range);
} else {
return active_expert_range;
}
}
static bool CheckV2Case(int hidden_dim, int64_t expert_num, at::IntArrayRef active_expert_range,
int64_t expert_tokens_num_type, int64_t quant_mode) {
if (expert_num == EXPERT_NUM_V2 && active_expert_range[0] == EXPERT_NUM_MIN_V2 &&
active_expert_range[1] == EXPERT_NUM_MAX_V2 && hidden_dim == HIDDEN_DIM_VAL_V2) {
if (quant_mode == -1 && expert_tokens_num_type == 1) {
return true;
}
}
return false;
}
tensor_list npu_moe_init_routing_v2(const at::Tensor &x, const at::Tensor &expert_idx,
const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type,
bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type,
c10::optional<int64_t> x_dtype) {
#if !VERSION_BETWEEN(V2R7, VERSION_NEWEST)
TORCH_CHECK(!IsQuantModeMXFP8(quant_mode),
"Unsupported quant_mode:",
quant_mode,
" on this version of torch with torch_npu. Please upgrade to at least v2.7.",
OPS_ERROR(ErrCode::PARAM));
#endif
int64_t x_dim = x.dim();
TORCH_CHECK(x_dim == DIM_X,
"The x should be ",
DIM_X,
"-Dimension, current is ",
x_dim,
"-Dimension.",
OPS_ERROR(ErrCode::PARAM));
int64_t expert_idx_dim = expert_idx.dim();
TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX,
"The expert_idx should be ",
DIM_EXPERT_IDX,
"-Dimension, current is ",
expert_idx_dim,
"-Dimension.",
OPS_ERROR(ErrCode::PARAM));
at::IntArrayRef current_active_expert_range = init_new_active_expert_range(active_expert_range, expert_num);
int64_t active_expert_range_length = current_active_expert_range.size();
TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE,
"The length of list active_expert_range should be ",
LENGTH_ACTIVE_EXPERT_RANGE,
", current is ",
active_expert_range_length,
".",
OPS_ERROR(ErrCode::PARAM));
int expert_length = current_active_expert_range[1] - current_active_expert_range[0];
auto x_size = x.sizes();
auto expert_idx_size = expert_idx.sizes();
const at::Tensor &p_scale = c10::value_or_else(scale, [] { return at::Tensor(); });
const at::Tensor &p_offset = c10::value_or_else(offset, [] { return at::Tensor(); });
int bs = x_size[0];
int h = x_size[1];
aclDataType x_acl_type = c10_npu::GetAclDataType(x_dtype.value_or(static_cast<int64_t>(x.scalar_type())));
if (x_acl_type == aclDataType::ACL_FLOAT4_E2M1) {
h = h * 2;
}
int k = expert_idx_size[1];
bool using_v2 = CheckV2Case(h, expert_num, active_expert_range, expert_tokens_num_type, quant_mode);
TORCH_CHECK(!(quant_mode == QUANT_MODE_DYNAMIC && IsInt4OutputDType(x_dtype)),
"INT4 dynamic quantization uses quant_mode=13. quant_mode=1 only supports INT8 dynamic quantization.",
OPS_ERROR(ErrCode::PARAM));
if (quant_mode == QUANT_MODE_INT4_DYNAMIC) {
TORCH_CHECK(drop_pad_mode == 0,
"INT4 dynamic quantization only supports drop_pad_mode=0.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"INT4 dynamic quantization only supports float32 or bfloat16 x.",
OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(!p_offset.defined(),
"INT4 dynamic quantization does not support offset.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(h % INT4_NUMS_IN_INT8 == 0,
"INT4 dynamic quantization requires the hidden size of x to be even.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!x_dtype.has_value() || IsInt4OutputDType(x_dtype),
"The optional parameter x_dtype must be torch_npu.int4 or None when quant_mode=13.",
OPS_ERROR(ErrCode::PARAM));
if (p_scale.defined()) {
TORCH_CHECK(p_scale.scalar_type() == at::kFloat,
"The scale dtype must be float32 in INT4 dynamic quantization.",
OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(p_scale.dim() == DIM_X,
"The scale shape supports only 2D in INT4 dynamic quantization.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(p_scale.size(0) == 1,
"The first dim of scale must be 1 in INT4 dynamic quantization.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(p_scale.size(1) == x_size[1],
"The second dim of scale should be the same as the second dim of x in INT4 dynamic quantization.",
OPS_ERROR(ErrCode::PARAM));
}
}
int64_t expanded_scale_len = 0;
at::Tensor expanded_x;
if (drop_pad_mode == 1) {
if (quant_mode == QUANT_MODE_UNQUANT) {
expanded_x = npu_preparation::apply_tensor_without_format(x, {expert_num, expert_capacity, h});
} else {
expanded_x = npu_preparation::apply_tensor_without_format(
{expert_num, expert_capacity, h}, x.options().dtype(at::kChar));
}
expanded_scale_len = expert_num * expert_capacity;
} else {
expanded_scale_len = (active_num <= 0) ? bs * k : std::min<int64_t>(active_num, bs * k);
switch (quant_mode) {
#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
case QUANT_MODE_MXFP8_E5M2:
expanded_x = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, h}, x.options().dtype(at::kFloat8_e5m2));
break;
case QUANT_MODE_MXFP8_E4M3FN:
expanded_x = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, h}, x.options().dtype(at::kFloat8_e4m3fn));
break;
case QUANT_MODE_MXFP4_E2M1:
expanded_x = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, h / 2}, x.options().dtype(at::kByte));
break;
#endif
case QUANT_MODE_STATIC:
case QUANT_MODE_DYNAMIC:
expanded_x = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, h}, x.options().dtype(at::kChar));
break;
case QUANT_MODE_INT4_DYNAMIC:
expanded_x = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, h / INT4_NUMS_IN_INT8}, x.options().dtype(at::kByte));
break;
case QUANT_MODE_HIF8_CAST:
case QUANT_MODE_HIF8_PERTENSOR:
case QUANT_MODE_HIF8_PER_TOKEN_DIM: {
expanded_x =
npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kByte));
break;
}
case QUANT_MODE_FP8_PERBLOCK_E5M2: {
expanded_x =
npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kFloat8_e5m2));
break;
}
case QUANT_MODE_FP8_PERBLOCK_E4M3FN: {
expanded_x =
npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kFloat8_e4m3fn));
break;
}
default:
expanded_x = npu_preparation::apply_tensor_without_format(x, {expanded_scale_len, x_size[1]});
}
}
at::Tensor expanded_row_idx = npu_preparation::apply_tensor_without_format(expert_idx, {bs * k});
at::Tensor expert_tokens_count_or_cumsum;
if (Is310PBoolCheck()) {
expert_tokens_count_or_cumsum =
npu_preparation::apply_tensor_without_format({expert_length}, x.options().dtype(at::kInt));
} else if (expert_tokens_num_type < EXPERT_TOKENS_KEY_VALUE) {
expert_tokens_count_or_cumsum =
npu_preparation::apply_tensor_without_format({expert_length}, x.options().dtype(at::kLong));
} else if (expert_tokens_num_type == EXPERT_TOKENS_KEY_VALUE) {
expert_tokens_count_or_cumsum =
npu_preparation::apply_tensor_without_format({expert_num, 2}, x.options().dtype(at::kLong));
}
if ((using_v2 && !op_plugin::utils::is_gte_cann_version_850alpha003()) || Is310PBoolCheck()) {
at::Tensor expert_tokens_before_capacity =
npu_preparation::apply_tensor_without_format({expert_num}, x.options().dtype(at::kInt));
expert_capacity = 0;
drop_pad_mode = 0;
int64_t expert_tokens_count_or_cumsum_flag = Is310PBoolCheck() ? 1 : 2;
bool expert_tokens_before_capacity_flag = false;
if (bs == 0) {
expert_tokens_count_or_cumsum.zero_();
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expert_tokens_before_capacity);
}
EXEC_NPU_CMD(aclnnMoeInitRoutingV2,
x,
expert_idx,
active_num,
expert_capacity,
expert_num,
drop_pad_mode,
expert_tokens_count_or_cumsum_flag,
expert_tokens_before_capacity_flag,
expanded_x,
expanded_row_idx,
expert_tokens_count_or_cumsum,
expert_tokens_before_capacity);
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expert_tokens_before_capacity);
}
#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
at::Tensor expanded_scale;
if (IsQuantModeMXFP8(quant_mode)) {
int64_t scale_cols = (h + MXQUANT_BLOCK_SIZE - 1) / MXQUANT_BLOCK_SIZE;
scale_cols = (scale_cols + PAD_TO_EVEN_FACTOR - 1) / PAD_TO_EVEN_FACTOR * PAD_TO_EVEN_FACTOR;
expanded_scale = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, scale_cols}, x.options().dtype(at::kFloat8_e8m0fnu));
} else if (IsQuantModeFP8(quant_mode)) {
int64_t block_size = FP8_QUANT_BLOCK_SIZE * 2;
expanded_scale = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, op_infer::CeilDiv(h, block_size), 2}, x.options().dtype(at::kFloat));
} else if (quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn) && scale.has_value()) {
expanded_scale = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
} else if (quant_mode == -1 && (x_acl_type == aclDataType::ACL_FLOAT4_E2M1) && scale.has_value()) {
expanded_scale = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
} else if (IsQuantModeMXFP4(quant_mode)) {
expanded_scale = npu_preparation::apply_tensor_without_format(
{expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
} else {
expanded_scale =
npu_preparation::apply_tensor_without_format({expanded_scale_len}, x.options().dtype(at::kFloat));
}
#else
at::Tensor expanded_scale =
npu_preparation::apply_tensor_without_format({expanded_scale_len}, x.options().dtype(at::kFloat));
#endif
auto scale_scalar_dtype = p_scale.defined() ? p_scale.scalar_type() : at::kFloat;
auto expanded_scale_scalar_dtype = expanded_scale.defined() ? expanded_scale.scalar_type() : at::kFloat;
TensorWrapper scale_wrapper = {
p_scale,
(quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn ||
x_acl_type == aclDataType::ACL_FLOAT4_E2M1)) ?
aclDataType::ACL_FLOAT8_E8M0:
npu_preparation::convert_to_acl_data_type(scale_scalar_dtype)
};
TensorWrapper expanded_scale_wrapper = {
expanded_scale,
(quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn ||
x_acl_type == aclDataType::ACL_FLOAT4_E2M1)) ?
aclDataType::ACL_FLOAT8_E8M0:
npu_preparation::convert_to_acl_data_type(expanded_scale_scalar_dtype)
};
TensorWrapper x_wrapper = {x, (quant_mode == -1 && x_dtype.has_value()) ?
c10_npu::GetAclDataType(x_dtype.value()):
npu_preparation::convert_to_acl_data_type(x.scalar_type())};
TensorWrapper expanded_x_wrapper = {expanded_x, npu_preparation::convert_to_acl_data_type(expanded_x.scalar_type())};
if (quant_mode == -1 && x_dtype.has_value()) {
expanded_x_wrapper.dtype = c10_npu::GetAclDataType(x_dtype.value());
} else if (IsDynamicQuantInt4Output(quant_mode, x_dtype)) {
expanded_x_wrapper.dtype = aclDataType::ACL_INT4;
} else if (IsQuantModeHIF8(quant_mode)) {
expanded_x_wrapper.dtype = aclDataType::ACL_HIFLOAT8;
} else if (IsQuantModeMXFP4(quant_mode)) {
expanded_x_wrapper.dtype = aclDataType::ACL_FLOAT4_E2M1;
expanded_scale_wrapper.dtype = aclDataType::ACL_FLOAT8_E8M0;
}
if (bs == 0) {
expert_tokens_count_or_cumsum.zero_();
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
EXEC_NPU_CMD(aclnnMoeInitRoutingV3,
x_wrapper,
expert_idx,
scale_wrapper,
p_offset,
active_num,
expert_capacity,
expert_num,
drop_pad_mode,
expert_tokens_num_type,
expert_tokens_num_flag,
quant_mode,
active_expert_range,
row_idx_type,
expanded_x_wrapper,
expanded_row_idx,
expert_tokens_count_or_cumsum,
expanded_scale_wrapper);
return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
}