"""
-------------------------------------------------------------------------
This file is part of the Vision SDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
Vision SDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
Description: Tensor.
Author: Vision SDK
Create: 2024
History: NA
"""
import base64
import logging
from enum import Enum
MAX_NAME_LEN = 100
MIN_TENSOR_ID = 0
MAX_TENSOR_ID = 10000
SHAPE_SIZE_MIN = 1
SHAPE_SIZE_MAX = 10000
MIN_DATA_TYPE_STR_LEN = 4
MAX_DATA_TYPE_STR_LEN = 8
MAX_DATA_SIZE = 51200 * 1024
MAX_DATA_SIZE_INT = 51200 * 1024
LEGAL_TENSOR_NAME_CHAR = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_+-/:"
class TensorKind(Enum):
SERVER_KIND = 0
CLIENT_KIND = 1
class DataType(Enum):
UNDEFINED = -1
FLOAT32 = 0
FLOAT16 = 1
INT8 = 2
INT32 = 3
UINT8 = 4
UINT16 = 5
UINT32 = 6
INT64 = 7
UINT64 = 8
DOUBLE64 = 9
BOOL = 10
STRING = 11
BINARY = 12
class Tensor:
def __init__(self, **kwargs):
self.name_ = kwargs.get("name")
self.id_ = kwargs.get("id")
self.format_ = kwargs.get("format")
self.shape_ = kwargs.get("shape")
self.data_type_size_ = 0
self.data_size_ = 0
self.data_ = None
self.data_type_str_ = ""
self.data_type_ = DataType.UNDEFINED
if kwargs.get("data_type_str"):
self.data_type_str_ = kwargs.get("data_type_str")
elif kwargs.get("data_type"):
self.data_type_ = kwargs.get("data_type")
def check_data_size(self, data_size):
if data_size == 0 or data_size > MAX_DATA_SIZE:
logging.error("Tensor data size is invalid.")
return False
if data_size != self.data_size_:
logging.error("The length of the received data does not match the dimension of the tensor. "
"The length of the received data = %s", data_size)
self.print_tensor_info()
return False
return True
def set_data(self, data_size, data_ptr):
if not data_ptr:
logging.error("Data is None.")
return False
self.data_ = data_ptr
return True
def name(self):
return self.name_
def id(self):
return self.id_
def data(self):
return self.data_
def data_size(self):
return len(self.data_)
def get_data_type(self):
return self.data_type_
def get_tensor_json(self):
return {
"name": self.name_,
"data_type": self.data_type_str_,
"format": self.format_,
"shape": self.shape_,
"data": base64.b64encode(self.data_).decode(),
}
def check_tensor_data_size(self, kind):
if kind == TensorKind.CLIENT_KIND:
count = 1
for s in self.shape_:
if s <= 0 or s > MAX_DATA_SIZE_INT or MAX_DATA_SIZE / s < count:
logging.error("Tensor shape dim is invalid.")
return False
count *= s
if count == 0:
logging.error("Tensor shape dim count is invalid.")
return False
self.data_size_ = count * self.data_type_size_
elif kind == TensorKind.SERVER_KIND:
count = 1
for s in self.shape_:
if s == -1:
continue
if s <= 0 or s > MAX_DATA_SIZE_INT or MAX_DATA_SIZE / s < count:
logging.error("Tensor shape dim is invalid.")
return False
count *= s
return True
def check_tensor_data(self, kind):
if len(self.name_) == 0 or len(self.name_) > MAX_NAME_LEN:
logging.error("Tensor name length is invalid.")
return False
for i in self.name_:
if i not in LEGAL_TENSOR_NAME_CHAR:
logging.error("Tensor name has invalid char.")
return False
if self.id_ < MIN_TENSOR_ID or self.id_ > MAX_TENSOR_ID:
logging.error("Tensor id is out of range.")
return False
if len(self.data_type_str_) < MIN_DATA_TYPE_STR_LEN or len(self.data_type_str_) > MAX_DATA_TYPE_STR_LEN:
logging.error("DataType string len out of range.")
return False
if not (self.format_ == "FORMAT_NONE" or self.format_ == "FORMAT_NHWC" or self.format_ == "FORMAT_NCWH"):
logging.error("Tensor format is invaid.")
return False
if len(self.shape_) < SHAPE_SIZE_MIN or len(self.shape_) > SHAPE_SIZE_MAX:
logging.error("Tensor shape size is out of range.")
return False
return self.check_tensor_data_size(kind)
def print_tensor_info(self):
shape_str = ",".join([str(i) for i in self.shape_])
logging.info("Tensor info: dataTypeStr=%s, dataType=%s, dataTypeSize=%s, format=%s, shape=%s, dataSize=%s",
self.data_type_str_, self.data_type_, self.data_type_size_, self.format_, shape_str,
self.data_size_)