import unittest
from unittest.mock import MagicMock, patch
from mindie_llm.text_generator.adapter import get_generator_backend
from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch
class TestAdapter(unittest.TestCase):
@patch('mindie_llm.text_generator.adapter.generator_torch.GeneratorTorch',
return_value=MagicMock(GeneratorTorch))
def test_get_generator_torch(self, mock_generator_torch):
backend_type = 'atb'
generator_backend = get_generator_backend({'backend_type': backend_type})
self.assertIsInstance(generator_backend, GeneratorTorch)
def test_get_generator_backend_exception(self):
backend_type = 'xxx'
with self.assertRaises(NotImplementedError):
get_generator_backend({'backend_type': backend_type})
if __name__ == '__main__':
unittest.main()