import unittest
import itertools
import importlib.util
from os import scandir
import numpy as np
import random
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
def numpy_hifloat8():
try:
from en_dtypes import hifloat8
return hifloat8
except ModuleNotFoundError:
raise RuntimeError("en_dtypes is needed to support hifloat8 dtype!!! "
"Please install with `pip3 install en-dtypes`")
except ImportError:
raise RuntimeError("Please upgrade en_dtypes to v0.0.3 at least to support hifloat8 dtype!!! "
"Command is `pip3 install --upgrade en-dtypes`")
def numpy_to_torch(np_arr):
FP8_DTYPE_MAP_NUMPY_TO_TORCH = {
"bfloat16": torch.bfloat16,
"float8_e5m2": torch.float8_e5m2,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e8m0": None if not hasattr(torch, "float8_e8m0fnu") else getattr(torch, "float8_e8m0fnu"),
"hifloat8": torch_npu.hifloat8,
"float4_e2m1": torch_npu.float4_e2m1fn_x2,
}
def _bitcast_float8_to_torch(np_arr):
np_dtype = np_arr.dtype
torch_dtype = FP8_DTYPE_MAP_NUMPY_TO_TORCH[str(np_dtype)]
np_uint8 = np_arr.view(np.uint8)
t_uint8 = torch.from_numpy(np_uint8)
return t_uint8.view(torch_dtype)
if str(np_arr.dtype) in list(FP8_DTYPE_MAP_NUMPY_TO_TORCH.keys()):
return _bitcast_float8_to_torch(np_arr)
return torch.from_numpy(np_arr)
def assert_tensors_close(x: torch.Tensor, y: torch.Tensor, rtol=1e-3, atol=1e-5, label="Tensor"):
if x.device.type != "cpu":
x = x.cpu()
if y.device.type != "cpu":
y = y.cpu()
x_f32 = x.to(torch.float32)
y_f32 = y.to(torch.float32)
nan_mask_x = torch.isnan(x_f32)
nan_mask_y = torch.isnan(y_f32)
if not torch.equal(nan_mask_x, nan_mask_y):
raise AssertionError(
f"[{label}] NaN mismatch: x has NaNs at {torch.where(nan_mask_x)}, y has NaNs at {torch.where(nan_mask_y)}")
valid_mask = ~nan_mask_x
x_valid = x_f32[valid_mask]
y_valid = y_f32[valid_mask]
inf_mask_x = torch.isinf(x_valid)
inf_mask_y = torch.isinf(y_valid)
if not torch.equal(inf_mask_x, inf_mask_y):
raise AssertionError(f"[{label}] Inf mismatch.")
valid_mask_no_inf = ~inf_mask_x
x_final = x_valid[valid_mask_no_inf]
y_final = y_valid[valid_mask_no_inf]
if x_final.numel() == 0:
return
diff = torch.abs(x_final - y_final)
tolerance = atol + (rtol * torch.abs(y_final))
failure_mask = diff > tolerance
if torch.any(failure_mask):
max_diff = diff.max().item()
max_idx = torch.argmax(diff).item()
y_safe = y_final.clone()
y_safe[y_safe == 0] = 1e-6
max_rel_err = (diff / torch.abs(y_safe)).max().item()
error_msg = (
f"\n[{label}] Mismatch found!\n"
f"Type: {x.dtype}\n"
f"Max Abs Diff: {max_diff:.6f} (Threshold: {atol})\n"
f"Max Rel Diff: {max_rel_err:.6f} (Threshold: {rtol})\n"
f"First failure at index {max_idx} (in flattened valid array):\n"
f" x: {x_final[max_idx].item()}\n"
f" y: {y_final[max_idx].item()}\n"
)
raise AssertionError(error_msg)
def simplified_mx_quantize(fp_array: np.ndarray, mx_ele_dtype: str = "float8_e4m3fn") -> tuple:
"""
简化的 MX 量化函数。
输入: fp_array (shape=[n, h], dtype=fp16/bf16)
输出: (scale_array, ele_array)
"""
try:
from ml_dtypes import float8_e5m2, float8_e4m3fn
from en_dtypes import float8_e8m0, float4_e2m1
except ImportError:
raise AssertionError("Unsupported UT testcase due to lack of package ml_dtypes or en_dtypes")
BLOCK_SIZE = 32
AXIS = -1
if mx_ele_dtype == "float8_e5m2":
max_norm = 57344.0
exp_bits, mantissa_bits = 5, 2
target_dtype = float8_e5m2
elif mx_ele_dtype == "float8_e4m3fn":
max_norm = 448.0
exp_bits, mantissa_bits = 4, 3
target_dtype = float8_e4m3fn
elif mx_ele_dtype == "float4_e2m1":
max_norm = float.fromhex("0x1.8p2")
exp_bits, mantissa_bits = 2, 1
target_dtype = float4_e2m1
else:
raise ValueError(f"Unsupported mx_ele_dtype: {mx_ele_dtype}")
orig_shape = fp_array.shape
h_dim = orig_shape[AXIS]
pad_len = (BLOCK_SIZE - (h_dim % BLOCK_SIZE)) % BLOCK_SIZE
if pad_len > 0:
pad_width = [(0, 0)] * fp_array.ndim
pad_width[AXIS] = (0, pad_len)
fp_array = np.pad(fp_array, pad_width, 'constant')
padded_shape = fp_array.shape
new_shape = list(padded_shape)
new_shape[AXIS] = new_shape[AXIS] // BLOCK_SIZE
new_shape.append(BLOCK_SIZE)
fp_array_blocked = fp_array.reshape(new_shape)
ele_emax = int(np.log2(max_norm))
fp_abs_max = np.max(np.abs(fp_array_blocked), axis=-1, keepdims=True)
FP32_MIN_NORMAL = 2 ** (-126)
share_exp = np.floor(
np.log2(fp_abs_max + FP32_MIN_NORMAL * (fp_abs_max == 0))) - ele_emax
share_exp[fp_abs_max == 0] = -float("inf")
SCALE_EMAX = 127
share_exp[share_exp > SCALE_EMAX] = float("NaN")
share_exp[share_exp < -SCALE_EMAX] = -SCALE_EMAX
scale_val = 2.0 ** share_exp
scaled_input = fp_array_blocked / scale_val
min_exp = -(2 ** (exp_bits - 1)) + 2
abs_scaled = np.abs(scaled_input)
private_exp = np.floor(
np.log2(abs_scaled + (abs_scaled == 0))).astype(np.int32)
private_exp = np.clip(private_exp, min_exp, None)
step_scale = 2.0 ** (mantissa_bits - private_exp)
ret = scaled_input * step_scale
ret = np.rint(ret)
ret = ret / step_scale
ret = np.clip(ret, -max_norm, max_norm)
ele_array = ret.reshape(padded_shape)
if pad_len > 0:
ele_array = ele_array[..., :h_dim]
share_exp = np.squeeze(share_exp, axis=-1)
scale_array = 2.0 ** share_exp
if scale_array.shape[-1] % 2 != 0:
scale_array = np.pad(scale_array, ((0, 0), (0, 1)),
'constant', constant_values=2**-127)
s_shape = list(scale_array.shape)
s_shape[-1] //= 2
s_shape.append(2)
scale_array = scale_array.reshape(s_shape)
if float8_e8m0:
scale_array = scale_array.astype(float8_e8m0)
ele_array = ele_array.astype(np.float32)
ele_array = np.nan_to_num(ele_array, nan=0.0)
if target_dtype:
ele_array = ele_array.astype(target_dtype)
return scale_array, ele_array
class MoeInitRoutingV2Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, expert_idx, *, scale=None, offset=None,
active_num=-1, expert_capacity=-1, expert_num=-1, drop_pad_mode=0,
expert_tokens_num_type=0, expert_tokens_num_flag=False,
quant_mode=0, active_expert_range=0, row_idx_type=0):
ddd = dict(x=x, expert_idx=expert_idx, scale=scale, offset=offset,
active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num,
drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
quant_mode=quant_mode, active_expert_range=active_expert_range, row_idx_type=row_idx_type)
for k, v in ddd.items():
print(k, v)
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = \
torch.ops.npu.npu_moe_init_routing_v2(
x, expert_idx, scale=scale, offset=offset,
active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num,
drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type
)
return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale
class TestNpuMoeInitRoutingV2(TestCase):
def adapter_capacity(self, sorted_row_idx, sorted_expert_idx, capacity):
count = 0
last = sorted_expert_idx[0]
for i, val in enumerate(sorted_expert_idx):
if last != val:
count = 1
last = val
else:
count += 1
if count > capacity:
sorted_expert_idx[i] = -1
sorted_row_idx[i] = -1
def cpu_op_exec(self, x, expert_idx, scale, offset, expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity):
expert_start = expert_range[0]
expert_end = expert_range[1]
expert_num = 32
num_rows = x.shape[0]
h = x.shape[1]
k = expert_idx.shape[-1]
expert_idx_in = expert_idx.copy().reshape(-1)
actual_expert_total_num = np.sum(
(expert_idx >= expert_start) & (expert_idx < expert_end))
expert_idx_in[(expert_idx_in < expert_start)
] = np.int32(np.iinfo(np.int32).max)
sorted_expert_indices = np.argsort(
expert_idx_in, axis=-1, kind="stable")
sorted_expert_idx = expert_idx_in[sorted_expert_indices]
if row_idx_type == 1:
expanded_row_idx = sorted_expert_indices.astype(np.int32)
else:
expanded_row_idx = np.ones(num_rows * k).astype(np.int32) * -1
tmp_indices = np.arange(actual_expert_total_num)
expanded_row_idx[sorted_expert_indices[:actual_expert_total_num]] = tmp_indices
if not expert_tokens_num_flag:
expert_tokens_count = None
else:
if drop_pad_mode == 0:
if expert_tokens_num_type == 1:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate(
[expert_tokens_count, np.zeros((expert_end - expert_start) - len(expert_tokens_count)).astype(np.int64)])
elif expert_tokens_num_type == 0:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate(
[expert_tokens_count, np.zeros((expert_end - expert_start) - len(expert_tokens_count)).astype(np.int64)])
expert_tokens_count = np.cumsum(expert_tokens_count)
elif expert_tokens_num_type == 2:
expert_id, counts = np.unique(
sorted_expert_idx[:actual_expert_total_num], return_counts=True)
expert_tokens_count = np.column_stack((expert_id, counts))
if expert_tokens_count.shape[0] < expert_num:
expert_tokens_count = np.concatenate(
(expert_tokens_count, [[0, 0],]), axis=0)
else:
expert_tokens_count = np.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start)
expert_tokens_count = np.concatenate(
[expert_tokens_count, np.zeros((expert_end - expert_start) - len(expert_tokens_count)).astype(np.int64)])
expert_tokens_count = expert_tokens_count.astype(np.int64)
if drop_pad_mode == 0:
if active_num == 0:
active_num = actual_expert_total_num
else:
active_num = min(active_num, actual_expert_total_num)
expanded_scale = None
expanded_x = x[sorted_expert_indices[:active_num] // k, :]
if scale is not None and quant_mode == -1:
expanded_scale = scale[sorted_expert_indices[:active_num] // k]
else:
self.adapter_capacity(sorted_expert_indices,
sorted_expert_idx, expert_capacity)
sort_row_tmp = np.full(
(expert_num * expert_capacity), -1, dtype=int)
offset_tmp = 0
lastExpertId = 0
for i, val in enumerate(sorted_expert_indices):
if val != -1:
if lastExpertId != sorted_expert_idx[i]:
offset_tmp = 0
lastExpertId = sorted_expert_idx[i]
sort_row_tmp[sorted_expert_idx[i] * expert_capacity +
offset_tmp] = sorted_expert_indices[i]
offset_tmp = offset_tmp + 1
expanded_row_idx = np.full(sorted_expert_indices.shape, -1)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_row_idx[val] = i
expanded_x_mask = np.full(
(expert_num * expert_capacity, h), 1, dtype=int)
expanded_x = np.full(
(expert_num * expert_capacity, h), 0, dtype=x.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_x[i] = x[val // k]
expanded_x_mask[i] = np.full((h,), 0, dtype=int)
if quant_mode == -1:
expanded_x = expanded_x
expanded_row_idx = expanded_row_idx
if scale is not None and drop_pad_mode == 1:
expanded_scale = np.full(
(expert_num * expert_capacity,), 0, dtype=scale.dtype)
for i, val in enumerate(sort_row_tmp):
if val != -1:
expanded_scale[i] = scale[val // k]
if scale is None:
expanded_scale = None
if quant_mode == 0:
expanded_scale = None
expanded_x_fp16 = expanded_x.astype(np.float16)
scale_val = scale.astype(np.float16)
offset_val = offset.astype(np.float16)
scale_rst = expanded_x_fp16 * scale_val[0]
add_offset = scale_rst + offset_val[0]
round_data = np.rint(add_offset)
round_data = np.clip(round_data, -128, 127)
expanded_x = round_data.astype(np.int8)
if quant_mode == 1:
x_final = expanded_x.astype(np.float32)
if scale is None:
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
else:
if scale.shape[0] == 1:
x_final = x_final * scale
else:
if drop_pad_mode == 0:
x_final = x_final * \
scale[sorted_expert_idx[:active_num] - expert_start]
else:
for i, val in enumerate(sort_row_tmp):
if val != -1:
x_final[i] = x_final[i] * \
scale[i // expert_capacity]
x_abs = np.abs(x_final)
x_max = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max / 127
expanded_x = x_final / expanded_scale
expanded_x = np.round(expanded_x).astype(np.int8)
if x.dtype == np.int8:
expanded_scale = None
if quant_mode == 2 or quant_mode == 3:
quant_mode_dtype_str_map = {2: "float8_e5m2", 3: "float8_e4m3fn"}
expanded_scale, expanded_x = simplified_mx_quantize(
expanded_x, mx_ele_dtype=quant_mode_dtype_str_map[quant_mode])
ess = expanded_scale.shape
expanded_scale = expanded_scale.reshape(
*ess[:-2], ess[-2] * ess[-1])
elif quant_mode == 6:
expanded_x = expanded_x.astype(numpy_hifloat8())
expanded_scale = None
elif quant_mode == 7:
HIFLOAT8_MIN = -32768.0
HIFLOAT8_MAX = 32768.0
expanded_x = expanded_x.astype(np.float32)
expanded_x = expanded_x * scale
expanded_x = np.clip(expanded_x, HIFLOAT8_MIN, HIFLOAT8_MAX)
expanded_x = expanded_x.astype(numpy_hifloat8(), copy=False)
expanded_scale = None
elif quant_mode == 8:
HIFLOAT8_MIN = -32768.0
HIFLOAT8_MAX = 32768.0
expanded_x = expanded_x.astype(np.float32)
x_abs = np.abs(expanded_x)
x_max_per_token = np.max(x_abs, axis=-1, keepdims=True)
expanded_scale = x_max_per_token / 32768.0
expand_scale_safe = np.where(expanded_scale == 0, 1.0, expanded_scale)
expanded_x = expanded_x / expand_scale_safe
expanded_x = np.clip(expanded_x, HIFLOAT8_MIN, HIFLOAT8_MAX)
expanded_x = expanded_x.astype(numpy_hifloat8(), copy=False)
elif quant_mode == 9:
expanded_scale, expanded_x = simplified_mx_quantize(
expanded_x, mx_ele_dtype=quant_mode_dtype_str_map[quant_mode])
elif quant_mode == 11 or quant_mode == 12:
quant_mode_dtype_str_map = {11: "float8_e5m2", 12: "float8_e4m3fn"}
from ml_dtypes import float8_e5m2, float8_e4m3fn
def block_max_with_padding(x, row_block_size, col_block_size):
rows, cols = x.shape
pad_rows = (row_block_size - rows % row_block_size) % row_block_size
pad_cols = (col_block_size - cols % col_block_size) % col_block_size
x_padded = np.pad(x, ((0, pad_rows), (0, pad_cols)),
mode='constant', constant_values=0)
padded_rows, padded_cols = x_padded.shape
row_blocks = padded_rows // row_block_size
col_blocks = padded_cols // col_block_size
result = np.zeros((row_blocks, col_blocks))
for i in range(row_blocks):
for j in range(col_blocks):
block = x_padded[i * row_block_size:(i + 1) * row_block_size,
j * col_block_size:(j + 1) * col_block_size]
result[i, j] = np.max(block)
return result
def pad_to_even_zero(tensor, axis):
if tensor.shape[axis] % 2 == 0:
return tensor
pad_width = [(0, 0)] * tensor.ndim
pad_width[axis] = (0, 1)
return np.pad(tensor, pad_width, mode='constant', constant_values=0)
def get_dtype_range(dt):
if "float8_e5m2" in str(dt):
return -float.fromhex("0x1.Cp15"), float.fromhex("0x1.Cp15")
if "float8_e4m3fn" in str(dt):
return -float.fromhex("0x1.Cp8"), float.fromhex("0x1.Cp8")
numpy_dtype = np.dtype(dt)
if numpy_dtype.kind in "iu":
numpy_info = np.iinfo(numpy_dtype)
else:
numpy_info = np.finfo(numpy_dtype)
return np.min, np.max
row_block_size = 1
col_block_size = 128
dst_type_str = quant_mode_dtype_str_map[quant_mode]
max_value = 0
if dst_type_str == 'float8_e5m2':
max_value = (2 - pow(2, -2)) * pow(2, 15)
elif dst_type_str == 'float8_e4m3fn':
max_value = (2 - pow(2, -2)) * pow(2, 8)
x_abs = np.abs(expand_x)
block_max = block_max_with_padding(x_abs, row_block_size, col_block_size)
block_max_f32 = block_max.astype(np.float32)
scale = block_max_f32 / max_value
min_normal_fp32 = np.finfo(np.float32).tiny
scale = np.where(scale < min_normal_fp32, 0, scale)
scale = pad_to_even_zero(scale, axis=2)
scale_expanded = np.zeros_like(expand_x).astype(np.float32)
for i in range(scale.shape[0]):
for j in range(scale.shape[1]):
scale_expanded[i * row_block_size:(i + 1) * row_block_size,
j * col_block_size:(j + 1) * col_block_size] = scale[i, j]
x_f32 = expand_x.astype(np.float32)
out_f32 = x_f32 / scale_expanded
max_norm = get_dtype_range(dst_type_str)[1]
np.clip(out_f32, a_min=-max_norm, a_max=max_norm, out=out_f32)
output_scale = scale.astype("float32")
round_data = np.round(out_f32, 8)
round_data = np.nan_to_num(round_data, nan=0.0, copy=False)
if dst_type_str == 'float8_e5m2':
round_data = round_data.astype(float8_e5m2, copy=False)
elif dst_type_str == 'float8_e4m3fn':
round_data = round_data.astype(float8_e4m3fn, copy=False)
expand_x = round_data
expand_scales = output_scale
if drop_pad_mode == 1:
expanded_x = np.ma.array(
expanded_x, mask=expanded_x_mask).filled(0)
expanded_x = expanded_x.reshape(expert_num, expert_capacity, h)
return expanded_x, expanded_row_idx.astype(np.int32), expert_tokens_count, expanded_scale
def npu_op_exec(self, x, expert_idx, scale, offset, active_expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity, x_dtype=None):
expert_num = 32
expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
x, expert_idx, scale=scale, offset=offset,
active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num,
drop_pad_mode=drop_pad_mode,
expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type,
x_dtype=x_dtype
)
return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale
def assertExpandedXRtolEqual(self, expanded_x, local_expanded_x_npu, dtype):
if dtype == torch.bfloat16:
self.assertRtolEqual(torch.tensor(
expanded_x, dtype=torch.bfloat16), local_expanded_x_npu)
elif dtype == np.int8:
self.assertEqual(expanded_x.shape, local_expanded_x_npu.shape)
self.assertEqual(np.int8, local_expanded_x_npu.numpy().dtype)
max_diff = np.abs(expanded_x - local_expanded_x_npu.numpy()).max()
self.assertLessEqual(max_diff, 1)
else:
self.assertRtolEqual(expanded_x, local_expanded_x_npu.numpy())
def generate_inputs(self, bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode):
if dtype == torch.bfloat16:
x = np.random.uniform(-1, 1, size=(bs, h)).astype(np.float32)
x_npu = torch.tensor(x, dtype=torch.bfloat16).npu()
elif dtype == np.int8:
x = np.random.uniform(-127, 128, size=(bs, h)).astype(dtype)
x_npu = torch.from_numpy(x).npu()
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
x = torch.randn((bs, h), dtype = torch.float32).to(dtype)
x_npu = x.npu()
elif dtype == torch_npu.float4_e2m1fn_x2:
x = torch.randn((bs, h), dtype = torch.float32).to(torch.uint8)
x_npu = x.npu()
else:
x = np.random.uniform(-1, 1, size=(bs, h)).astype(dtype)
x_npu = torch.from_numpy(x).npu()
expert_idx = np.random.randint(0, 32, size=(bs, k)).astype(np.int32)
scale_npu = None
if none_scale:
scale = None
else:
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2 or dtype == torch_npu.float4_e2m1fn_x2:
scale = torch.randint(0, 256, scale_shape, dtype=torch.uint8)
scale = scale.view(torch.float8_e8m0fnu)
scale_npu = scale.npu()
else:
scale = np.random.uniform(-1, 1, size=scale_shape).astype(np.float32)
scale_npu = torch.from_numpy(scale).contiguous().npu()
offset = None if none_offset or none_scale else np.random.uniform(
-1, 1, size=scale_shape).astype(np.float32)
expert_idx_npu = torch.from_numpy(expert_idx).npu()
offset_npu = None if offset is None else torch.from_numpy(
offset).contiguous().npu()
expert_tokens_num_type = 1 if drop_pad_mode == 1 else random.choice([
0, 1, 2])
row_idx_type = 0 if drop_pad_mode == 1 else random.choice([0, 1])
active_num = bs * k
expert_capacity = -1 if drop_pad_mode == 0 else random.randint(1, bs)
return x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity
def calc_npu_vs_golden(self, x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu, expert_range,
quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity, x_dtype=None):
expanded_x_npu, expanded_row_idx_npu, expert_tokens_count_npu, expanded_scale_npu = self.npu_op_exec(
x_npu, expert_idx_npu, scale=scale_npu, offset=offset_npu,
active_expert_range=expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type, expert_tokens_num_flag=expert_tokens_num_flag, expert_tokens_num_type=expert_tokens_num_type,
drop_pad_mode=drop_pad_mode, active_num=active_num, expert_capacity=expert_capacity, x_dtype=x_dtype)
expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale = self.cpu_op_exec(
x, expert_idx, scale=scale, offset=offset,
expert_range=expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type, expert_tokens_num_flag=expert_tokens_num_flag, expert_tokens_num_type=expert_tokens_num_type,
drop_pad_mode=drop_pad_mode, active_num=active_num, expert_capacity=expert_capacity)
local_expanded_x_npu = expanded_x_npu.cpu()
local_expanded_row_idx_npu = expanded_row_idx_npu.cpu()
local_expert_tokens_count_npu = expert_tokens_count_npu.cpu()
local_expanded_scale_npu = expanded_scale_npu.cpu()
actual_expert_total_num = np.sum(
(expert_idx >= expert_range[0]) & (expert_idx < expert_range[1]))
if drop_pad_mode == 0:
if active_num == 0:
actual_expert_count = actual_expert_total_num
else:
actual_expert_count = min(active_num, actual_expert_total_num)
local_expanded_x_npu = local_expanded_x_npu[:actual_expert_count]
local_expanded_scale_npu = local_expanded_scale_npu[:actual_expert_count]
if expert_tokens_num_flag and expert_tokens_num_type == 2:
length = expert_tokens_count.shape[0]
local_expert_tokens_count_npu = local_expert_tokens_count_npu[:length]
return expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu
@SupportedDevices(['Ascend910B'])
def test_npu_moe_init_routing_empty_tensor(self):
bs_list = [0]
h_list = [8192]
k_list = [12]
expert_range_list = [[0, 768]]
quant_mode_list = [1]
drop_pad_mode_list = [0]
row_idx_type_list = [1]
expert_tokens_num_flags = [True]
dtype_list = [torch.bfloat16]
none_scales = [False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs,)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
x_dtype = None
if dtype == torch_npu.hifloat8:
x_dtype = torch_npu.hifloat8
expanded_x_npu, expanded_row_idx_npu, expert_tokens_count_npu, expanded_scale_npu = self.npu_op_exec(
x_npu, expert_idx_npu, scale=scale_npu, offset=offset_npu,
active_expert_range=expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type, expert_tokens_num_flag=expert_tokens_num_flag, expert_tokens_num_type=expert_tokens_num_type,
drop_pad_mode=drop_pad_mode, active_num=active_num, expert_capacity=expert_capacity, x_dtype=x_dtype)
assert torch.numel(expanded_x_npu) == 0
assert torch.numel(expanded_row_idx_npu) == 0
assert torch.all(expert_tokens_count_npu == 0)
assert torch.numel(expanded_scale_npu) == 0
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_moe_init_routing_no_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [5]
expert_range_list = [[0, 16]]
quant_mode_list = [-1]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [np.int8, np.float16, np.float32, torch.bfloat16, torch_npu.hifloat8]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs,)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
if drop_pad_mode == 1 or not expert_tokens_num_flag or expert_tokens_num_type == 0:
continue
x_dtype = None
if dtype == torch_npu.hifloat8:
x_dtype = torch_npu.hifloat8
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity, x_dtype)
self.assertExpandedXRtolEqual(
expanded_x, local_expanded_x_npu, dtype)
self.assertRtolEqual(
expanded_row_idx, local_expanded_row_idx_npu.numpy())
if expert_tokens_num_flag:
self.assertRtolEqual(expert_tokens_count,
local_expert_tokens_count_npu.numpy())
if none_scale:
return
self.assertRtolEqual(
expanded_scale, local_expanded_scale_npu.numpy())
@unittest.skip("Skipping test_npu_moe_init_routing_static_quant for now")
@SupportedDevices(['Ascend910B'])
def test_npu_moe_init_routing_static_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [5, 128]
expert_range_list = [[0, 16]]
quant_mode_list = [0]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [np.float16, np.float32, torch.bfloat16]
none_scales = [False]
none_offsets = [False]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (1,)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
self.assertExpandedXRtolEqual(
expanded_x, local_expanded_x_npu, np.int8)
self.assertRtolEqual(
expanded_row_idx, local_expanded_row_idx_npu.numpy())
if expert_tokens_num_flag:
self.assertRtolEqual(expert_tokens_count,
local_expert_tokens_count_npu.numpy())
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_moe_init_routing_dynamic_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [8]
expert_range_list = [[0, 16]]
quant_mode_list = [1]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [np.float16, np.float32, torch.bfloat16]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
expert_range_length = expert_range[1] - expert_range[0]
scale_shape = (expert_range_length, h)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
if drop_pad_mode == 1 or not expert_tokens_num_flag or expert_tokens_num_type == 0:
continue
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
self.assertExpandedXRtolEqual(
expanded_x, local_expanded_x_npu, np.int8)
self.assertRtolEqual(
expanded_row_idx, local_expanded_row_idx_npu.numpy())
if expert_tokens_num_flag:
self.assertRtolEqual(expert_tokens_count,
local_expert_tokens_count_npu.numpy())
if none_scale:
return
self.assertRtolEqual(
expanded_scale, local_expanded_scale_npu.numpy())
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_int4_dynamic_quant_mode_13(self):
bs, h, k = 4, 16, 2
expert_num = 8
active_num = -1
expected_rows = bs * k
expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu()
for dtype, scale in [
(torch.float32, None),
(torch.float32, torch.ones((1, h), dtype=torch.float32).npu()),
(torch.bfloat16, None),
(torch.bfloat16, torch.ones((1, h), dtype=torch.float32).npu()),
]:
x = torch.randn((bs, h), dtype=dtype).npu()
expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
x,
expert_idx,
scale=scale,
offset=None,
active_num=active_num,
expert_capacity=-1,
expert_num=expert_num,
drop_pad_mode=0,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
quant_mode=13,
active_expert_range=[0, expert_num],
row_idx_type=0,
x_dtype=torch_npu.int4,
)
self.assertEqual(expanded_x.dtype, torch.uint8)
self.assertEqual(tuple(expanded_x.shape), (expected_rows, h // 2))
self.assertEqual(expanded_row_idx.dtype, torch.int32)
self.assertEqual(tuple(expanded_row_idx.shape), (expected_rows,))
self.assertEqual(expert_tokens_count.dtype, torch.int64)
self.assertEqual(tuple(expert_tokens_count.shape), (expert_num,))
self.assertEqual(expanded_scale.dtype, torch.float32)
self.assertEqual(tuple(expanded_scale.shape), (expected_rows,))
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_int4_dynamic_quant_reject_legacy_mode(self):
bs, h, k = 4, 16, 2
expert_num = 8
x = torch.randn((bs, h), dtype=torch.float32).npu()
expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu()
with self.assertRaisesRegex(RuntimeError, "quant_mode=13"):
torch_npu.npu_moe_init_routing_v2(
x,
expert_idx,
scale=None,
offset=None,
active_num=-1,
expert_capacity=-1,
expert_num=expert_num,
drop_pad_mode=0,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
quant_mode=1,
active_expert_range=[0, expert_num],
row_idx_type=0,
x_dtype=torch_npu.int4,
)
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_hif8_no_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [5]
expert_range_list = [[0, 16]]
quant_mode_list = [-1]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [torch_npu.hifloat8]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs,)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
if drop_pad_mode == 1 or not expert_tokens_num_flag or expert_tokens_num_type == 0:
continue
x_dtype = torch_npu.hifloat8
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu, expert_range,
quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity, x_dtype)
self.assertExpandedXRtolEqual(
expanded_x, local_expanded_x_npu, dtype)
self.assertRtolEqual(
expanded_row_idx, local_expanded_row_idx_npu.numpy())
if expert_tokens_num_flag:
self.assertRtolEqual(expert_tokens_count,
local_expert_tokens_count_npu.numpy())
if none_scale:
return
self.assertRtolEqual(
expanded_scale, local_expanded_scale_npu.numpy())
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_mxfp_no_quant(self):
ALIGN_BASE = 64
MXFP8_SCALE_THIRD_DIM_SIZE = 2
bs_list = [32]
h_list = [14, 200]
k_list = [5]
expert_range_list = [[0, 16]]
quant_mode_list = [-1]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [torch.float8_e4m3fn, torch.float8_e5m2, torch_npu.float4_e2m1fn_x2]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs, (h + ALIGN_BASE -1) // ALIGN_BASE, MXFP8_SCALE_THIRD_DIM_SIZE)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
if drop_pad_mode == 1 or not expert_tokens_num_flag or expert_tokens_num_type == 0:
continue
x_dtype = None
if dtype == torch_npu.float4_e2m1fn_x2:
x_dtype = torch_npu.float4_e2m1fn_x2
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu, expert_range,
quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity, x_dtype)
assert_tensors_close(numpy_to_torch(
expanded_x), local_expanded_x_npu)
assert_tensors_close(numpy_to_torch(
expanded_row_idx), local_expanded_row_idx_npu)
if expert_tokens_num_flag:
assert_tensors_close(numpy_to_torch(expert_tokens_count),
local_expert_tokens_count_npu)
if none_scale:
return
assert_tensors_close(numpy_to_torch(
expanded_scale), local_expanded_scale_npu)
@unittest.skipIf(
torch.__version__ < "2.7" or
importlib.util.find_spec("ml_dtypes") is None or
importlib.util.find_spec("en_dtypes") is None,
"Unittest for mxfp8 need torch.version>=2.7 and package ml_dtypes, en_dtypes"
)
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_mxfp_quant(self):
bs_list = [32]
h_list = [7168, 7184]
k_list = [8]
expert_range_list = [[0, 32]]
quant_mode_list = [2, 3, 9]
drop_pad_mode_list = [0]
row_idx_type_list = [0, 1]
expert_tokens_num_type = 1
expert_tokens_num_flags = [True]
dtype_list = [torch.float16, torch.bfloat16]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, x_dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs*k, h)
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, _, _, active_num, expert_capacity = self.generate_inputs(
bs, h, k, x_dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
assert_tensors_close(numpy_to_torch(
expanded_x), local_expanded_x_npu)
assert_tensors_close(numpy_to_torch(
expanded_row_idx), local_expanded_row_idx_npu)
assert_tensors_close(numpy_to_torch(
expert_tokens_count), local_expert_tokens_count_npu)
assert_tensors_close(numpy_to_torch(
expanded_scale), local_expanded_scale_npu)
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_hif8_pertensor_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [5, 128]
expert_range_list = [[0, 16]]
quant_mode_list = [7]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [torch.float16, torch.bfloat16]
none_scales = [False]
none_offsets = [False]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (1,)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
assert_tensors_close(numpy_to_torch(
expanded_x), local_expanded_x_npu)
assert_tensors_close(numpy_to_torch(
expanded_row_idx), local_expanded_row_idx_npu)
assert_tensors_close(numpy_to_torch(
expert_tokens_count), local_expert_tokens_count_npu)
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_hif8_pertoken_quant(self):
bs_list = [32]
h_list = [14, 200]
k_list = [8]
expert_range_list = [[0, 16]]
quant_mode_list = [8]
drop_pad_mode_list = [0, 1]
row_idx_type_list = [0, 1]
expert_tokens_num_flags = [True, False]
dtype_list = [torch.float16, torch.bfloat16]
none_scales = [True, False]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
expert_range_length = expert_range[1] - expert_range[0]
scale_shape = (expert_range_length, h)
expert_range = [0, 32] if drop_pad_mode == 1 else [0, 16]
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, expert_tokens_num_type, row_idx_type, active_num, expert_capacity = self.generate_inputs(
bs, h, k, dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
if drop_pad_mode == 1 or not expert_tokens_num_flag or expert_tokens_num_type == 0:
continue
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
assert_tensors_close(numpy_to_torch(
expanded_x), local_expanded_x_npu)
assert_tensors_close(numpy_to_torch(
expanded_row_idx), local_expanded_row_idx_npu)
assert_tensors_close(numpy_to_torch(
expert_tokens_count), local_expert_tokens_count_npu)
assert_tensors_close(numpy_to_torch(
expanded_scale), local_expanded_scale_npu)
@SupportedDevices(['Ascend950'])
def test_npu_moe_init_routing_fp8_perblock_quant(self):
bs_list = [32]
h_list = [7168, 7184]
k_list = [8]
expert_range_list = [[0, 32]]
quant_mode_list = [11, 12]
drop_pad_mode_list = [0]
row_idx_type_list = [0, 1]
expert_tokens_num_type = 1
expert_tokens_num_flags = [True]
dtype_list = [torch.float16, torch.bfloat16]
none_scales = [True]
none_offsets = [True]
for bs, h, k, expert_range, quant_mode, row_idx_type, x_dtype, none_scale, none_offset, expert_tokens_num_flag, drop_pad_mode in itertools.product(
bs_list, h_list, k_list, expert_range_list, quant_mode_list, row_idx_type_list,
dtype_list, none_scales, none_offsets, expert_tokens_num_flags, drop_pad_mode_list):
scale_shape = (bs*k, h)
x, expert_idx, scale, offset, x_npu, expert_idx_npu, scale_npu, offset_npu, _, _, active_num, expert_capacity = self.generate_inputs(
bs, h, k, x_dtype, scale_shape, none_scale, none_offset, drop_pad_mode)
expanded_x, local_expanded_x_npu, expanded_row_idx, local_expanded_row_idx_npu, \
expert_tokens_count, local_expert_tokens_count_npu, expanded_scale, local_expanded_scale_npu \
= self.calc_npu_vs_golden(x, expert_idx, scale, offset,
x_npu, expert_idx_npu, scale_npu, offset_npu,
expert_range, quant_mode, row_idx_type, expert_tokens_num_flag, expert_tokens_num_type, drop_pad_mode, active_num, expert_capacity)
assert_tensors_close(numpy_to_torch(
expanded_x), local_expanded_x_npu)
assert_tensors_close(numpy_to_torch(
expanded_row_idx), local_expanded_row_idx_npu)
assert_tensors_close(numpy_to_torch(
expert_tokens_count), local_expert_tokens_count_npu)
assert_tensors_close(numpy_to_torch(
expanded_scale), local_expanded_scale_npu)
if __name__ == "__main__":
run_tests()