import unittest
import numpy as np
from mindie_llm.text_generator.utils.request import Request
from mindie_llm.text_generator.utils.generation_metadata import GenerationParams
class TestRequest(unittest.TestCase):
def setUp(self):
self.generation_params = GenerationParams(max_new_tokens=50, best_of=3, ignore_eos=True)
self.seq_id = 1
self.req_id = np.array([1])
self.sampling_params = np.array([0.5, 0.5])
def test_initialization(self):
request = Request(
req_id=1,
seq_id=self.seq_id,
input_ids=np.array([1, 2, 3, 4, 5], dtype=np.int64),
generation_params=self.generation_params,
has_sampling=True,
sampling_params=self.sampling_params
)
self.assertEqual(request.req_id, 1)
self.assertEqual(request.sequences[self.seq_id].seq_id, self.seq_id)
self.assertEqual(request.best_of, 3)
self.assertEqual(request.max_new_tokens, 50)
self.assertTrue(request.ignore_eos)
self.assertEqual(request.skip_special_tokens, self.generation_params.skip_special_tokens)
self.assertTrue(np.array_equal(request.sampling_params, self.sampling_params))
def test_from_warmup(self):
input_len = 5
max_output_len = 10
warmup_request = Request.from_warmup(input_len, max_output_len, 130000)
self.assertEqual(warmup_request.req_id, 0)
self.assertEqual(warmup_request.max_new_tokens, 10)
self.assertEqual(warmup_request.input_length, input_len)
def test_request_from_token(self):
input_ids = [10, 20, 30]
sampling_params = np.array([0.7, 0.3])
generation_params = GenerationParams(max_new_tokens=15, best_of=2)
token_request = Request.request_from_token(
input_ids, sampling_params, generation_params, req_id=self.req_id, seq_id=self.seq_id)
self.assertEqual(token_request.req_id, 1)
self.assertEqual(token_request.sequences[self.seq_id].seq_id, self.req_id)
self.assertTrue(np.array_equal(token_request.input_ids, np.array([10, 20, 30])))
self.assertEqual(token_request.max_new_tokens, 15)
self.assertEqual(token_request.best_of, 2)
self.assertTrue(np.array_equal(token_request.sampling_params, sampling_params))
def test_request_lifecycle(self):
input_len = 128 * 6
max_output_len = 2
max_placeholder_num = 1
request = Request.from_warmup(input_len, max_output_len, max_placeholder_num)
request.build(dp_rank_id=0, scp_size=4, block_size=128, is_mix_model=False)
gold_block_table = np.array([[0, 0, -1, -1, -1, -1],
[0, -1, -1, -1, -1, -1],
[0, -1, -1, -1, -1, -1],
[0, 0, -1, -1, -1, -1]])
self.assertEqual(request.dp_rank_id, 0)
self.assertEqual(request.sp_tokens.size, 4)
self.assertEqual(request.sp_rank_id, 3)
self.assertEqual(request.prefill_block_rank_id.size, 6)
self.assertTrue(np.array_equal(request.block_tables, gold_block_table))
self.assertEqual(len(request.computed_blocks), 4)
self.assertEqual(len(request.remote_computed_blocks), 4)
request.step(num_new_token=1, scp_size=4, block_size=128, is_mix_model=False)
gold_block_table = np.array([[0, 0, 0, -1, -1, -1],
[0, -1, -1, -1, -1, -1],
[0, -1, -1, -1, -1, -1],
[0, 0, -1, -1, -1, -1]])
self.assertEqual(request.output_len, 1)
self.assertEqual(request.max_placeholder_num, 1)
self.assertTrue(request.is_append_block)
self.assertEqual(request.block_rank_id, 0)
self.assertTrue(np.array_equal(request.block_tables, gold_block_table))
def test_update_with_chunkprefill_features(self):
request = Request.from_warmup(100, 1, 130000)
request.update_with_features(
is_mix_model=True,
scp_size=1,
is_prefill=True
)
self.assertEqual(request.split_start_position, 0)
self.assertEqual(request.split_end_position, 100)
self.assertTrue(request.last_prompt)
if __name__ == '__main__':
unittest.main()