import unittest
import numpy as np
from mindie_llm.text_generator.utils.generation_output import GenerationOutput
class TestGenerationOutput(unittest.TestCase):
def setUp(self):
self.generation_output = GenerationOutput(
sequence_ids=np.array([0, 1, 2]),
parent_sequence_ids=np.array([0, 0, 0]),
group_indices=[(0, 1), (1, 2), (2, 3)],
token_ids=np.array([np.array([100]), np.array([101]), np.array([102])]),
logprobs=np.array([np.array([-0.1]), np.array([-0.2]), np.array([-0.3])]),
top_token_ids=np.array([[[100]], [[101]], [[102]]]),
top_logprobs=np.array([[[-0.1]], [[-0.2]], [[-0.3]]]),
num_new_tokens=np.array([1, 1, 1]),
num_top_tokens=np.array([1, 1, 1]),
cumulative_logprobs=np.array([-0.1, -0.2, -0.3]),
finish_reason=np.array([0, 0, 0]),
truncation_indices=np.array([0, 0, 0]),
current_token_indices=[1, 1, 1],
trace_ids=np.array([0, 1, 2])
)
def test_collate(self):
self.generation_output.collate()
self.assertEqual(self.generation_output.eos_info.tolist(), [[0, 1], [0, 1], [0, 1]])
def test_pad_output(self):
max_generated_tokens = 3
self.generation_output.pad_output(max_generated_tokens)
self.assertEqual(self.generation_output.token_ids.shape, (3, 3))
def _assert_concatenated_attrs(self, original, new, concatenated, axis=0):
expected = np.concatenate([original, new], axis=axis)
np.testing.assert_array_equal(concatenated, expected)
def test_concatenate_normal_no_overlap(self):
self.base_new_output = GenerationOutput(
sequence_ids=np.array([3, 4, 5]),
parent_sequence_ids=np.array([3, 3, 3]),
group_indices=[(0, 1), (1, 2), (2, 3)],
token_ids=np.array([[200], [201], [202]]),
logprobs=np.array([[-0.4], [-0.5], [-0.6]]),
top_token_ids=np.array([[[200]], [[201]], [[202]]]),
top_logprobs=np.array([[[-0.4]], [[-0.5]], [[-0.6]]]),
num_new_tokens=np.array([1, 1, 1]),
num_top_tokens=np.array([1, 1, 1]),
cumulative_logprobs=np.array([-0.4, -0.5, -0.6]),
finish_reason=np.array([1, 1, 1]),
truncation_indices=np.array([1, 1, 1]),
current_token_indices=[2, 2, 2],
trace_ids=np.array([3, 4, 5])
)
max_generated_tokens = 2
new_output = self.base_new_output
self.generation_output.concatenate_output(new_output, max_generated_tokens)
expected_seq_ids = np.array([0, 1, 2, 3, 4, 5])
np.testing.assert_array_equal(self.generation_output.sequence_ids, expected_seq_ids)
expected_group_indices = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
self.assertEqual(self.generation_output.group_indices, expected_group_indices)
def test_make_empty(self):
empty_output = GenerationOutput.make_empty()
self.assertEqual(empty_output.sequence_ids.shape, (0,))
self.assertEqual(empty_output.group_indices, [])
self.assertEqual(empty_output.token_ids.shape, (0, 0))
self.assertEqual(empty_output.eos_info.shape, (0,))
def test_remove(self):
self.generation_output.collate()
self.generation_output.remove(np.array([1]))
np.testing.assert_array_equal(self.generation_output.sequence_ids, np.array([0, 2]))
self.assertEqual(self.generation_output.group_indices, [(0, 1), (1, 2)])
self.assertEqual(self.generation_output.token_ids.shape[0], 2)
self.assertEqual(self.generation_output.eos_info.shape[0], 2)
if __name__ == '__main__':
unittest.main()