* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <atomic>
#include "common_util.h"
#include "error.h"
#include "llm_manager.h"
#include "memory_utils.h"
#include "status.h"
using namespace mindie_llm;
namespace py = pybind11;
namespace mindie_llm {
constexpr uint32_t MAX_INPUTS_NUM = 4 * 1024 * 1024;
constexpr uint32_t MAX_BYTE_ALLOWED = 4 * 1024 * 1024 * sizeof(int64_t);
void StructDefine(py::module &m) {
py::enum_<MemType>(m, "MemType").value("HOST_MEM", MemType::HOST_MEM);
py::enum_<InferDataType>(m, "InferDataType")
.value("TYPE_INVALID", InferDataType::TYPE_INVALID)
.value("TYPE_BOOL", InferDataType::TYPE_BOOL)
.value("TYPE_UINT8", InferDataType::TYPE_UINT8)
.value("TYPE_UINT16", InferDataType::TYPE_UINT16)
.value("TYPE_UINT32", InferDataType::TYPE_UINT32)
.value("TYPE_UINT64", InferDataType::TYPE_UINT64)
.value("TYPE_INT8", InferDataType::TYPE_INT8)
.value("TYPE_INT16", InferDataType::TYPE_INT16)
.value("TYPE_INT32", InferDataType::TYPE_INT32)
.value("TYPE_INT64", InferDataType::TYPE_INT64)
.value("TYPE_FP16", InferDataType::TYPE_FP16)
.value("TYPE_FP32", InferDataType::TYPE_FP32)
.value("TYPE_FP64", InferDataType::TYPE_FP64)
.value("TYPE_STRING", InferDataType::TYPE_STRING)
.value("TYPE_BF16", InferDataType::TYPE_BF16)
.value("TYPE_BUTT", InferDataType::TYPE_BUTT);
py::enum_<Error::Code>(m, "Code")
.value("OK", Error::Code::OK)
.value("ERROR", Error::Code::ERROR)
.value("INVALID_ARG", Error::Code::INVALID_ARG)
.value("NOT_FOUND", Error::Code::NOT_FOUND);
py::enum_<InferRequestId::DataType>(m, "DataType")
.value("UINT64", InferRequestId::DataType::UINT64)
.value("STRING", InferRequestId::DataType::STRING);
py::enum_<StatusResponseType>(m, "StatusResponseType")
.value("CONTROL_SIGNAL_STATUS", StatusResponseType::CONTROL_SIGNAL_STATUS)
.value("REQUEST_ENQUEUE_STATUS", StatusResponseType::REQUEST_ENQUEUE_STATUS);
py::enum_<Operation>(m, "Operation").value("STOP", Operation::STOP).value("RELEASE_KV", Operation::RELEASE_KV);
}
void StatusDefine(py::module &m) {
py::class_<Status>(m, "Status")
.def(py::init<Error::Code>())
.def(py::init<Error::Code, const std::string>())
.def(py::init<Error>())
.def("is_ok", &Status::IsOk)
.def("status_code", &Status::StatusCode)
.def("status_msg", &Status::StatusMsg);
}
void ErrorDefine(py::module &m) {
py::class_<Error>(m, "Error")
.def(py::init<Error::Code>())
.def(py::init<Error::Code, std::string>())
.def("error_code", &Error::ErrorCode)
.def("message", &Error::Message)
.def("is_ok", &Error::IsOk);
}
py::dtype GetNumpyDtype(InferDataType dataType) {
switch (dataType) {
case InferDataType::TYPE_BOOL:
return pybind11::dtype::of<bool>();
case InferDataType::TYPE_UINT8:
return pybind11::dtype::of<uint8_t>();
case InferDataType::TYPE_UINT16:
return pybind11::dtype::of<uint16_t>();
case InferDataType::TYPE_UINT32:
return pybind11::dtype::of<uint32_t>();
case InferDataType::TYPE_UINT64:
return pybind11::dtype::of<uint64_t>();
case InferDataType::TYPE_INT8:
return pybind11::dtype::of<int8_t>();
case InferDataType::TYPE_INT16:
return pybind11::dtype::of<int16_t>();
case InferDataType::TYPE_INT32:
return pybind11::dtype::of<int32_t>();
case InferDataType::TYPE_INT64:
return pybind11::dtype::of<int64_t>();
case InferDataType::TYPE_FP32:
return pybind11::dtype::of<float>();
case InferDataType::TYPE_FP64:
return pybind11::dtype::of<double>();
default:
throw std::runtime_error("Unsupported data type");
}
}
py::array TensorToNumpy(const InferTensor &tensor) {
auto shape = tensor.GetShape();
void *data = tensor.GetData();
pybind11::dtype dtype = GetNumpyDtype(tensor.GetDataType());
return pybind11::array(dtype, shape, data);
}
void InferTensorDefine(py::module &m) {
using TensorMap = std::unordered_map<std::string, std::shared_ptr<InferTensor>>;
py::class_<InferTensor, std::shared_ptr<InferTensor>>(m, "InferTensor")
.def(py::init<>())
.def(py::init<std::string, InferDataType, std::vector<int64_t>>(), py::arg("name"), py::arg("data_type"),
py::arg("data_shape"))
.def("get_shape", &InferTensor::GetShape)
.def("get_size", &InferTensor::GetSize)
.def("get_data_type", &InferTensor::GetDataType)
.def("get_mem_type", &InferTensor::GetMemType)
.def("get_data", &InferTensor::GetData)
.def("get_name", &InferTensor::GetName)
.def("allocate", &InferTensor::Allocate, py::arg("size"))
.def("set_buffer",
[](InferTensor &self, py::buffer &buf, bool needRelease) {
auto bufferInfo = buf.request();
if (bufferInfo.size < 0 || bufferInfo.size > MAX_INPUTS_NUM ||
static_cast<uint64_t>(bufferInfo.itemsize) > sizeof(int64_t)) {
std::string message = "The number of items or item size in input buffer is error. ";
throw std::runtime_error(message + "the number of items must in the range of [0, " +
std::to_string(MAX_INPUTS_NUM) + "]." +
"the item size must in the range of (0, 8].");
}
auto bufferSize = bufferInfo.size * bufferInfo.itemsize;
if (bufferSize > MAX_BYTE_ALLOWED || bufferSize <= 0) {
std::string mallocSize = std::to_string(bufferSize);
throw std::runtime_error("valid byte allowed is (0, " + std::to_string(MAX_BYTE_ALLOWED) +
"). try to allocate " + mallocSize);
}
void *data = malloc(bufferSize);
if (data == nullptr) {
throw std::runtime_error("malloc data failed.");
}
try {
if (memcpy_s(data, bufferSize, bufferInfo.ptr, bufferSize) != 0) {
throw std::runtime_error("Error occured in set_buffer memcpy_s.");
}
if (bufferInfo.ndim != 1) {
throw std::runtime_error("Buffer must be one-dimensional.");
}
self.SetBuffer(data, bufferSize, needRelease);
} catch (const std::exception &e) {
free(data);
throw e;
}
})
.def("set_release", &InferTensor::SetRelease, py::arg("release_flag"))
.def("release", &InferTensor::Release);
m.def("tensor_to_numpy", &TensorToNumpy, "Converts the InferTensor's data to a NumPy array.");
py::bind_map<TensorMap>(m, "TensorMap");
}
void InferRequestDefine(py::module &m) {
py::class_<InferRequestId>(m, "InferRequestId")
.def(py::init<>())
.def(py::init<std::string>())
.def(py::init<uint64_t>())
.def("type", &InferRequestId::Type)
.def("string_value", &InferRequestId::StringValue)
.def("unsigned_int_value", &InferRequestId::UnsignedIntValue);
py::class_<InferRequest, std::shared_ptr<InferRequest>>(m, "InferRequest")
.def(py::init<InferRequestId>())
.def("add_tensor", &InferRequest::AddTensor, py::arg("tensor_name"), py::arg("tensor"))
.def("set_tensor", &InferRequest::SetTensor, py::arg("tensor_name"), py::arg("tensor"))
.def("get_tensor_by_name", &InferRequest::GetTensorByName, py::arg("tensor_name"), py::arg("tensor"))
.def("del_tensor_by_name", &InferRequest::DelTensorByName, py::arg("name"))
.def("get_request_id", &InferRequest::GetRequestId)
.def("set_max_output_len", &InferRequest::SetMaxOutputLen, py::arg("max_output_len"))
.def("get_max_output_len", &InferRequest::GetMaxOutputLen)
.def("immutable_inputs", &InferRequest::ImmutableInputs);
}
void LlmManagerDefine(py::module &m) {
py::class_<LlmManager, std::shared_ptr<LlmManager>>(m, "LlmManager")
.def(py::init<const std::string &, GetRequestsCallback, SendResponsesCallback, ControlSignalCallback,
LlmManagerStatsCallback, SendStatusResponseCallback>(),
py::arg("llm_config_path"), py::arg("get_request"), py::arg("send_response"), py::arg("control_callback"),
py::arg("status_callback"), py::arg("status_response_callback"))
.def("get_max_position_embeddings", &LlmManager::GetMaxPositionEmbeddings)
.def("shutdown", &LlmManager::Shutdown)
.def("init", py::overload_cast<uint32_t, std::set<size_t>>(&LlmManager::Init), py::arg("model_instanceId"),
py::arg("npu_device_ids"))
.def("init",
py::overload_cast<uint32_t, std::set<size_t>, std::map<std::string, std::string>>(&LlmManager::Init),
py::arg("model_instanceId"), py::arg("npu_device_ids"), py::arg("extend_info"));
}
PYBIND11_MODULE(llm_manager_python, m) {
StatusDefine(m);
ErrorDefine(m);
StructDefine(m);
InferTensorDefine(m);
InferRequestDefine(m);
LlmManagerDefine(m);
}
}