from typing import Union, List, Tuple
import numpy as np
import ctypes
from llm_datadist_v1.utils import utils
from llm_datadist_v1.status import handle_llm_status
from llm_datadist_v1.data_type import _dwrapper_dtype_to_python_dtype
from llm_datadist_v1 import data_type
from llm_datadist_v1 import llm_wrapper
class TensorDesc(object):
def __init__(self, dtype: data_type.DataType, shape: Union[List[int], Tuple[int]]):
"""
初始化
Args:
dtype: 数据类型
shape: 数据维度信息
"""
utils.check_isinstance("dtype", dtype, data_type.DataType)
utils.check_isinstance("shape", shape, [list, tuple], int)
self._dtype = dtype
self._shape = list(shape)
@property
def dtype(self):
return self._dtype
@property
def shape(self):
return self._shape
def __str__(self):
return f"TensorDesc(dtype={str(self.dtype)}, shape={str(self.shape)})"
class Tensor(object):
def __init__(self, data, tensor_desc: TensorDesc = None):
"""
初始化
Args:
data: 数据
tensor_desc: 描述信息
"""
utils.check_isinstance("data", data, [np.ndarray, Tensor, int])
utils.check_isinstance("tensor_desc", tensor_desc, TensorDesc)
self._tensor_id = 0
if utils.check_type(data, Tensor):
self._tensor_desc = data._tensor_desc
self._tensor_id = llm_wrapper.clone_tensor(data._tensor_id)
elif utils.check_type(data, int):
self._tensor_desc = tensor_desc
self._tensor_id = data
else:
self._init_by_ndarray(data, tensor_desc)
def __del__(self):
if self._tensor_id != 0:
llm_wrapper.destroy_tensor(self._tensor_id)
def __str__(self):
return f"Tensor({self.numpy(True if self._is_inner_dtype_str() else False)},tensor_desc={self._tensor_desc})"
@staticmethod
def from_tensor_tuple(tensor_tuple: Tuple[int, int, List[int]]):
tensor_desc = TensorDesc(_dwrapper_dtype_to_python_dtype[tensor_tuple[1]], tensor_tuple[2])
return Tensor(tensor_tuple[0], tensor_desc)
def _init_by_ndarray(self, data: np.ndarray, tensor_desc: TensorDesc = None):
if tensor_desc:
if list(data.shape) != tensor_desc.shape:
raise RuntimeError(
f"The shape of data:{data.shape} is not same as tensor_desc shape:{tensor_desc.shape}")
desc_np_dtype = data_type.dtype_to_np_dtype.get(tensor_desc.dtype)
if data.dtype != desc_np_dtype:
raise RuntimeError(
f"The dtype of data:{data.dtype} is not same as tensor_desc dtype:{tensor_desc.dtype}")
else:
if data.dtype not in data_type.valid_np_dtypes and not self._is_origin_dtype_str(data.dtype):
raise RuntimeError(
f"The dtype of data:{data.dtype} is not valid, only support {data_type.valid_np_dtypes}")
if tensor_desc:
self._tensor_desc = tensor_desc
elif self._is_origin_dtype_str(data.dtype):
self._tensor_desc = TensorDesc(data_type.DataType.DT_STRING, list(data.shape))
else:
self._tensor_desc = TensorDesc(data_type.np_dtype_to_dtype[data.dtype], list(data.shape))
if self._is_origin_dtype_str(data.dtype):
data = self._convert_raw_str_data(data)
if not data.flags.c_contiguous:
raise RuntimeError("The data is not c_contiguous")
data_ptr = data.ctypes.data_as(ctypes.c_void_p).value
size = data.nbytes
self._tensor_id = llm_wrapper.build_tensor(
data_ptr,
size,
data_type.python_dtype_2_dwrapper_dtype.get(self._tensor_desc.dtype),
list(self._tensor_desc.shape))
def _is_origin_dtype_str(self, dtype):
return np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
def _is_inner_dtype_str(self):
return self._tensor_desc is not None and self._tensor_desc.dtype == data_type.DataType.DT_STRING
def _convert_raw_str_data(self, data):
format_data = data.astype(np.bytes_)
end_point = '\0'.encode('ascii', errors='ignore')
new_data = np.char.add(format_data, end_point)
return new_data
def numpy(self, copy=False):
"""
获取数据的numpy表示
Args:
copy: 是否复制
Returns:
数据的numpy表示
"""
utils.check_isinstance("copy", copy, bool)
if self._is_inner_dtype_str():
if not copy:
raise RuntimeError("String tensor only support when param copy is True.")
return np.array(llm_wrapper.get_string_tensor(self._tensor_id)).reshape(self._tensor_desc.shape)
ret, tensor = llm_wrapper.tensor_get_buffer(self._tensor_id)
handle_llm_status(ret, 'Tensor.numpy', 'Failed to get tensor buffer')
if self._tensor_desc.dtype == data_type.DataType.DT_BF16:
np_array = np.frombuffer(tensor, dtype=np.uint16)
return (np_array.astype(np.uint32) << 16).view(np.float32)
elif self._tensor_desc.dtype == data_type.DataType.DT_FLOAT16:
np_array = np.frombuffer(tensor, dtype=np.uint16)
return np_array.view(np.float16)
if copy:
ret = np.array(tensor, copy=True)
else:
ret = np.asarray(tensor)
return ret