import shutil

from unittest.mock import Mock

import os

import tempfile

import hashlib

import unittest

from unittest.mock import patch

from ascend_deployer.utils import CONFIG_INST

from zipfile import ZipFile, BadZipfile

from ascend_deployer.utils import extract_zip, get_remote_md5, calculate_file_md5, update_obs_config, get_os_list, \

    get_pkg_list





class TestGetRemoteMD5(unittest.TestCase):



    @patch('urllib.request.Request')

    @patch('urllib.request.urlopen')

    def test_get_remote_md5_success(self, mock_urlopen, mock_request):

        mock_response = Mock()

        mock_response.getheader.return_value = '1234567890abcdef1234567890abcdef'

        mock_urlopen.return_value.__enter__.return_value = mock_response



        result = get_remote_md5('http://example.com', 'http://referer.com')

        self.assertEqual(result, '1234567890abcdef1234567890abcdef')



    @patch('urllib.request.Request')

    @patch('urllib.request.urlopen')

    def test_get_remote_md5_fail_invalid_etag(self, mock_urlopen, mock_request):

        mock_response = Mock()

        mock_response.getheader.return_value = 'invalid_md5'

        mock_urlopen.return_value.__enter__.return_value = mock_response



        result = get_remote_md5('http://example.com', 'http://referer.com')

        self.assertIsNone(result)



    @patch('urllib.request.Request')

    @patch('urllib.request.urlopen')

    def test_get_remote_md5_fail_exception(self, mock_urlopen, mock_request):

        mock_urlopen.side_effect = Exception('Test exception')



        result = get_remote_md5('http://example.com', 'http://referer.com')

        self.assertIsNone(result)





class TestCalculateFileMD5(unittest.TestCase):

    def setUp(self):

        self.test_file_content = b"Hello, world!"

        self.test_file = tempfile.NamedTemporaryFile(delete=False)

        self.test_file.write(self.test_file_content)

        self.test_file.close()



    def tearDown(self):

        os.unlink(self.test_file.name)



    def test_calculate_file_md5(self):

        # Test normal case

        expected_md5 = hashlib.md5(self.test_file_content).hexdigest()

        self.assertEqual(calculate_file_md5(self.test_file.name), expected_md5)



        # Test with different chunk size

        self.assertEqual(calculate_file_md5(self.test_file.name, chunk_size=1), expected_md5)



        # Test with non-existent file

        with self.assertRaises(RuntimeError):

            calculate_file_md5("non_existent_file")



        # Test with directory

        with self.assertRaises(RuntimeError):

            calculate_file_md5("/")





class TestUpdateObsConfig(unittest.TestCase):

    @patch('ascend_deployer.utils.get_remote_md5')

    @patch('ascend_deployer.utils.calculate_file_md5')

    @patch('ascend_deployer.utils.extract_zip')

    @patch('ascend_deployer.downloader.parallel_file_downloader.ParallelDownloader')

    def test_update_obs_config_no_remote_md5(self, mock_downloader, mock_extract_zip, mock_calculate_file_md5,

                                             mock_get_remote_md5):

        # Test scenario where no remote MD5 is available

        mock_get_remote_md5.return_value = None

        update_obs_config()

        mock_get_remote_md5.assert_called_once()

        mock_downloader.assert_not_called()

        mock_calculate_file_md5.assert_not_called()

        mock_extract_zip.assert_not_called()



    @patch('ascend_deployer.utils.get_remote_md5')

    @patch('ascend_deployer.utils.calculate_file_md5')

    @patch('ascend_deployer.utils.extract_zip')

    @patch('ascend_deployer.downloader.parallel_file_downloader.ParallelDownloader')

    def test_update_obs_config_md5_match(self, mock_downloader, mock_extract_zip, mock_calculate_file_md5,

                                         mock_get_remote_md5):

        # Test scenario where remote MD5 matches local MD5

        mock_get_remote_md5.return_value = 'same_md5'

        # Mock the method get_obs_downloader_config to return specific values

        with patch.object(CONFIG_INST, 'get_obs_downloader_config') as mock_config:

            mock_config.return_value = ('url', 'same_md5')

            update_obs_config()

        mock_get_remote_md5.assert_called_once()

        mock_downloader.assert_not_called()

        mock_calculate_file_md5.assert_not_called()

        mock_extract_zip.assert_not_called()





class TestGetOsList(unittest.TestCase):



    @patch('os.listdir')

    @patch('downloader.download_util.get_obs_downloader_path')

    def test_get_os_list(self, mock_get_obs_downloader_path, mock_listdir):

        # 设置模拟函数的返回值

        mock_get_obs_downloader_path.return_value = '/path/to/downloader'

        mock_listdir.return_value = ['os1', 'os2', 'os3']



        # 调用测试函数

        result = get_os_list()



        # 验证结果

        self.assertEqual(result, ['os1', 'os2', 'os3'])





class TestExtractZip(unittest.TestCase):

    def setUp(self):

        self.test_dir = tempfile.mkdtemp()

        self.test_zip = os.path.join(self.test_dir, 'test.zip')

        self.test_file = os.path.join(self.test_dir, 'test.txt')

        with open(self.test_file, 'w') as f:

            f.write('test')

        with ZipFile(self.test_zip, 'w') as z:

            z.write(self.test_file, os.path.basename(self.test_file))



    def tearDown(self):

        shutil.rmtree(self.test_dir)



    def test_extract_zip_success(self):

        members = extract_zip(self.test_zip, self.test_dir)

        self.assertTrue(os.path.exists(os.path.join(self.test_dir, 'test.txt')))

        self.assertEqual(members, ['test.txt'])



    def test_extract_zip_filter_rule(self):

        filter_rule = lambda file, members: [m for m in members if 'test' in m]

        members = extract_zip(self.test_zip, self.test_dir, filter_rule)

        self.assertTrue(os.path.exists(os.path.join(self.test_dir, 'test.txt')))

        self.assertEqual(members, ['test.txt'])





if __name__ == '__main__':

    unittest.main()