__all__ = ["rebuild_npu_event", "rebuild_npu_tensor"]
import multiprocessing
import torch
from torch.multiprocessing.reductions import (
shared_cache,
rebuild_storage_filename,
rebuild_storage_empty,
rebuild_storage_fd,
StorageWeakRef,
fd_id,
rebuild_tensor,
storage_from_cache,
rebuild_meta_tensor,
reduce_nested_tensor,
)
import torch_npu
def rebuild_npu_event(device, handle):
return torch.npu.Event.from_ipc_handle(device, handle)
def _npu_reduce_event(event):
handle = event.ipc_handle()
return (rebuild_npu_event, (event.device, handle))
def rebuild_npu_tensor(
tensor_cls,
tensor_size,
tensor_stride,
tensor_offset,
storage_cls,
dtype,
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
requires_grad,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
):
if storage_handle is None or storage_size_bytes == 0:
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
else:
storage = storage_from_cache(
storage_cls, (storage_handle, storage_offset_bytes)
)
if storage is None:
torch_npu.npu._lazy_init()
storage = storage_cls._new_shared_npu(
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
)
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
storage
)
else:
storage_cls._release_ipc_counter_npu(
ref_counter_handle, ref_counter_offset, device=storage_device
)
_storage = (
storage
if isinstance(storage, torch.UntypedStorage)
else storage._untyped_storage
)
t = torch._utils._rebuild_tensor(
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
tensor_offset,
tensor_size,
tensor_stride,
)
if tensor_cls == torch.nn.parameter.Parameter:
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
else:
t.requires_grad = requires_grad
return t
def _npu_reduce_tensor(tensor):
from torch.nested._internal.nested_tensor import NestedTensor
if tensor.is_nested and not isinstance(tensor, NestedTensor):
return reduce_nested_tensor(tensor)
storage = tensor._typed_storage()
if tensor.requires_grad and not tensor.is_leaf:
raise RuntimeError(
"Cowardly refusing to serialize non-leaf tensor which requires_grad, "
"since autograd does not support crossing process boundaries. "
"If you just want to transfer the data, call detach() on the tensor "
"before serializing (e.g., putting it on the queue)."
)
torch._namedtensor_internals.check_serializing_named_tensor(tensor)
torch.utils.hooks.warn_if_has_hooks(tensor)
if storage._untyped_storage.device.type == "npu":
(
device,
handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
) = storage._share_npu_()
tensor_offset = tensor.storage_offset()
shared_cache[handle] = StorageWeakRef(storage)
return (
rebuild_npu_tensor,
(
type(tensor),
tensor.size(),
tensor.stride(),
tensor_offset,
type(storage),
tensor.dtype,
device,
handle,
storage_size_bytes,
storage_offset_bytes,
tensor.requires_grad,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
),
)
elif storage._untyped_storage.device.type == "meta":
return (
rebuild_meta_tensor,
(
type(tensor),
tensor.size(),
tensor.stride(),
tensor.storage_offset(),
tensor.dtype,
tensor.untyped_storage().size(),
tensor.requires_grad,
),
)
metadata = (
tensor.storage_offset(),
tensor.size(),
tensor.stride(),
tensor.requires_grad,
)
return (rebuild_tensor, (type(tensor), storage, metadata))
def _npu_reduce_storage(storage):
from torch.multiprocessing import get_sharing_strategy
if storage.is_npu:
raise RuntimeError(
"Cannot pickle NPU storage; try pickling a NPU tensor instead"
)
elif get_sharing_strategy() == "file_system":
metadata = storage._share_filename_cpu_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
if isinstance(storage, torch.TypedStorage):
metadata += (storage.dtype,)
storage._shared_incref()
elif storage.size() == 0:
return (rebuild_storage_empty, (type(storage),))
else:
fd, size = storage._share_fd_cpu_()
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)
rebuild = rebuild_storage_fd
shared_cache[cache_key] = StorageWeakRef(storage)
return (rebuild, (type(storage),) + metadata)
def _add_reductions_methods():
multiprocessing.reduction.register(torch.npu.Event, _npu_reduce_event)
torch.multiprocessing.reductions.reduce_event = _npu_reduce_event
torch.multiprocessing.reductions.reduce_tensor = _npu_reduce_tensor
torch.multiprocessing.reductions.reduce_storage = _npu_reduce_storage
torch.multiprocessing.reductions.init_reductions()