import unittest
import sys
import random
import os
import ctypes
from multiprocessing import Pool, cpu_count
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
A = 4
M = 1
BS = 54
K_plus_1 = 9
H = 7168
expert_num = 2
out_num = A
tokenDtype = 2
need_schedule = 1
invaildProb = 0
if need_schedule != 0:
need_schedule = 1
Y = A * BS * K_plus_1
HS = H * 2
if tokenDtype == 2:
HS = (H + 4 + 512 - 1) // 512 * 512
F = (1 + 1 + BS * K_plus_1)
cur_micro_batch_id = 0
uniq_expert_id_cnt = 0
class CommonArea(ctypes.Structure):
_pack_ = 1
_fields_ = [
("session_num", ctypes.c_uint32),
("micro_batch_num", ctypes.c_uint32),
("micro_batch_size", ctypes.c_uint32),
("selected_expert_num", ctypes.c_uint32),
("expert_num", ctypes.c_uint32),
("attn_to_ffn_token_size", ctypes.c_uint32),
("ffn_to_attn_token_size", ctypes.c_uint32),
("schedule_mode", ctypes.c_int32),
("reserve0", ctypes.c_int8 * 96)
]
class ControlArea(ctypes.Structure):
_pack_ = 1
_fields_ = [
("run_flag", ctypes.c_int32),
("reserve1", ctypes.c_int8 * 4),
("schedule_wait_buf", ctypes.c_uint64),
("ffn_wait_buf", ctypes.c_uint64),
("attention_wait_buf", ctypes.c_uint64),
("reserve2", ctypes.c_int8 * 96)
]
class AttentionArea(ctypes.Structure):
_pack_ = 1
_fields_ = [
("token_info_buf", ctypes.c_uint64),
("token_info_buf_size", ctypes.c_uint64),
("token_data_buf", ctypes.c_uint64),
("token_data_buf_size", ctypes.c_uint64),
("micro_batch_id", ctypes.c_uint32),
("reserve5", ctypes.c_int8 * 92)
]
class FfnArea(ctypes.Structure):
_pack_ = 1
_fields_ = [
("token_info_buf", ctypes.c_uint64),
("token_info_buf_size", ctypes.c_uint64),
("token_data_buf", ctypes.c_uint64),
("token_data_buf_size", ctypes.c_uint64),
("polling_index", ctypes.c_uint64),
("reserve3", ctypes.c_int8 * 88),
("layer_ids_buf", ctypes.c_uint64),
("layer_ids_buf_size", ctypes.c_uint64),
("session_ids_buf", ctypes.c_uint64),
("session_ids_buf_size", ctypes.c_uint64),
("micro_batch_ids_buf", ctypes.c_uint64),
("micro_batch_ids_buf_size", ctypes.c_uint64),
("expert_ids_buf", ctypes.c_uint64),
("expert_ids_buf_size", ctypes.c_uint64),
("out_num", ctypes.c_uint32),
("reserve4", ctypes.c_int8 * 60)
]
class DataDesc(ctypes.Structure):
_pack_ = 1
_fields_ = [
("flag", ctypes.c_int32),
("layer_id", ctypes.c_int32),
("expert_ids", ctypes.c_int32 * (BS * K_plus_1))
]
if tokenDtype == 2:
class TokenDataQuant(ctypes.Structure):
_pack_ = 1
_fields_ = [
("hidden_states", ctypes.c_int8 * H),
("quant_scale", ctypes.c_float),
("padding", ctypes.c_uint8 * (HS - H - 4))
]
else:
class TokenData(ctypes.Structure):
_pack_ = 1
_fields_ = [
("hidden_states", ctypes.c_uint16 * H),
("padding", ctypes.c_uint8 * (HS - H * 2))
]
class ScheduleContext(ctypes.Structure):
_pack_ = 1
_fields_ = [
("common", CommonArea),
("control", ControlArea),
("attention", AttentionArea),
("ffn", FfnArea),
("reserve6", ctypes.c_int8 * 384)
]
def read_schedule_context(file_path):
with open(file_path, "rb") as f:
file_content = f.read()
ctx = ScheduleContext.from_buffer_copy(file_content[:ctypes.sizeof(ScheduleContext)])
parsed_data = {}
def parse_data_block(buf_field, size_field, data_type):
offset = getattr(ctx.ffn, buf_field)
size = getattr(ctx.ffn, size_field)
if offset == 0 or size == 0:
return None
data_bytes = file_content[offset: offset + size]
element_size = ctypes.sizeof(data_type)
num_elements = size // element_size
array_type = data_type * num_elements
return array_type.from_buffer_copy(data_bytes)
parsed_data['token_info_buf'] = parse_data_block('token_info_buf', 'token_info_buf_size', DataDesc)
if tokenDtype == 2:
parsed_data['token_data_buf'] = parse_data_block('token_data_buf', 'token_data_buf_size', TokenDataQuant)
else:
parsed_data['token_data_buf'] = parse_data_block('token_data_buf', 'token_data_buf_size', TokenData)
parsed_data['session_ids_buf'] = parse_data_block('session_ids_buf', 'session_ids_buf_size', ctypes.c_int32)
parsed_data['micro_batch_ids_buf'] = parse_data_block('micro_batch_ids_buf', 'micro_batch_ids_buf_size',
ctypes.c_int32)
parsed_data['expert_ids_buf'] = parse_data_block('expert_ids_buf', 'expert_ids_buf_size', ctypes.c_int32)
return ctx, parsed_data
def convert_ctypes_to_torch(ctypes_array, expected_dtype):
type_map = {
ctypes.c_int32: torch.int32,
ctypes.c_float: torch.float32,
ctypes.c_uint8: torch.uint8,
ctypes.c_int64: torch.int64,
}
if not isinstance(ctypes_array, ctypes.Array):
raise TypeError("输入必须是ctypes数组")
ptr = ctypes.addressof(ctypes_array)
num_elements = len(ctypes_array)
dtype = type_map.get(ctypes_array._type_, None)
if dtype is None:
raise ValueError(f"不支持的ctypes类型: {ctypes_array._type_}")
tensor = torch.frombuffer(
(ctypes.c_byte * num_elements * ctypes.sizeof(ctypes_array._type_)).from_address(ptr),
dtype=dtype
)
return tensor.clone()
def create_random_c_array2(array_length, min_val, max_val, c_type=ctypes.c_int32, extreme_val=None):
"""
生成一个C语言风格的数组,内容为指定范围的随机整数,
参数:
array_length: 数组长度
min_val: 随机数最小值(包含)
max_val: 随机数最大值(包含)
c_type: C数据类型,默认为ctypes.c_int32
max_val: 极大值 (默认自动选择c_type的最大值)
返回:
C风格的数组(指针可直接传递给C函数)
"""
random_values = [random.randint(min_val, max_val) for _ in range(array_length)]
c_array = (c_type * array_length)(*random_values)
return c_array
def create_random_c_array(array_length, min_val, max_val, c_type=ctypes.c_int32, prob=0.0, extreme_val=None):
"""
生成一个C语言风格的数组,内容为指定范围的随机整数,
参数:
array_length: 数组长度
min_val: 随机数最小值(包含)
max_val: 随机数最大值(包含)
c_type: C数据类型,默认为ctypes.c_int32
prob: 元素被替换为极大值的概率 (默认0.0)
max_val: 极大值 (默认自动选择c_type的最大值)
返回:
C风格的数组(指针可直接传递给C函数)
"""
if extreme_val is None:
extreme_val = 214748364
random_values = [random.randint(min_val, max_val) for _ in range(array_length)]
if prob > 0.0:
a = 0
for i in range(array_length):
if random.random() < prob and a < 1512:
random_values[i] = extreme_val
a = a + 1
c_array = (c_type * array_length)(*random_values)
return c_array
def generate_token_data_element(_):
"""生成单个TokenData结构体元素(支持float16)"""
if tokenDtype == 2:
token = TokenDataQuant()
int8_values = np.random.randint(-10, 10, size=H).astype(np.int8)
ctypes.memmove(token.hidden_states, int8_values.ctypes.data, ctypes.sizeof(token.hidden_states))
token.quant_scale = np.random.uniform(0.1, 2.0)
ctypes.memset(ctypes.addressof(token.padding), 0, HS - H - 4)
else:
token = TokenData()
float16_values = np.random.uniform(-1.0, 1.0, size=H).astype(np.float16)
ctypes.memmove(token.hidden_states, float16_values.ctypes.data, ctypes.sizeof(token.hidden_states))
ctypes.memset(ctypes.addressof(token.padding), 0, HS - H * 2)
return token
def generate_token_info_element(AM_idx):
"""生成单个TokenInfo结构体元素"""
tokenInfoDes = DataDesc()
if AM_idx % M == cur_micro_batch_id:
tokenInfoDes.flag = 1
cur_expert_ids = create_random_c_array((BS * K_plus_1), 0, 1, prob=invaildProb)
else:
tokenInfoDes.flag = 0
cur_expert_ids = create_random_c_array((BS * K_plus_1), 0, 1, prob=invaildProb)
tokenInfoDes.expert_ids = cur_expert_ids
return tokenInfoDes
def generate_token_info():
"""生成完整的toke_info数据(并行优化版)"""
total_elements = A * M
chunksize = max(1000, total_elements // (cpu_count() * 4))
with Pool(cpu_count()) as pool:
results = list(pool.imap(
generate_token_info_element,
range(total_elements),
chunksize=chunksize
))
return (DataDesc * total_elements)(*results)
def generate_token_data():
"""生成完整的token_data数据(并行优化版)"""
total_elements = A * M * BS * K_plus_1
chunksize = max(1000, total_elements // (cpu_count() * 4))
with Pool(cpu_count()) as pool:
results = list(pool.imap(
generate_token_data_element,
range(total_elements),
chunksize=chunksize
))
if tokenDtype == 2:
arr_type = TokenDataQuant * total_elements
else:
arr_type = TokenData * total_elements
return arr_type(*results)
tensor_buffers = {}
ffn_data_parts = []
def generate_input_func():
ctx = ScheduleContext()
ctx.common.session_num = A
ctx.common.micro_batch_num = M
ctx.common.micro_batch_size = BS
ctx.common.selected_expert_num = K_plus_1
ctx.common.expert_num = expert_num
ctx.common.attn_to_ffn_token_size = HS
ctx.common.ffn_to_attn_token_size = 512
ctx.common.schedule_mode = 1
ctypes.memset(ctx.common.reserve0, 0, 96)
ctx.control.run_flag = 1
ctypes.memset(ctx.control.reserve1, 0, 4)
ctx.control.schedule_wait_buf = 0x1000
ctx.control.ffn_wait_buf = 0x2000
ctx.control.attention_wait_buf = 0x3000
ctypes.memset(ctx.control.reserve2, 0, 96)
ctx.attention.token_info_buf = 0x4000
ctx.attention.token_info_buf_size = 1024
ctx.attention.token_data_buf = 0x5000
ctx.attention.token_data_buf_size = 2048
ctx.attention.micro_batch_id = 1
ctypes.memset(ctx.attention.reserve5, 0, 92)
if need_schedule == 1:
token_info = generate_token_info()
token_info_bytes = bytes(token_info)
int32_array = np.frombuffer(token_info_bytes, dtype=np.int32).copy()
np.savetxt('tokeninfo_org.txt', int32_array, fmt='%d', delimiter=',')
ffn_data_parts.append(("token_info", token_info_bytes))
token_data = generate_token_data()
token_data_bytes = bytes(token_data)
ffn_data_parts.append(("token_data", token_data_bytes))
random_session_ids = [i for i in range(A)]
session_ids = (ctypes.c_int32 * A)(*random_session_ids)
ffn_data_parts.append(("session_ids", session_ids))
random_micro_batch_ids = [cur_micro_batch_id for _ in range(M)]
micro_batch_ids = (ctypes.c_int32 * M)(*random_micro_batch_ids)
ffn_data_parts.append(("micro_batch_ids", micro_batch_ids))
if need_schedule == 0:
expert_ids = create_random_c_array((out_num * BS * K_plus_1), 0, 1, prob=invaildProb)
ffn_data_parts.append(("expert_ids", expert_ids))
current_offset = ctypes.sizeof(ctx)
for name, data in ffn_data_parts:
data_bytes = bytes(data) if isinstance(data, ctypes.Array) else data
size = len(data_bytes)
if name == "token_info":
ctx.ffn.token_info_buf = current_offset
ctx.ffn.token_info_buf_size = size
elif name == "token_data":
ctx.ffn.token_data_buf = current_offset
ctx.ffn.token_data_buf_size = size
elif name == "session_ids":
ctx.ffn.session_ids_buf = current_offset
ctx.ffn.session_ids_buf_size = size
elif name == "micro_batch_ids":
ctx.ffn.micro_batch_ids_buf = current_offset
ctx.ffn.micro_batch_ids_buf_size = size
elif name == "expert_ids":
ctx.ffn.expert_ids_buf = current_offset
ctx.ffn.expert_ids_buf_size = size
current_offset += size
if need_schedule == 0:
ctx.ffn.token_info_buf = 0
ctx.ffn.token_info_buf_size = 0
else:
ctx.ffn.expert_ids_buf = 0
ctx.ffn.expert_ids_buf_size = 0
ctx.ffn.polling_index = cur_micro_batch_id
ctx.ffn.out_num = out_num
with open("schedule_context.bin", "wb") as f:
f.write(bytes(ctx))
for _, data in ffn_data_parts:
f.write(data if isinstance(data, bytes) else bytes(data))
tensor_buffers.clear()
for name, data in ffn_data_parts:
data_bytes = bytes(data) if isinstance(data, ctypes.Array) else data
if name in ["token_info", "session_ids", "micro_batch_ids", "expert_ids"]:
np_array = np.frombuffer(data_bytes, dtype=np.int32).copy()
tensor = torch.from_numpy(np_array).int()
else:
if tokenDtype == 2:
np_array = np.frombuffer(data_bytes, dtype=np.int8).copy()
tensor = torch.from_numpy(np_array).to(torch.int8)
else:
np_array = np.frombuffer(data_bytes, dtype=np.float16).copy()
tensor = torch.from_numpy(np_array).to(torch.float16)
tensor = tensor.npu()
tensor_buffers[name] = tensor
def calc_expect_func():
ctx, ffn_area = read_schedule_context("schedule_context.bin")
if tokenDtype == 2:
dynamic_scales_numpy = np.array(
[token.quant_scale for token in ffn_area["token_data_buf"]],
dtype=np.float32
)
dynamic_scales = torch.from_numpy(dynamic_scales_numpy)
dynamic_scales = dynamic_scales.view(A, M, BS, K_plus_1, 1)
if tokenDtype == 2:
uint8_data = np.array(
[token.hidden_states for token in ffn_area["token_data_buf"]],
dtype=np.int8
)
byte_stream = uint8_data.tobytes()
int8_data = np.frombuffer(byte_stream, dtype=np.int8)
int8_data_writable = int8_data.copy()
token_data = torch.from_numpy(int8_data_writable)
token_data = token_data.view(A, M, BS, K_plus_1, H)
else:
uint16_data = np.array(
[token.hidden_states for token in ffn_area["token_data_buf"]],
dtype=np.uint16
)
byte_stream = uint16_data.tobytes()
float16_data = np.frombuffer(byte_stream, dtype=np.float16)
float16_data_writable = float16_data.copy()
token_data = torch.from_numpy(float16_data_writable)
token_data = token_data.view(A, M, BS, K_plus_1, H)
if need_schedule == 0:
expert_ids = convert_ctypes_to_torch(
ffn_area["expert_ids_buf"],
torch.int32
)
else:
AM_idx = 0
expert_ids_array = np.array([], dtype=np.int32)
for token_info in ffn_area["token_info_buf"]:
if AM_idx % M == cur_micro_batch_id:
tmp_expert_ids = np.array(token_info.expert_ids, dtype=np.int32)
expert_ids_array = np.concatenate((expert_ids_array, tmp_expert_ids))
AM_idx += 1
byte_stream = expert_ids_array.tobytes()
int32_data = np.frombuffer(byte_stream, dtype=np.int32)
int32_data_writable = int32_data.copy()
expert_ids = torch.from_numpy(int32_data_writable)
expert_ids_shape = expert_ids.view(A, BS, K_plus_1)
session_ids_buf = convert_ctypes_to_torch(
ffn_area["session_ids_buf"],
torch.int32
)
micro_batch_ids_buf = convert_ctypes_to_torch(
ffn_area["micro_batch_ids_buf"],
torch.int32
)
masked_expert_ids = expert_ids.clone()
masked_count = (masked_expert_ids >= 1e6).sum().item()
zero_up_count = (masked_expert_ids >= 0).sum().item()
flat_expert_ids = masked_expert_ids.view(-1)
sorted_indices = torch.argsort(flat_expert_ids, stable=True)
sorted_expert_ids = flat_expert_ids[sorted_indices]
valid_count = len(sorted_indices) - masked_count
actual_token_num = valid_count
gather_idx = sorted_indices.to(torch.int32)
dtype_map = {0: torch.float16, 1: torch.bfloat16, 2: torch.int8, 3: torch.int8}
y = torch.zeros(Y, H, dtype=dtype_map[0], device=token_data.device)
a_indices = gather_idx // (BS * K_plus_1)
remainder = gather_idx % (BS * K_plus_1)
bs_indices = remainder // K_plus_1
k_indices = remainder % K_plus_1
session_indices = session_ids_buf[a_indices]
micro_batch_indices = micro_batch_ids_buf[cur_micro_batch_id]
token_data_indices = (session_indices, micro_batch_indices, bs_indices, k_indices)
y = token_data[token_data_indices]
if tokenDtype == 2:
dynamic_scale = dynamic_scales[token_data_indices]
unique_experts, counts = torch.unique_consecutive(
sorted_expert_ids, return_counts=True
)
group_list = torch.zeros(expert_num, 2, dtype=torch.int64, device=token_data.device)
mask = unique_experts <= 1e6
filtered_experts = unique_experts[mask]
filtered_counts = counts[mask]
group_list[:, 0] = filtered_experts
group_list[:, 1] = filtered_counts
global uniq_expert_id_cnt
uniq_expert_id_cnt = len(filtered_experts)
session_ids = session_ids_buf[a_indices]
micro_batch_ids = micro_batch_ids_buf[cur_micro_batch_id]
token_ids = bs_indices
expert_offsets = k_indices
if tokenDtype == 2:
return y, group_list, session_ids, micro_batch_ids, token_ids, expert_offsets, dynamic_scale.squeeze(), actual_token_num
else:
return y, group_list, session_ids, micro_batch_ids, token_ids, expert_offsets, None, actual_token_num
class TestFfnWorkerBatching(TestCase):
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_worker_batching_001(self):
generate_input_func()
y_cpu, group_list_cpu, session_ids_cpu, micro_batch_ids_cpu, token_ids_cpu, expert_offsets_cpu, dynamic_scale_cpu, actual_token_num_cpu = calc_expect_func()
schedule_context_bin = "./schedule_context.bin"
ctx, ffn_area = read_schedule_context(schedule_context_bin)
for name, data in ffn_data_parts:
data_bytes = bytes(data) if isinstance(data, ctypes.Array) else data
if name == "token_info":
ctx.ffn.token_info_buf = tensor_buffers[name].data_ptr()
ctx.ffn.token_info_buf_size = len(data_bytes)
elif name == "token_data":
ctx.ffn.token_data_buf = tensor_buffers[name].data_ptr()
ctx.ffn.token_data_buf_size = len(data_bytes)
elif name == "session_ids":
ctx.ffn.session_ids_buf = tensor_buffers[name].data_ptr()
ctx.ffn.session_ids_buf_size = len(data_bytes)
elif name == "micro_batch_ids":
ctx.ffn.micro_batch_ids_buf = tensor_buffers[name].data_ptr()
ctx.ffn.micro_batch_ids_buf_size = len(data_bytes)
elif name == "expert_ids":
ctx.ffn.expert_ids_buf = tensor_buffers[name].data_ptr()
ctx.ffn.expert_ids_buf_size = len(data_bytes)
ctx_bytes = bytes(ctx)
ctx_numpy = np.frombuffer(ctx_bytes, dtype=np.int8).copy()
ctx_tensor = torch.from_numpy(ctx_numpy).to(torch.int8)
schedule_context_npu = ctx_tensor.npu()
max_out_shape = [A, BS, K_plus_1, H]
y_npu, group_list_npu, session_ids_npu, micro_batch_ids_npu, token_ids_npu, expert_offsets_npu, dynamic_scale_npu, actual_token_num_npu = \
torch_npu.npu_ffn_worker_batching(schedule_context_npu, expert_num, max_out_shape, token_dtype=tokenDtype, need_schedule=need_schedule, layer_num=0)
torch_npu.npu.synchronize()
self.assertRtolEqual(y_cpu, y_npu)
if __name__ == "__main__":
run_tests()