from pathlib import Path
from .base import AccurateTest


class DirectoryMappingStrategy(AccurateTest):
    """
    Map the modified files to the corresponding test cases
    """
    mapping_list = {
        'contrib': 'test/contrib',
        'cpp_extension': 'test/cpp_extensions',
        'distributed': 'test/distributed',
        'fx': 'test/test_fx.py',
        'optim': 'test/optim',
        'profiler': 'test/profiler',
        'onnx': 'test/onnx',
        'utils': 'test/test_utils.py',
        'testing': 'test/test_testing.py',
        'jit': 'test/test_jit.py',
        'rpc': 'test/distributed/rpc',
        'meta': 'test/test_fake_tensor.py',
    }

    @staticmethod
    def get_module_name(modify_file):
        module_name = str(Path(modify_file).parts[1])
        if module_name == 'csrc':
            module_name = str(Path(modify_file).parts[2])
        for part in Path(modify_file).parts:
            if part == 'rpc':
                module_name = 'rpc'
        if module_name == 'utils' and Path(modify_file).parts[2] == 'cpp_extension.py':
            module_name = 'cpp_extension'
        return module_name

    def identify(self, modify_file):
        current_all_ut_path = []
        if str(Path(modify_file).parts[0]) == 'torch_npu':
            mapped_ut_path = []
            module_name = self.get_module_name(modify_file)
            if module_name in self.mapping_list:
                mapped_ut_path.append(self.mapping_list[module_name])
            file_name = str(Path(modify_file).stem)
            if file_name in self.mapping_list:
                mapped_ut_path.append(self.mapping_list[file_name])

            for mapped_path in mapped_ut_path:
                if Path.is_file(self.base_dir.joinpath(mapped_path)):
                    current_all_ut_path.append(str(self.base_dir.joinpath(mapped_path)))
                else:
                    current_all_ut_path += [str(i) for i in (self.base_dir.joinpath(mapped_path)).glob('test_*.py')]
        return current_all_ut_path