import abc
import ctypes
from typing import Protocol, Union
from typing_extensions import Self
import numpy
from ..lib import runtime as rt
class TorchTensor(Protocol):
nbytes: int
def data_ptr(self) -> int:
...
def ravel(self) -> Self:
...
class MemoryHandle(abc.ABC):
@abc.abstractmethod
def copy_to_device(self) -> int:
raise NotImplementedError
@abc.abstractmethod
def copy_from_device(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def release_memory(self) -> None:
raise NotImplementedError
class ByteArrayHandle(MemoryHandle):
def __init__(self, data: Union[bytearray, bytes]):
self.data = data if isinstance(data, bytearray) else bytearray(data)
self.buffer_ptr = ctypes.cast(ctypes.pointer(ctypes.c_char.from_buffer(self.data)), ctypes.c_void_p)
def copy_to_device(self) -> int:
self.handle = rt.copy_data_to_device(self.buffer_ptr, len(self.data))
return int(self.handle.value)
def copy_from_device(self) -> None:
rt.copy_data_from_device(self.buffer_ptr, self.handle, len(self.data))
def release_memory(self) -> None:
rt.free(self.handle)
class NumpyArrayHandle(MemoryHandle):
def __init__(self, array: numpy.ndarray):
self.array = array
def copy_to_device(self) -> int:
flat = self.array.ravel(order="C")
self.handle = rt.copy_data_to_device(flat.ctypes.data_as(ctypes.c_void_p), flat.nbytes)
return int(self.handle.value)
def copy_from_device(self) -> None:
rt.copy_data_from_device(self.array.ctypes.data_as(ctypes.c_void_p), self.handle, self.array.nbytes)
def release_memory(self) -> None:
rt.free(self.handle)
class TorchCpuTensorHandle(MemoryHandle):
def __init__(self, tensor: TorchTensor):
self.tensor = tensor
def copy_to_device(self) -> int:
flat = self.tensor.ravel()
self.handle = rt.copy_data_to_device(ctypes.c_void_p(flat.data_ptr()), flat.nbytes)
return self.handle.value
def copy_from_device(self) -> None:
rt.copy_data_from_device(ctypes.c_void_p(self.tensor.data_ptr()), self.handle, self.tensor.nbytes)
def release_memory(self) -> None:
rt.free(self.handle)
class TorchNpuTensorArgument(MemoryHandle):
def __init__(self, tensor: TorchTensor):
self.tensor = tensor
def copy_to_device(self) -> int:
return self.tensor.data_ptr()
def copy_from_device(self) -> None:
pass
def release_memory(self) -> None:
pass
def resolve_memory_handle(obj) -> MemoryHandle:
if isinstance(obj, MemoryHandle):
return obj
if isinstance(obj, (bytes, bytearray)):
return ByteArrayHandle(obj)
if isinstance(obj, numpy.ndarray):
return NumpyArrayHandle(obj)
try:
import torch
if isinstance(obj, torch.Tensor):
if getattr(obj, "is_cpu", False):
return TorchCpuTensorHandle(obj)
if getattr(obj, "is_npu", False):
return TorchNpuTensorArgument(obj)
except ModuleNotFoundError:
pass
raise RuntimeError(f"Unsupported memory handle of type {obj.__class__.__name__}")