import enum
import torch
from torch_npu._C import _npu_from_dlpack
from torch_npu._C import _npu_to_dlpack
def _to_dlpack(tensor):
return _npu_to_dlpack(tensor)
def _from_dlpack(ext_tensor) -> 'torch.Tensor':
if hasattr(ext_tensor, '__dlpack__'):
dlpack = ext_tensor.__dlpack__()
else:
dlpack = ext_tensor
return _npu_from_dlpack(dlpack)
def _apply_dlpack_patch():
"""Patch torch.utils.dlpack and torch.utils to use torch_npu implementation for NPU tensors"""
import torch.utils.dlpack as torch_dlpack
_original_to_dlpack = torch_dlpack.to_dlpack
_original_from_dlpack = torch_dlpack.from_dlpack
def create_patched_to_dlpack(module_name):
"""Create a patched to_dlpack function with proper __module__ attribute"""
def patched_to_dlpack(tensor):
"""Patched to_dlpack that uses torch_npu implementation for NPU tensors"""
if hasattr(tensor, 'device') and tensor.device.type == 'npu':
return _to_dlpack(tensor)
return _original_to_dlpack(tensor)
patched_to_dlpack.__module__ = module_name
return patched_to_dlpack
def create_patched_from_dlpack(module_name):
"""Create a patched from_dlpack function with proper __module__ attribute"""
def patched_from_dlpack(ext_tensor):
"""Patched from_dlpack that uses torch_npu implementation when appropriate"""
try:
return _from_dlpack(ext_tensor)
except Exception:
return _original_from_dlpack(ext_tensor)
patched_from_dlpack.__module__ = module_name
return patched_from_dlpack
torch_dlpack.to_dlpack = create_patched_to_dlpack('torch.utils.dlpack')
torch_dlpack.from_dlpack = create_patched_from_dlpack('torch.utils.dlpack')
if hasattr(torch.utils, 'to_dlpack'):
_original_torch_utils_to_dlpack = torch.utils.to_dlpack
torch.utils.to_dlpack = create_patched_to_dlpack('torch.utils')
if hasattr(torch.utils, 'from_dlpack'):
_original_torch_utils_from_dlpack = torch.utils.from_dlpack
torch.utils.from_dlpack = create_patched_from_dlpack('torch.utils')
if hasattr(torch, 'from_dlpack'):
_original_torch_from_dlpack = torch.from_dlpack
torch.from_dlpack = create_patched_from_dlpack('torch')
if hasattr(torch, 'to_dlpack'):
_original_torch_to_dlpack = torch.to_dlpack
torch.to_dlpack = create_patched_to_dlpack('torch')
if not hasattr(torch, '__all__'):
torch.__all__ = []
if 'to_dlpack' not in torch.__all__:
torch.__all__.append('to_dlpack')
if hasattr(torch, 'from_dlpack'):
if not hasattr(torch, '__all__'):
torch.__all__ = []
if 'from_dlpack' not in torch.__all__:
torch.__all__.append('from_dlpack')