910e62b5创建于 1月15日历史提交
#!/usr/bin/env vpython3

# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unit tests for install.py."""

import io
from pathlib import Path
import unittest
import unittest.mock

import install
from pyfakefs import fake_filesystem_unittest

# pylint: disable=protected-access


class InstallTest(fake_filesystem_unittest.TestCase):
    """Tests for the extension installation script."""

    def setUp(self):
        """Sets up the test environment."""
        self.setUpPyfakefs(additional_skip_names=['subprocess'])
        self.tmpdir = '/tmp/test'
        self.project_root = Path(self.tmpdir) / 'src'
        self.fs.create_dir(self.project_root)

        self.source_extensions_dir = self.project_root / 'agents' / 'extensions'
        self.fs.create_dir(self.source_extensions_dir)
        self.install_script_path = self.source_extensions_dir / 'install.py'
        self.fs.create_file(self.install_script_path)

        self.testing_extensions_dir = (self.project_root / 'agents' /
                                       'testing' / 'extensions')
        self.fs.create_dir(self.testing_extensions_dir)

        self.internal_extensions_dir = (self.project_root / 'internal' /
                                        'agents' / 'extensions')
        self.fs.create_dir(self.internal_extensions_dir)

        # Create sample extensions
        self.extension1_dir = self.source_extensions_dir / 'sample_1'
        self.fs.create_dir(self.extension1_dir)
        self.fs.create_file(
            self.extension1_dir / 'gemini-extension.json',
            contents='{"name": "sample_1", "version": "1.0.0"}',
        )

        self.test_extension_dir = self.testing_extensions_dir / 'test_sample'
        self.fs.create_dir(self.test_extension_dir)
        self.fs.create_file(
            self.test_extension_dir / 'gemini-extension.json',
            contents='{"name": "test_sample", "version": "1.0.0"}',
        )

        self.mock_run_command_patcher = unittest.mock.patch(
            'install._run_command')
        self.mock_run_command = self.mock_run_command_patcher.start()
        self.addCleanup(self.mock_run_command_patcher.stop)

        self.mock_check_version = unittest.mock.patch(
            'install.check_gemini_version')
        self.mock_check_version.start()
        self.addCleanup(self.mock_check_version.stop)

        self.mock_get_project_root_patcher = unittest.mock.patch(
            'install.get_project_root')
        self.mock_get_project_root = self.mock_get_project_root_patcher.start()
        self.addCleanup(self.mock_get_project_root_patcher.stop)
        self.mock_get_project_root.return_value = self.project_root

        self.mock_subprocess_run_patcher = unittest.mock.patch(
            'subprocess.run')
        self.mock_subprocess_run = self.mock_subprocess_run_patcher.start()
        self.addCleanup(self.mock_subprocess_run_patcher.stop)

    def test_find_extensions_dir_for_extension(self):
        """Tests finding an extension directory."""
        extensions_dirs = install.get_extensions_dirs(self.project_root)
        # Extension in source directory
        ext_dir = install.find_extensions_dir_for_extension(
            'sample_1', extensions_dirs)
        self.assertEqual(ext_dir, self.source_extensions_dir)

        # Extension in internal directory
        internal_extension_dir = self.internal_extensions_dir / 'internal_ext'
        self.fs.create_dir(internal_extension_dir)
        self.fs.create_file(
            internal_extension_dir / 'gemini-extension.json',
            contents='{"name": "internal_ext", "version": "1.0.0"}',
        )
        extensions_dirs = install.get_extensions_dirs(self.project_root)
        ext_dir = install.find_extensions_dir_for_extension(
            'internal_ext', extensions_dirs)
        self.assertEqual(ext_dir, self.internal_extensions_dir)

        # Extension in testing directory
        extensions_dirs = install.get_extensions_dirs(
            self.project_root,
            extra_extensions_dirs=[self.testing_extensions_dir])
        ext_dir = install.find_extensions_dir_for_extension(
            'test_sample', extensions_dirs)
        self.assertEqual(ext_dir, self.testing_extensions_dir)

    def test_get_extensions_dirs(self):
        """Tests that get_extensions_dirs returns correct directories."""
        # By default, test extensions should not be included
        dirs = install.get_extensions_dirs(self.project_root)
        self.assertIn(self.source_extensions_dir, dirs)
        self.assertIn(self.internal_extensions_dir, dirs)
        self.assertNotIn(self.testing_extensions_dir, dirs)

        dirs = install.get_extensions_dirs(
            self.project_root,
            extra_extensions_dirs=[self.testing_extensions_dir])
        self.assertIn(self.source_extensions_dir, dirs)
        self.assertIn(self.internal_extensions_dir, dirs)
        self.assertIn(self.testing_extensions_dir, dirs)

    def test_get_extensions_dirs_no_project_root(self):
        """Tests get_extensions_dirs() when no project root is found."""
        extensions_dirs = install.get_extensions_dirs(None)
        self.assertEqual(extensions_dirs, [])

    @unittest.mock.patch('install.find_extensions_dir_for_extension')
    def test_add_extension_copy(self, mock_find_dir):
        """Tests add command with copy."""
        mock_find_dir.return_value = self.source_extensions_dir
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'add', '--copy', 'sample_1']):
            install.main()
        self.mock_run_command.assert_called_once_with([
            'gemini', 'extensions', 'install',
            str(self.source_extensions_dir / 'sample_1')
        ],
                                                      skip_prompt=False)

    @unittest.mock.patch('install.find_extensions_dir_for_extension')
    def test_add_extension_link(self, mock_find_dir):
        """Tests add command."""
        mock_find_dir.return_value = self.source_extensions_dir
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'add', 'sample_1']):
            install.main()
        self.mock_run_command.assert_called_once_with(
            [
                'gemini', 'extensions', 'link',
                str(self.source_extensions_dir / 'sample_1')
            ],
            skip_prompt=False)

    @unittest.mock.patch('install.find_extensions_dir_for_extension')
    def test_add_extension_skip_prompt(self, mock_find_dir):
        """Tests that the skip_prompt flag is accepted."""
        mock_find_dir.return_value = self.source_extensions_dir
        with unittest.mock.patch(
                'sys.argv',
            ['install.py', 'add', '--skip-prompt', 'sample_1']):
            install.main()
        self.mock_run_command.assert_called_once_with(
            [
                'gemini', 'extensions', 'link',
                str(self.source_extensions_dir / 'sample_1')
            ],
            skip_prompt=True)

    def test_add_test_extension(self):
        """Tests add command with a test extension."""
        with unittest.mock.patch('sys.argv', [
                'install.py', '--extra-extensions-dir',
                str(self.testing_extensions_dir), 'add', 'test_sample'
        ]):
            install.main()
        self.mock_run_command.assert_called_once_with(
            [
                'gemini', 'extensions', 'link',
                str(self.testing_extensions_dir / 'test_sample')
            ],
            skip_prompt=False)

    def test_add_test_extension_without_flag_fails(self):
        """Tests add command with a test extension."""
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'add', 'test_sample']):
            with self.assertRaises(SystemExit):
                install.main()


    def test_add_invalid_extension(self):
        """Tests add command with an invalid extension."""
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'add', 'nonexistent']):
            with unittest.mock.patch('sys.stderr',
                                     new_callable=io.StringIO) as mock_stderr:
                with self.assertRaises(SystemExit) as e:
                    install.main()
                self.assertEqual(e.exception.code, 1)
                self.assertIn("Extension 'nonexistent' not found.",
                              mock_stderr.getvalue())
        self.mock_run_command.assert_not_called()

    def test_update_extension(self):
        """Tests update command."""
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'update', 'sample_1']):
            install.main()
        self.mock_run_command.assert_called_once_with(
            ['gemini', 'extensions', 'update', 'sample_1'], skip_prompt=False)

    def test_update_all_extensions(self):
        """Tests update command with no extension specified."""
        with unittest.mock.patch('sys.argv', ['install.py', 'update']):
            install.main()
        self.mock_run_command.assert_called_once_with(
            ['gemini', 'extensions', 'update', '--all'])

    def test_remove_extension(self):
        """Tests remove command."""
        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'remove', 'sample-1']):
            install.main()
        self.mock_run_command.assert_called_once_with(
            ['gemini', 'extensions', 'uninstall', 'sample-1'])

    @unittest.mock.patch('pathlib.Path.home')
    def test_remove_legacy_extension(self, mock_home):
        """Tests remove command for legacy extensions with underscores."""
        fake_home = Path(self.tmpdir) / 'home'
        mock_home.return_value = fake_home

        # Set up a legacy extension
        legacy_extension_dir = (install.get_global_extension_dir() /
                                'my_legacy_ext')
        self.fs.create_dir(legacy_extension_dir)
        self.assertTrue(legacy_extension_dir.exists())

        with unittest.mock.patch('sys.argv',
                                 ['install.py', 'remove', 'my_legacy_ext']):
            install.main()

        self.mock_run_command.assert_not_called()
        self.assertFalse(legacy_extension_dir.exists())

    def test_list_extensions(self):
        """Tests the list command, showing all extensions."""
        self.mock_subprocess_run.return_value.stdout = """
✓ user-enabled (1.0.0)
 ID: abc
 Path: /path/to/user-enabled
 Source: /path/to/source/user-enabled (Type: link)
 Enabled (User): true
 Enabled (Workspace): false

✓ workspace-enabled (2.0.0)
 ID: def
 Path: /path/to/workspace-enabled
 Source: /path/to/source/workspace-enabled (Type: local)
 Enabled (User): false
 Enabled (Workspace): true

✓ both-enabled (3.0.0)
 ID: ghi
 Path: /path/to/both-enabled
 Source: /path/to/source/both-enabled (Type: local)
 Enabled (User): true
 Enabled (Workspace): true
        """
        self.mock_subprocess_run.return_value.returncode = 0

        with unittest.mock.patch('sys.argv', ['install.py', 'list']):
            with unittest.mock.patch('sys.stdout',
                                     new_callable=io.StringIO) as mock_stdout:
                install.main()
                output = mock_stdout.getvalue()

        expected_extensions = {
            'workspace-enabled':
            install.ExtensionInfo(name='workspace-enabled',
                                  installed='2.0.0',
                                  linked=False,
                                  enabled_for_workspace=True),
            'user-enabled':
            install.ExtensionInfo(name='user-enabled',
                                  installed='1.0.0',
                                  linked=True,
                                  enabled_for_workspace=False),
            'both-enabled':
            install.ExtensionInfo(name='both-enabled',
                                  installed='3.0.0',
                                  linked=False,
                                  enabled_for_workspace=True),
            'sample_1':
            install.ExtensionInfo(name='sample_1', available='1.0.0'),
        }
        with unittest.mock.patch('sys.stdout',
                                 new_callable=io.StringIO) as expected_stdout:
            install._print_extensions_table(expected_extensions)
            expected_output = expected_stdout.getvalue()

        self.assertEqual(output, expected_output)

    def test_list_extensions_no_installed(self):
        """Tests the list command with no installed extensions."""
        self.mock_subprocess_run.return_value.stdout = ''
        self.mock_subprocess_run.return_value.returncode = 0

        with unittest.mock.patch('sys.argv', ['install.py', 'list']):
            with unittest.mock.patch('sys.stdout',
                                     new_callable=io.StringIO) as mock_stdout:
                install.main()
                output = mock_stdout.getvalue()

        expected_extensions = {
            'sample_1': install.ExtensionInfo(name='sample_1',
                                              available='1.0.0'),
        }
        with unittest.mock.patch('sys.stdout',
                                 new_callable=io.StringIO) as expected_stdout:
            install._print_extensions_table(expected_extensions)
            expected_output = expected_stdout.getvalue()

        self.assertEqual(output, expected_output)

    def test_list_extensions_empty_table(self):
        """Tests the list command with no available or installed extensions."""
        self.mock_subprocess_run.return_value.stdout = ''
        self.mock_subprocess_run.return_value.returncode = 0

        # Remove the sample extension created in setUp
        self.fs.remove_object(str(self.extension1_dir))

        with unittest.mock.patch('sys.argv', ['install.py', 'list']):
            with unittest.mock.patch('sys.stdout',
                                     new_callable=io.StringIO) as mock_stdout:
                install.main()
                output = mock_stdout.getvalue()
                expected_output = (
                    'EXTENSION  AVAILABLE  INSTALLED  LINKED  ENABLED\n'
                    '---------  ---------  ---------  ------  -------\n')
                self.assertEqual(output, expected_output)

    def test_print_extensions_table_formatting(self):
        """Tests the formatting of the extensions table."""
        extensions_data = {
            'ext_a':
            install.ExtensionInfo(name='ext_a',
                                  available='1.0.0',
                                  installed='1.0.0',
                                  linked=True,
                                  enabled_for_workspace=True),
            'another_extension':
            install.ExtensionInfo(name='another_extension',
                                  available='2.0.0',
                                  installed='-',
                                  linked=False,
                                  enabled_for_workspace=False),
            'third_ext':
            install.ExtensionInfo(name='third_ext',
                                  available='-',
                                  installed='3.0.0',
                                  linked=False,
                                  enabled_for_workspace=True),
        }
        expected_output = (
            'EXTENSION          AVAILABLE  INSTALLED  LINKED  ENABLED  \n'
            '-----------------  ---------  ---------  ------  ---------\n'
            'another_extension  2.0.0      -          no      -        \n'
            'ext_a              1.0.0      1.0.0      yes     workspace\n'
            'third_ext          -          3.0.0      no      workspace\n')
        with unittest.mock.patch('sys.stdout',
                                 new_callable=io.StringIO) as mock_stdout:
            install._print_extensions_table(extensions_data)
            self.assertEqual(mock_stdout.getvalue(), expected_output)

    def test_find_extensions_dir_for_nonexistent_extension(self):
        """Tests finding a non-existent extension."""
        extensions_dirs = install.get_extensions_dirs(self.project_root)
        ext_dir = install.find_extensions_dir_for_extension(
            'nonexistent', extensions_dirs)
        self.assertIsNone(ext_dir)

    @unittest.mock.patch('install.find_extensions_dir_for_extension')
    def test_fix_extensions(self, mock_find_dir):
        """Tests fix command."""
        mock_find_dir.return_value = self.source_extensions_dir
        project_extensions_dir = self.project_root / '.gemini' / 'extensions'
        self.fs.create_dir(project_extensions_dir)
        self.fs.create_file(
            project_extensions_dir / 'sample_1' / 'gemini-extension.json',
            contents='{"name": "sample_1", "version": "1.0.0"}',
        )

        with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
            install.main()

        calls = [
            unittest.mock.call([
                'gemini', 'extensions', 'link',
                str(self.source_extensions_dir / 'sample_1')
            ]),
            unittest.mock.call([
                'gemini', 'extensions', 'disable', 'sample_1', '--scope=User'
            ]),
            unittest.mock.call([
                'gemini', 'extensions', 'enable', 'sample_1',
                '--scope=Workspace'
            ]),
        ]
        self.mock_run_command.assert_has_calls(calls)
        self.assertFalse(project_extensions_dir.exists())

    def test_fix_extensions_no_project_dir(self):
        """Tests fix command when no project-level directory exists."""
        with unittest.mock.patch('sys.stdout',
                                 new_callable=io.StringIO) as mock_stdout:
            with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
                install.main()
            self.assertIn('No project-level extensions found to fix.',
                          mock_stdout.getvalue())

        self.mock_run_command.assert_not_called()

    def test_fix_extensions_no_extensions(self):
        """Tests fix command when no project-level extensions are found."""
        project_extensions_dir = self.project_root / '.gemini' / 'extensions'
        self.fs.create_dir(project_extensions_dir)

        with unittest.mock.patch('sys.stdout',
                                 new_callable=io.StringIO) as mock_stdout:
            with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
                install.main()
            self.assertIn(
                'No valid project-level extensions found.',
                mock_stdout.getvalue(),
            )

        self.mock_run_command.assert_not_called()
        self.assertFalse(project_extensions_dir.exists())

    @unittest.mock.patch('pathlib.Path.home')
    def test_fix_skips_existing_user_extension(self, mock_home):
        """Tests that fix skips extensions that already exist for the user."""
        fake_home = Path(self.tmpdir) / 'home'
        mock_home.return_value = fake_home

        # Set up a user-level extension
        (install.get_global_extension_dir() / 'sample_1').mkdir(parents=True)

        # Create a project-level extension with the same name
        project_extensions_dir = self.project_root / '.gemini' / 'extensions'
        self.fs.create_dir(project_extensions_dir)
        self.fs.create_file(
            project_extensions_dir / 'sample_1' / 'gemini-extension.json',
            contents='{"name": "sample_1", "version": "1.0.0"}',
        )

        with unittest.mock.patch('sys.stderr',
                                 new_callable=io.StringIO) as mock_stderr:
            with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
                install.main()
            self.assertIn(
                'Warning: User extension "sample_1" already exists.',
                mock_stderr.getvalue(),
            )

        self.mock_run_command.assert_not_called()
        self.assertFalse(project_extensions_dir.exists())

    def test_prompt_for_fix(self):
        """Tests that the user is prompted to run fix."""
        project_extensions_dir = self.project_root / '.gemini' / 'extensions'
        self.fs.create_dir(project_extensions_dir)
        with unittest.mock.patch('sys.stderr',
                                 new_callable=io.StringIO) as mock_stderr:
            with unittest.mock.patch('sys.argv', ['install.py', 'list']):
                install.main()
            self.assertIn('WARNING: Project-level extensions are deprecated.',
                          mock_stderr.getvalue())

    def test_get_project_root(self):
        """Tests the get_project_root function."""
        with unittest.mock.patch('install._PROJECT_ROOT', self.project_root):
            project_root = install.get_project_root()
            self.assertEqual(project_root, self.project_root)


if __name__ == '__main__':
    unittest.main()