# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
import sys
import types

import torch
import torch_npu


def accelerator_getattr(module, fallback_module):
    def __getattr__(name):
        if hasattr(fallback_module, name):
            attr = getattr(fallback_module, name)
            setattr(module, name, attr)
            return attr
        else:
            raise AttributeError(f'module {module} and {fallback_module} has no attribute {name}.')

    return __getattr__


def set_accelerator_compatible(fallback_module=torch.npu):
    accelerator_module = types.ModuleType('torch.accelerator')
    accelerator_module.__doc__ = 'Fallback accelerator module that delegates to torch.npu'
    for attr in dir(torch.accelerator):
        if attr.startswith('__'):
            continue
        setattr(accelerator_module, attr, getattr(torch.accelerator, attr))

    accelerator_module.__getattr__ = accelerator_getattr(accelerator_module, fallback_module)
    torch.accelerator = accelerator_module
    sys.modules['torch.accelerator'] = accelerator_module