* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "tprt_error_code.h"
#include "tprt_api.h"
#include "tprt_type.h"
#include "tprt_worker.hpp"
#include "tprt.hpp"
#if defined(__cplusplus)
extern "C" {
#endif
uint32_t TprtDeviceOpen(const uint32_t devId, const TprtCfgInfo_t *cfg)
{
if (cfg == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] input is null.", devId);
return TPRT_INPUT_NULL;
}
if ((cfg->sqcqMaxDepth == 0U) || (cfg->sqcqMaxDepth > SQCQ_MAX_DEPTH) || (cfg->sqcqMaxNum == 0U)) {
TPRT_LOG(TPRT_LOG_ERROR,
"device_id[%u] sqcqMaxDepth[%u] sqcqMaxNum[%u].",
devId,
cfg->sqcqMaxDepth,
cfg->sqcqMaxNum);
return TPRT_INPUT_INVALID;
}
return cce::tprt::TprtManage::Instance()->TprtDeviceOpen(devId, cfg);
}
uint32_t TprtDeviceClose(uint32_t devId)
{
return cce::tprt::TprtManage::Instance()->TprtDeviceClose(devId);
}
uint32_t TprtSqCqCreate(const uint32_t devId, const TprtSqCqInputInfo *sqInfo, const TprtSqCqInputInfo *cqInfo)
{
if (sqInfo == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "sqInfo is null, device_id=%u.", devId);
return TPRT_INPUT_INVALID;
}
if (cqInfo == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "cqInfo is null, device_id=%u.", devId);
return TPRT_INPUT_INVALID;
}
if (sqInfo->inputType != TPRT_ALLOC_SQ_TYPE) {
TPRT_LOG(TPRT_LOG_ERROR, "Input type of sqInfo is wrong, device_id=%u, inputType=%u.", devId, sqInfo->inputType);
return TPRT_INPUT_INVALID;
}
if (cqInfo->inputType != TPRT_ALLOC_CQ_TYPE) {
TPRT_LOG(TPRT_LOG_ERROR, "Input type of cqInfo is wrong, device_id=%u, inputType=%u.", devId, cqInfo->inputType);
return TPRT_INPUT_INVALID;
}
uint32_t sqCqMaxNum = cce::tprt::TprtManage::Instance()->TprtGetSqCqMaxNum();
if (sqInfo->reqId >= sqCqMaxNum) {
TPRT_LOG(TPRT_LOG_ERROR,
"The reqId of sqInfo is larger than max num, device_id=%u, reqId=%u, sqCqMaxNum=%u.",
devId,
sqInfo->reqId,
sqCqMaxNum);
return TPRT_INPUT_INVALID;
}
if (cqInfo->reqId >= sqCqMaxNum) {
TPRT_LOG(TPRT_LOG_ERROR,
"The reqId of cqInfo is larger than max num, device_id=%u, reqId=%u, sqCqMaxNum=%u.",
devId,
cqInfo->reqId,
sqCqMaxNum);
return TPRT_INPUT_INVALID;
}
cce::tprt::TprtDevice *device = cce::tprt::TprtManage::Instance()->GetDeviceByDevId(devId);
if (device == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] is not found.", devId);
return TPRT_DEVICE_INVALID;
}
uint32_t error = device->TprtSqCqAlloc(sqInfo->reqId, cqInfo->reqId);
if (error != TPRT_SUCCESS) {
TPRT_LOG(TPRT_LOG_ERROR,
"Failed to alloc sq cq, device_id=%u, sq_id=%u, cq_id=%u.",
devId,
sqInfo->reqId,
cqInfo->reqId);
return error;
}
return TPRT_SUCCESS;
}
uint32_t TprtSqCqDestroy(const uint32_t devId, const TprtSqCqInputInfo *sqInfo, const TprtSqCqInputInfo *cqInfo)
{
if (sqInfo == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "sqInfo is null, device_id=%u.", devId);
return TPRT_INPUT_INVALID;
}
if (cqInfo == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "cqInfo is null, device_id=%u.", devId);
return TPRT_INPUT_INVALID;
}
if (sqInfo->inputType != TPRT_FREE_SQ_TYPE) {
TPRT_LOG(TPRT_LOG_ERROR, "Input type of sqInfo is wrong, device_id=%u, inputType=%u.", devId, sqInfo->inputType);
return TPRT_INPUT_INVALID;
}
if (cqInfo->inputType != TPRT_FREE_CQ_TYPE) {
TPRT_LOG(TPRT_LOG_ERROR, "Input type of cqInfo is wrong, device_id=%u, inputType=%u.", devId, cqInfo->inputType);
return TPRT_INPUT_INVALID;
}
cce::tprt::TprtDevice *device = cce::tprt::TprtManage::Instance()->GetDeviceByDevId(devId);
if (device == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] is not found.", devId);
return TPRT_DEVICE_INVALID;
}
uint32_t error = device->TprtSqCqDeAlloc(sqInfo->reqId, cqInfo->reqId);
return error;
}
uint32_t TprtSqPushTask(const uint32_t devId, const TprtTaskSendInfo_t *sendInfo)
{
const uint32_t depth = cce::tprt::TprtManage::Instance()->TprtGetSqMaxDepth();
if ((sendInfo == nullptr) || (sendInfo->sqeNum >= depth)) {
TPRT_LOG(TPRT_LOG_ERROR, "input is invalid device_id[%u].", devId);
return TPRT_INPUT_INVALID;
}
cce::tprt::TprtManage *manage = cce::tprt::TprtManage::Instance();
cce::tprt::TprtDevice *dev = manage->GetDeviceByDevId(devId);
if (dev == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] is invalid.", devId);
return TPRT_DEVICE_INVALID;
}
cce::tprt::TprtSqHandle *sqHandle = dev->TprtGetSqHandleBySqId(sendInfo->sqId);
if (sqHandle == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] sq_id[%u] is invalid.", devId, sendInfo->sqId);
return TPRT_SQ_HANDLE_INVALID;
}
cce::tprt::TprtWorker *worker = dev->TprtGetWorkHandleBySqHandle(sqHandle);
if (worker == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] sq_id[%u] worker is invalid.", devId, sendInfo->sqId);
return TPRT_WORKER_INVALID;
}
uint32_t error = sqHandle->SqPushTask(sendInfo->sqeAddr, sendInfo->sqeNum);
if (error != TPRT_SUCCESS) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] sq_id[%u] push task failed, error=%u.", devId,
sendInfo->sqId, error);
return error;
}
worker->WorkerWakeUp();
return error;
}
uint32_t TprtOpSqCqInfo(uint32_t devId, TprtSqCqOpInfo_t *opInfo)
{
if (opInfo == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "input is null device_id[%u].", devId);
return TPRT_INPUT_INVALID;
}
cce::tprt::TprtManage *manage = cce::tprt::TprtManage::Instance();
cce::tprt::TprtDevice *dev = manage->GetDeviceByDevId(devId);
if (dev == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] is invalid.", devId);
return TPRT_DEVICE_INVALID;
}
const uint32_t error = dev->TprtDevOpSqCqInfo(opInfo);
return error;
}
uint32_t TprtCqReportRecv(uint32_t devId, TprtReportCqeInfo_t *cqeInfo)
{
if ((cqeInfo == nullptr) || (cqeInfo->cqeNum == 0U) || (cqeInfo->cqeAddr == nullptr)) {
TPRT_LOG(TPRT_LOG_ERROR, "cqe info is invalid, device_id=%u.", devId);
return TPRT_INPUT_INVALID;
}
if (cqeInfo->type != TPRT_QUERY_CQ_INFO) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id=%u op type[%u] is invalid.", devId, cqeInfo->type);
return TPRT_INPUT_OP_TYPE_INVALID;
}
cce::tprt::TprtManage *manage = cce::tprt::TprtManage::Instance();
cce::tprt::TprtDevice *dev = manage->GetDeviceByDevId(devId);
if (dev == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] is invalid.", devId);
return TPRT_DEVICE_INVALID;
}
cce::tprt::TprtCqHandle *cqHandle = dev->TprtGetCqHandleByCqId(cqeInfo->cqId);
if (cqHandle == nullptr) {
TPRT_LOG(TPRT_LOG_ERROR, "device_id[%u] cq_id=%u cqhandle can not find.", devId, cqeInfo->cqId);
return TPRT_CQ_HANDLE_INVALID;
}
cqHandle->TprtCqHandleGetCqe(cqeInfo);
return TPRT_SUCCESS;
}
uint32_t TprtProfilingEnable(bool isEnable)
{
cce::tprt::TprtManage *manage = cce::tprt::TprtManage::Instance();
manage->setTprtTaskReportEnable(isEnable);
return 0;
}
#ifdef __cplusplus
}
#endif