import sys
import types
import torch
import torch_npu
def accelerator_getattr(module, fallback_module):
def __getattr__(name):
if hasattr(fallback_module, name):
attr = getattr(fallback_module, name)
setattr(module, name, attr)
return attr
else:
raise AttributeError(f'module {module} and {fallback_module} has no attribute {name}.')
return __getattr__
def set_accelerator_compatible(fallback_module=torch.npu):
accelerator_module = types.ModuleType('torch.accelerator')
accelerator_module.__doc__ = 'Fallback accelerator module that delegates to torch.npu'
for attr in dir(torch.accelerator):
if attr.startswith('__'):
continue
setattr(accelerator_module, attr, getattr(torch.accelerator, attr))
accelerator_module.__getattr__ = accelerator_getattr(accelerator_module, fallback_module)
torch.accelerator = accelerator_module
sys.modules['torch.accelerator'] = accelerator_module