import os
import re
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices, create_common_tensor
class TestNPUAdvanceStepFlashattn(TestCase):
def advance_step_flashattn_golden(self, num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables):
num_core = 40
cmd = 'npu-smi info -t common -i 0'
test_list = os.popen(cmd).readlines()
for line in test_list:
if re.search('Aicore Count', line):
num_core = int(line[-3:]) * 2
for block_idx in range(num_core):
n_pad = num_seqs - num_queries
if n_pad > 0 and block_idx == 0:
offset = num_queries
i = 0
while i < n_pad:
input_tokens[offset + i] = 0
input_positions[offset + i] = 0
slot_mapping[offset + i] = -1
i += num_core
num_query_blocks = num_queries // 1
if block_idx >= num_query_blocks:
break
cur_query_id = block_idx * 1 + 0
if cur_query_id >= num_queries:
break
input_tokens[cur_query_id] = sampled_token_ids[cur_query_id]
seq_len = seq_lens[cur_query_id]
next_seq_len = seq_len + 1
next_input_pos = next_seq_len - 1
seq_lens[cur_query_id] = next_seq_len
input_positions[cur_query_id] = next_input_pos
block_index = next_input_pos // block_size
block_offset = next_input_pos % block_size
cur_block = block_tables.flatten()[block_index]
slot_num = (cur_block + block_tables.stride(0) * cur_query_id) * block_size + block_offset
slot_mapping[cur_query_id] = slot_num
def advance_step_spec_flashattn_golden(self, num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, spec_token, accepted_num):
token_each_reqs = 1 + len(spec_token[0])
input_positions += torch.repeat_interleave(accepted_num, token_each_reqs) + 1
seq_lens.copy_((input_positions + 1).to(seq_lens.dtype))
index = torch.argmin(
torch.cat([
sampled_token_ids,
torch.full((num_seqs, 1), -1, device=sampled_token_ids.device)
], dim=1),
dim=1
) - 1
last_tokens = sampled_token_ids[torch.arange(num_seqs), index]
if token_each_reqs == 1:
input_tokens[:num_seqs] = last_tokens.to(dtype=input_tokens.dtype)
else:
input_tokens_2d = input_tokens.view(-1, token_each_reqs)
input_tokens_2d[:num_seqs, 0] = last_tokens
input_tokens_2d[:num_seqs, 1:] = spec_token
req_indices = torch.repeat_interleave(
torch.arange(num_seqs),
token_each_reqs,
dim=0
)
max_num_blocks_per_req = block_tables.shape[1]
block_tables_indices = (
req_indices * max_num_blocks_per_req +
input_positions // block_size
)
block_numbers = block_tables.flatten()[block_tables_indices]
block_offset = input_positions % block_size
slot_mapping.copy_(block_numbers * block_size + block_offset)
@unittest.skip("Skipping test_npu_advance_step_flashattn for now")
@SupportedDevices(['Ascend910B'])
def test_npu_advance_step_flashattn(self):
iter_num = 7
for i in range(iter_num):
num_seqs = 2 * i + 2
num_queries = i + 1
block_size = i + 1
shape_format_num_seqs = [np.int64, 2, [num_seqs, ]]
shape_format_num_queries = [np.int64, 2, [num_queries, 1]]
input_tokens_cpu, input_tokens = create_common_tensor(shape_format_num_seqs, 5, 50)
sampled_token_ids_cpu, sampled_token_ids = create_common_tensor(shape_format_num_queries, 5, 50)
input_positions_cpu, input_positions = create_common_tensor(shape_format_num_seqs, 5, 50)
seq_lens_cpu, seq_lens = create_common_tensor(shape_format_num_seqs, 5, 50)
slot_mapping_cpu, slot_mapping = create_common_tensor(shape_format_num_seqs, 5, 50)
shape_format_block_tables = [np.int64, 2, [num_seqs, torch.max(seq_lens_cpu) // block_size + 1]]
block_tables_cpu, block_tables = create_common_tensor(shape_format_block_tables, 5, 50)
self.advance_step_flashattn_golden(num_seqs, num_queries, block_size, input_tokens_cpu,
sampled_token_ids_cpu, input_positions_cpu, seq_lens_cpu,
slot_mapping_cpu, block_tables_cpu)
torch_npu.npu_advance_step_flashattn(input_tokens, sampled_token_ids, input_positions,
seq_lens, slot_mapping, block_tables, num_seqs,
num_queries, block_size)
self.assertRtolEqual(input_tokens_cpu, input_tokens.cpu())
self.assertRtolEqual(input_positions_cpu, input_positions.cpu())
self.assertRtolEqual(seq_lens_cpu, seq_lens.cpu())
self.assertRtolEqual(slot_mapping_cpu, slot_mapping.cpu())
@unittest.skip("Skip until CANN is updated to 8.3.RC1 to support aclnnAdvanceStepV2")
@SupportedDevices(['Ascend910B'])
def test_npu_advance_step_spec_flashattn(self):
iter_num = 7
for i in range(iter_num):
num_seqs = 2 * i + 2
block_size = 8
spec_num = 3
shape_format_num_seqs = [np.int64, 2, [num_seqs * (spec_num + 1), ]]
shape_format_spec_token = [np.int64, 2, [num_seqs, spec_num]]
shape_format_accepted_num = [np.int64, 2, [num_seqs, ]]
shape_format_sampled_token = [np.int64, 2, [num_seqs, spec_num + 1]]
input_tokens_cpu, input_tokens = create_common_tensor(shape_format_num_seqs, 5, 50)
sampled_token_ids_cpu, sampled_token_ids = create_common_tensor(shape_format_sampled_token, 5, 50)
input_positions_cpu, input_positions = create_common_tensor(shape_format_num_seqs, 5, 50)
seq_lens_cpu, seq_lens = create_common_tensor(shape_format_num_seqs, 5, 50)
slot_mapping_cpu, slot_mapping = create_common_tensor(shape_format_num_seqs, 5, 50)
spec_token_cpu, spec_token = create_common_tensor(shape_format_spec_token, 5, 50)
accepted_num_cpu, accepted_num = create_common_tensor(shape_format_accepted_num, 5, 50)
shape_format_block_tables = [np.int64, 2, [num_seqs, 10000]]
block_tables_cpu, block_tables = create_common_tensor(shape_format_block_tables, 5, 50)
self.advance_step_spec_flashattn_golden(num_seqs, num_seqs, block_size, input_tokens_cpu,
sampled_token_ids_cpu, input_positions_cpu, seq_lens_cpu,
slot_mapping_cpu, block_tables_cpu, spec_token_cpu, accepted_num_cpu)
torch_npu.npu_advance_step_flashattn(input_tokens, sampled_token_ids, input_positions,
seq_lens, slot_mapping, block_tables, num_seqs,
num_seqs, block_size, spec_token=spec_token, accepted_num=accepted_num)
self.assertRtolEqual(input_tokens_cpu, input_tokens.cpu())
self.assertRtolEqual(input_positions_cpu, input_positions.cpu())
self.assertRtolEqual(seq_lens_cpu, seq_lens.cpu())
self.assertRtolEqual(slot_mapping_cpu, slot_mapping.cpu())
if __name__ == "__main__":
run_tests()