"""
"""
import os
import torch
import torch_npu
import numpy as np
from numpy.testing import assert_allclose
from torch._subclasses.fake_tensor import FakeTensor
from torch._dynamo import allow_in_graph
import pypto
import pytest
from utils.get_format import get_format
from glm_ffn_common_interface import symmetric_quantization_per_token, dequant_dynamic, swiglu
def check_cond(cond, msg):
if not cond:
raise ValueError(msg)
def powers_of_2(n: int) -> set[int]:
check_cond(n > 0, "n must be positive")
result = set()
power = 0
while True:
current = 1 << power
if current > n:
break
result.add(current)
power += 1
return result
def check_args(
gate_weight: torch.Tensor,
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: int,
num_expert_group: int,
e_score_correction_bias: torch.Tensor,
w13,
w13_scale,
w2,
w2_scale
) -> None:
check_cond(gate_weight.dim() == 2, "invalid gate weight dim.")
check_cond(gate_weight.shape[0] == 160, "invalid gate weight shape.")
check_cond(gate_weight.shape[1] == 5120, "invalid gate weight shape.")
check_cond(get_format(gate_weight) == 'ND', "invalid gate weight format.")
check_cond((gate_weight.dtype == torch.float32), "invalid gate weight dtype.")
check_cond(hidden_states.dim() == 2, "invalid hidden states dim.")
check_cond(hidden_states.shape[1] == 5120, "invalid hidden states shape.")
check_cond(get_format(hidden_states) == 'ND', "invalid hidden states format.")
check_cond(hidden_states.dtype == torch.bfloat16, "invalid hidden states dtype.")
check_cond(e_score_correction_bias.dim() == 1, "invalid bias dim.")
check_cond(e_score_correction_bias.shape[0] == 160, "invalid bias shape.")
check_cond(get_format(e_score_correction_bias) == 'ND', "invalid bias format.")
check_cond(e_score_correction_bias.dtype == torch.bfloat16, "invalid bias dtype.")
check_cond(isinstance(top_k, int), "invalid topk dtype.")
check_cond(isinstance(renormalize, bool), "invalid renormalize dtype.")
check_cond(isinstance(topk_group, int), "invalid topk_group dtype.")
check_cond(isinstance(num_expert_group, int), "invalid num_expert_group dtype.")
check_cond(w13.dim() == 2, "invalid w13 dim.")
check_cond(w13.shape[0] == 5120, "invalid w13 shape.")
check_cond(w13.shape[1] == 384, "invalid w13 shape.")
check_cond(get_format(w13) == 'NZ', "invalid w13 format.")
check_cond(w13.dtype == torch.int8, "invalid w13 dtype.")
check_cond(w13_scale.dim() == 1, "invalid w13_scale dim.")
check_cond(w13_scale.shape[0] == 384, "invalid w13_scale shape.")
check_cond(get_format(w13_scale) == 'ND', "invalid w13_scale format.")
check_cond(w13_scale.dtype == torch.bfloat16, "invalid w13_scale dtype.")
check_cond(w2.dim() == 2, "invalid w2 dim.")
check_cond(w2.shape[0] == 192, "invalid w2 shape.")
check_cond(w2.shape[1] == 5120, "invalid w2 shape.")
check_cond(get_format(w2) == 'NZ', "invalid w2 format.")
check_cond(w2.dtype == torch.int8, "invalid w2 dtype.")
check_cond(w2_scale.dim() == 1, "invalid w2_scale dim.")
check_cond(w2_scale.shape[0] == 5120, "invalid w2_scale shape.")
check_cond(get_format(w2_scale) == 'ND', "invalid w2_scale format.")
check_cond(w2_scale.dtype == torch.bfloat16, "invalid hidden states dtype.")
def gen_quan_per_channel_weight_nz(x):
x_fp32 = x.to(torch.float32)
max_value = x_fp32.abs().max(dim=0, keepdim=True)[0]
scale_quant = 127.0 / max_value
y_fp32 = x_fp32 * scale_quant
y_rint = torch.round(y_fp32).to(torch.int32)
y_round = torch.round(y_rint).to(torch.float16)
y_int8 = torch.trunc(y_round).to(torch.int8)
y_int8_nz = torch_npu.npu_format_cast(y_int8, 29)
scale_dequant = (1 / scale_quant)
return y_int8_nz, scale_dequant
ND = pypto.TileOpFormat.TILEOP_ND
NZ = pypto.TileOpFormat.TILEOP_NZ
@pypto.frontend.jit(
runtime_options={"device_sched_mode": 1,
"stitch_function_max_num": 128},
pass_options={"cube_l1_reuse_setting": {-1: 2},
"cube_nbuffer_setting": {2: 2}}
)
def moe_fusion_kernel(
hidden_states: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16, format=ND),
mm_weight: pypto.Tensor([], pypto.DT_FP32, format=ND),
e_score_bias_input: pypto.Tensor([], pypto.DT_BF16, format=ND),
w13: pypto.Tensor([], pypto.DT_INT8, format=NZ),
w13_scale: pypto.Tensor([], pypto.DT_BF16, format=ND),
w2: pypto.Tensor([], pypto.DT_INT8, format=NZ),
w2_scale: pypto.Tensor([], pypto.DT_BF16, format=ND),
weight_k: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32, format=ND),
ids_k: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_INT32, format=ND),
ffn_res: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16, format=ND),
topk_group,
num_expert_group,
):
bs = hidden_states.shape[0]
ne = mm_weight.shape[0]
topk = ids_k.shape[1]
pypto.experimental.set_operation_options(combine_axis=True)
vec_tile_shape = (4, 5120)
mm1_cube_tile_shape = (8, 256, 192)
mm2_cube_tile_shape = (8, 192, 320)
hidden_size = hidden_states.shape[1]
intermediate_size = w2.shape[0]
pypto.set_vec_tile_shapes(ne)
e_score_bias_2d = pypto.reshape(e_score_bias_input, [1, ne], inplace=True)
pypto.set_vec_tile_shapes(intermediate_size * 2)
w13_scale_2d = pypto.reshape(w13_scale, [1, intermediate_size * 2], inplace=True)
pypto.set_vec_tile_shapes(hidden_size)
w2_scale_2d = pypto.reshape(w2_scale, [1, hidden_size], inplace=True)
for bs_idx, tile_batch in pypto.loop_unroll(0, bs, 1, name="LOOP_MOE_FUSION_L0", idx_name="bs_idx",
unroll_list=powers_of_2(16)):
tile_hidden_states = hidden_states[bs_idx:bs_idx + tile_batch, :]
pypto.set_pass_options(sg_set_scope=(10001, False, True))
res = pypto.tensor([tile_batch, ne], pypto.DT_FP32, "res")
for tmp_idx in range(ne // 16):
pypto.set_vec_tile_shapes(min(tile_batch, 32), 1024)
tile_hidden_states_fp32 = pypto.cast(tile_hidden_states, pypto.DT_FP32)
pypto.set_vec_tile_shapes(16, 1024)
tile_mm_weight_fp32 = pypto.cast(mm_weight[tmp_idx * 16:tmp_idx * 16 + 16, :], pypto.DT_FP32)
pypto.set_cube_tile_shapes([min(tile_batch, 32), min(tile_batch, 32)], [512, 1024], [16, 16])
res_tmp = pypto.matmul(tile_hidden_states_fp32, tile_mm_weight_fp32,
tile_hidden_states_fp32.dtype, b_trans=True)
pypto.assemble(res_tmp, [0, tmp_idx * 16], res)
pypto.set_pass_options(sg_set_scope=-1)
tile_logits = res
view_first = 1
pypto.set_vec_tile_shapes(view_first, ne)
tile_logits_fp32 = pypto.cast(tile_logits, pypto.DT_FP32)
e_score_bias_2d_tile = pypto.tensor([tile_batch, ne], e_score_bias_2d.dtype, "e_score_bias_2d_tile")
for tmp_idx in range(tile_batch):
pypto.assemble(e_score_bias_2d, [tmp_idx, 0], e_score_bias_2d_tile)
e_score_bias_2d_cast = pypto.cast(e_score_bias_2d_tile, tile_logits_fp32.dtype)
topk_weights = pypto.sigmoid(tile_logits_fp32)
topk_weights_add = pypto.add(topk_weights, e_score_bias_2d_cast)
group_unit = ne // num_expert_group
r1 = pypto.reshape(topk_weights_add, [tile_batch, num_expert_group, group_unit])
pypto.set_vec_tile_shapes(view_first, num_expert_group, group_unit)
max1 = pypto.amax(r1, -1, False)
group_weight = max1
pypto.set_vec_tile_shapes(view_first, num_expert_group)
_, topk_group_indices = pypto.topk(group_weight, topk_group, -1, True)
topk_group_mask = pypto.full([tile_batch, num_expert_group], 0.0, group_weight.dtype)
topk_group_mask_scatter_trans = pypto.scatter_(topk_group_mask, 1, topk_group_indices, 1.0)
twm_unsqueeze = pypto.unsqueeze(topk_group_mask_scatter_trans, -1)
pypto.set_vec_tile_shapes(view_first, num_expert_group, ne)
twm_expand = pypto.expand_clone(twm_unsqueeze, [tile_batch, num_expert_group, group_unit])
pypto.set_vec_tile_shapes(view_first, num_expert_group, group_unit)
twm_reshape = pypto.reshape(twm_expand, [tile_batch, ne])
pypto.set_vec_tile_shapes(view_first, ne)
twm_not = pypto.logical_not(twm_reshape)
topk_weights_maskfill = pypto.where(twm_not, 0.0, topk_weights_add)
_, topk_ids = pypto.topk(topk_weights_maskfill, topk, -1, True)
tw_gather = pypto.gather(topk_weights, 1, topk_ids)
pypto.set_vec_tile_shapes(view_first, topk)
denominator = pypto.sum(tw_gather, -1, True)
topk_weight_out = pypto.div(tw_gather, denominator)
weight_k[bs_idx:bs_idx + tile_batch, :] = topk_weight_out
ids_k[bs_idx:bs_idx + tile_batch, :] = topk_ids
pypto.set_vec_tile_shapes(1, intermediate_size * 2)
w13_scale_2d_fp32 = pypto.cast(w13_scale_2d, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
pypto.set_vec_tile_shapes(1, hidden_size)
w2_scale_2d_fp32 = pypto.cast(w2_scale_2d, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
pypto.set_vec_tile_shapes(vec_tile_shape[0], vec_tile_shape[1])
hidden_states_offset = [bs_idx, 0]
k_split_shape = 512
k_split = hidden_size // k_split_shape
matmul_res = []
up_proj = pypto.full([tile_batch, intermediate_size * 2], 0, pypto.DT_INT32)
hidden_states_scale = pypto.tensor([tile_batch, 1], pypto.DT_FP32, "hidden_states_scale")
for ki in range(k_split):
pypto.set_pass_options(sg_set_scope=(10002, False, True))
pypto.set_vec_tile_shapes(vec_tile_shape[0], vec_tile_shape[1])
x_fp32 = pypto.cast(tile_hidden_states, pypto.DT_FP32)
x_abs = pypto.abs(x_fp32)
x_max = pypto.amax(x_abs, -1, True)
pypto.set_vec_tile_shapes(tile_batch, k_split_shape)
shape_0, shape_1 = x_max.shape[:2]
x_scale = pypto.div(pypto.full([shape_0, shape_1], 127.0, pypto.DT_FP32), x_max)
x_mul = pypto.mul(x_fp32[:, ki * k_split_shape:ki * k_split_shape + k_split_shape], x_scale)
x_int32 = pypto.cast(x_mul, pypto.DT_INT32, pypto.CastMode.CAST_RINT)
x_fp16 = pypto.cast(x_int32, pypto.DT_FP16, pypto.CastMode.CAST_ROUND)
hidden_states_quant = pypto.cast(x_fp16, pypto.DT_INT8,
pypto.CastMode.CAST_TRUNC, satmode=pypto.SaturationMode.ON)
hidden_states_scale_tmp = pypto.div(pypto.full([shape_0, shape_1], 1.0, pypto.DT_FP32), x_scale)
pypto.assemble(hidden_states_scale_tmp, [0, 0], hidden_states_scale)
kn = pypto.view(w13, [k_split_shape, intermediate_size * 2], [ki * k_split_shape, 0])
pypto.set_cube_tile_shapes([tile_batch, tile_batch],
[mm1_cube_tile_shape[1], mm1_cube_tile_shape[1] * 2],
[mm1_cube_tile_shape[2], mm1_cube_tile_shape[2]])
res_tmp = pypto.matmul(hidden_states_quant, kn, pypto.DT_INT32)
pypto.set_pass_options(sg_set_scope=-1)
matmul_res.append(res_tmp)
for ki in range(k_split):
pypto.set_pass_options(sg_set_scope=(1, False, False))
pypto.set_vec_tile_shapes(8, intermediate_size * 2)
up_proj = pypto.add(up_proj, matmul_res[ki])
pypto.set_vec_tile_shapes(8, intermediate_size * 2)
up_proj_fp32 = pypto.cast(up_proj, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
hidden_states_scale_fp32 = pypto.cast(hidden_states_scale, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
up_proj_dequant_scale_2 = pypto.mul(up_proj_fp32, hidden_states_scale_fp32)
up_proj_dequant = pypto.mul(up_proj_dequant_scale_2, w13_scale_2d_fp32)
swiglu_out = swiglu(up_proj_dequant)
down_proj_quant, down_proj_scale = symmetric_quantization_per_token(swiglu_out)
pypto.set_pass_options(sg_set_scope=-1)
pypto.set_cube_tile_shapes([tile_batch, tile_batch],
[mm2_cube_tile_shape[1], mm2_cube_tile_shape[1] * 2],
[mm2_cube_tile_shape[2], mm2_cube_tile_shape[2]], False)
pypto.set_pass_options(sg_set_scope=(10003, False, True))
down_proj = pypto.matmul(down_proj_quant, w2, pypto.DT_INT32)
pypto.set_vec_tile_shapes(tile_batch, 640)
down_proj_fp32 = pypto.cast(down_proj, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
down_proj_scale_fp32 = pypto.cast(down_proj_scale, pypto.DT_FP32, pypto.CastMode.CAST_NONE)
down_proj_dequant_scale_2 = pypto.mul(down_proj_fp32, down_proj_scale_fp32)
down_proj_dequant = pypto.mul(down_proj_dequant_scale_2, w2_scale_2d_fp32)
out = pypto.cast(down_proj_dequant, hidden_states.dtype)
pypto.set_pass_options(sg_set_scope=-1)
pypto.assemble(out, hidden_states_offset, ffn_res)
@pytest.mark.soc("950")
def test_moe_fusion():
enable_graph = False
ne = 160
h_num = 5120
top_k = 8
topk_group = 1
num_expert_group = 1
renormalize = True
x_dtype = torch.bfloat16
intermediate_size = 192
hidden_size = h_num
torch_npu.npu.config.allow_internal_format = True
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch.manual_seed(0)
for bs in [16, 16, 16, 16, 16, 16]:
hidden_states = torch.rand((bs, hidden_size), dtype=x_dtype, device=f'npu:{device_id}') * 0.05
weight_gate_upper_tensor = torch.rand((hidden_size, intermediate_size * 2),
dtype=x_dtype, device=f'npu:{device_id}') * 0.05
w13, w13_scale = gen_quan_per_channel_weight_nz(weight_gate_upper_tensor)
w13_scale = w13_scale.reshape(-1).to(x_dtype)
weight_down_proj_tensor = torch.rand((intermediate_size, hidden_size),
dtype=x_dtype, device=f'npu:{device_id}') * 0.05
w2, w2_scale = gen_quan_per_channel_weight_nz(weight_down_proj_tensor)
w2_scale = w2_scale.reshape(-1).to(x_dtype)
ffn_res = torch.empty((bs, hidden_size), dtype=x_dtype, device=f'npu:{device_id}')
mm_weight = torch.rand((ne, h_num), dtype=torch.float32, device=f'npu:{device_id}')
e_score_bias = torch.rand(
(ne), dtype=torch.bfloat16, device=f'npu:{device_id}')
topk_weights = torch.empty(
(bs, top_k), dtype=torch.float32, device=f'npu:{device_id}')
topk_ids = torch.empty(
(bs, top_k), dtype=torch.int32, device=f'npu:{device_id}')
inputs = [
mm_weight,
hidden_states,
top_k,
renormalize,
topk_group,
num_expert_group,
e_score_bias,
w13,
w13_scale,
w2,
w2_scale
]
outputs = [
topk_weights,
topk_ids,
ffn_res
]
if enable_graph:
g = torch.npu.NPUGraph()
with torch.npu.graph(g):
moe_fusion(*inputs, *outputs)
g.replay()
else:
moe_fusion(*inputs, *outputs)
result = torch.matmul(hidden_states.to(torch.float32), mm_weight.to(torch.float32).t())
router_logits_fp32 = result.to(torch.float)
original_weights = router_logits_fp32.sigmoid()
bias_2d = e_score_bias.unsqueeze(0)
topk_weights_g_add = original_weights + bias_2d
tw_view = topk_weights_g_add.view(bs, num_expert_group, -1)
grouped_weights = tw_view.max(dim=-1).values
topk_group_indices_g = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices_g, 1)
tgm_unsquee = topk_group_mask.unsqueeze(-1)
tgm_expand = tgm_unsquee.expand(bs, num_expert_group, ne // num_expert_group)
topk_weight_mask = tgm_expand.reshape(bs, -1)
logical_not_tmp = ~topk_weight_mask.bool()
topk_weights_fill = topk_weights_g_add.masked_fill(
logical_not_tmp, 0.0)
topk_ids_int64 = torch.topk(topk_weights_fill.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
topk_ids_int32 = topk_ids_int64.to(torch.int32)
topk_weights_gather = original_weights.gather(1, topk_ids_int64)
if renormalize:
topk_weights_out = topk_weights_gather / topk_weights_gather.sum(dim=-1, keepdim=True)
else:
topk_weights_out = topk_weights_gather
topk_weight_2_tensor_list = topk_weights_out.cpu().flatten().tolist()
topk_ids_tensor_list = topk_ids_int32.cpu().flatten().tolist()
x_dtype = hidden_states.dtype
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(hidden_states)
output_w13 = torch_npu.npu_quant_matmul(
quantized_x,
w13,
w13_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=x_dtype
)
swiglu_out = torch_npu.npu_swiglu(output_w13)
quantized_x, x_scale = torch_npu.npu_dynamic_quant(swiglu_out)
golden = torch_npu.npu_quant_matmul(
quantized_x,
w2,
w2_scale,
pertoken_scale=x_scale,
bias=None,
output_dtype=x_dtype
)
assert_allclose(np.array(topk_weights.cpu().flatten().tolist()), np.array(topk_weight_2_tensor_list),
rtol=5e-3, atol=5e-3)
assert_allclose(np.array(topk_ids.cpu().flatten().tolist()), np.array(topk_ids_tensor_list),
rtol=5e-3, atol=5e-3)
assert_allclose(np.array(ffn_res.cpu().flatten().tolist()), np.array(golden.cpu().flatten().tolist()),
rtol=0.0078125, atol=0.0001)
import pypto.pypto_impl as pypto_impl
total_elapsed = pypto_impl.GetCompilerMonitorTotalElapsed()
check_cond(total_elapsed <= 30, f"glm_moe_fusion compile elapsed timeout {total_elapsed}s > 30s.")
@allow_in_graph
def moe_fusion(
gate_weight: torch.Tensor,
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: int,
num_expert_group: int,
e_score_bias: torch.Tensor,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ffn_res: torch.Tensor
):
if isinstance(hidden_states, FakeTensor):
return
check_args(gate_weight, hidden_states, top_k, renormalize, topk_group, num_expert_group,
e_score_bias, w13, w13_scale, w2, w2_scale)
bs = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
inputs = [hidden_states, gate_weight, e_score_bias, w13, w13_scale,
w2, w2_scale, topk_weights, topk_ids, ffn_res]
moe_fusion_kernel(*inputs, topk_group, num_expert_group)
def moe_fusion_pto(gate_layer, hidden_states, share_layer, top_k, renormalize, topk_group=None, num_expert_group=None,
e_score_correction_bias=None):
bs = hidden_states.shape[0]
ne = gate_layer.weight.shape[0]
device_info = hidden_states.device
topk_weights = torch.empty((bs, top_k), dtype=torch.float32, device=device_info)
topk_ids = torch.empty((bs, top_k), dtype=torch.int32, device=device_info)
ffn_res = torch.empty_like(hidden_states, device=device_info)
w13_int8 = share_layer.gate_up_proj.weight
w13_scale = share_layer.gate_up_proj.weight_scale
w2_int8 = share_layer.down_proj.weight
w2_scale = share_layer.down_proj.weight_scale
moe_fusion(
gate_layer.weight,
hidden_states,
top_k,
renormalize,
topk_group,
num_expert_group,
e_score_correction_bias,
w13_int8,
w13_scale,
w2_int8,
w2_scale,
topk_weights,
topk_ids,
ffn_res
)
return topk_weights, topk_ids, ffn_res
def main():
pypto.set_host_options(compile_monitor_enable=1,
compile_timeout=10,
compile_timeout_stage=5,
compile_monitor_print_interval=2)
test_moe_fusion()
if __name__ == "__main__":
main()