import unittest
import numpy as np
from mindie_llm.text_generator.utils.sampling_output import SamplingOutput
class TestSamplingOutput(unittest.TestCase):
def setUp(self):
self.token_ids = np.array([
[10, 20, 3, 40, 50],
[11, 21, 31, 41, 51],
[3, 12, 22, 32, 42]
])
self.num_new_tokens = np.array([5, 5, 5])
self.sampling_output = SamplingOutput(
sequence_ids=np.array([0, 1, 2]),
parent_sequence_ids=np.array([0, 1, 2]),
group_indices=[(0, 1), (1, 2), (2, 3)],
repeating_indices=np.zeros(3),
token_ids=self.token_ids.copy(),
logprobs=np.zeros((3, 5)),
top_token_ids=np.zeros((3, 5)),
top_logprobs=np.zeros((3, 5)),
cumulative_logprobs=np.zeros(3),
num_new_tokens=self.num_new_tokens.copy()
)
def test_truncate_after_eos(self):
eos_id = 3
self.sampling_output.truncate_after_eos(eos_id)
self.assertEqual(self.sampling_output.num_new_tokens[0], 3)
np.testing.assert_array_equal(
self.sampling_output.token_ids[0],
np.array([10, 20, 3, 0, 0])
)
self.assertEqual(self.sampling_output.num_new_tokens[1], 5)
np.testing.assert_array_equal(
self.sampling_output.token_ids[1],
np.array([11, 21, 31, 41, 51])
)
self.assertEqual(self.sampling_output.num_new_tokens[2], 1)
np.testing.assert_array_equal(
self.sampling_output.token_ids[2],
np.array([3, 0, 0, 0, 0])
)
def test_truncate_with_multiple_eos(self):
self.sampling_output.token_ids[1] = np.array([10, 3, 3, 3, 10])
self.sampling_output.num_new_tokens[1] = 5
self.sampling_output.truncate_after_eos(eos_token_id=3)
self.assertEqual(self.sampling_output.num_new_tokens[1], 2)
np.testing.assert_array_equal(
self.sampling_output.token_ids[1],
np.array([10, 3, 0, 0, 0])
)
if __name__ == '__main__':
unittest.main()