* -------------------------------------------------------------------------
* This file is part of the Vision SDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* Vision SDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
* Description: Manage Acl Stream.
* Author: MindX SDK
* Create: 2023
* History: NA
*/
#include <list>
#include <mutex>
#include <unistd.h>
#include <algorithm>
#include <unordered_map>
#include "acl/acl.h"
#include "acl/dvpp/hi_dvpp.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/E2eInfer/Tensor/Tensor.h"
#include "ResourceManager/StreamTensorManager/StreamTensorManager.h"
#include "MxBase/Asynchron/AscendStream.h"
namespace MxBase {
constexpr int SLEEP_TIME = 500;
constexpr size_t WAIT_TIMES = 100;
constexpr int CALLBACK_WAIT_TIMES = 100;
constexpr char MAGIC_NUMBER = 'A';
static std::mutex g_ascendStreamMtx;
struct MxBaseUserData {
Tensor lastTensor;
aclrtStream stream;
char magicNumber = 'A';
char reserved[15];
};
void RecycleTensorCallback(void *userData)
{
MxBaseUserData *mxBaseData = static_cast<MxBaseUserData*>(userData);
if (mxBaseData->magicNumber != MAGIC_NUMBER) {
LogError << "Memory corruption detected in MxBaseUserData: magicNumber="
<< mxBaseData->magicNumber << ". Expected: " << MAGIC_NUMBER
<< "." << GetErrorInfo(APP_ERR_COMM_FAILURE);
}
if (StreamTensorManager::GetInstance()->DeleteTensor(mxBaseData->stream, mxBaseData->lastTensor) != APP_ERR_OK) {
LogError << "Fail to delete tensor in MxBaseCallback." << GetErrorInfo(APP_ERR_COMM_FAILURE);
}
delete mxBaseData;
}
AscendStream::AscendStream(int32_t deviceId):AscendStream(deviceId, FlagType::LAUNCH_SYNC)
{
}
AscendStream::AscendStream(int32_t deviceId, AscendStream::FlagType flag)
{
deviceId_ = deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->CheckDeviceId(deviceId);
if (ret != APP_ERR_OK) {
LogError << "Device id is out of range, current deviceId is " << deviceId << "." << GetErrorInfo(ret);
throw std::runtime_error(GetErrorInfo(APP_ERR_COMM_INIT_FAIL));
}
if (flag < AscendStream::FlagType::DEFAULT || flag > AscendStream::FlagType::LAUNCH_SYNC) {
LogError << "Flag must be in FlagType, failed to create ascend stream." << GetErrorInfo(ret);
throw std::runtime_error(GetErrorInfo(APP_ERR_COMM_INIT_FAIL));
}
DeviceContext device = {};
device.devId = deviceId;
ret = MxBase::DeviceManager::GetInstance()->SetDevice(device);
if (ret != APP_ERR_OK) {
LogError << "Set current context failed." << GetErrorInfo(ret);
throw std::runtime_error(GetErrorInfo(APP_ERR_COMM_INIT_FAIL));
}
LogDebug << "SetDevice ret is " << ret << ", deviceId is " << deviceId;
ret = aclrtCreateStreamWithConfig(&stream, 0, flag);
if (ret != APP_ERR_OK) {
LogError << "Failed to create ascend stream." << GetErrorInfo(ret, "aclrtCreateStreamWithConfig");
throw std::runtime_error(GetErrorInfo(APP_ERR_COMM_INIT_FAIL));
}
param_ = std::make_shared<CallbackParam>();
ret = StreamTensorManager::GetInstance()->AddStream(stream);
if (ret != APP_ERR_OK) {
LogError << "Failed to add ascend stream for StreamTensorManager." << GetErrorInfo(APP_ERR_COMM_INIT_FAIL);
throw std::runtime_error(GetErrorInfo(APP_ERR_COMM_INIT_FAIL));
}
}
AscendStream::~AscendStream()
{
APP_ERROR ret = APP_ERR_OK;
if (param_->isSuccessSetContext && td_.use_count() == 1) {
ret = aclrtUnSubscribeReport(tid, stream);
param_->isExit = true;
if (ret != APP_ERR_OK) {
LogError << "AclrtUnSubscribeReport is failed." << GetErrorInfo(ret, "aclrtUnSubscribeReport");
}
if (td_->joinable()) {
td_->join();
}
td_.reset();
param_->isSuccessSetContext = false;
}
}
APP_ERROR AscendStream::GetChannel(int* channelId)
{
if (channelId == nullptr) {
LogError << "Failed to get AscendStream channel." << GetErrorInfo(APP_ERR_COMM_FAILURE);
return APP_ERR_COMM_FAILURE;
}
*channelId = chnId_;
return APP_ERR_OK;
}
int32_t AscendStream::GetDeviceId() const
{
return deviceId_;
}
APP_ERROR AscendStream::CreateChannel()
{
if (!DeviceManager::IsAscend310P() && !DeviceManager::IsAscend310B() && !DeviceManager::IsAtlas800IA2()) {
LogError << "CreateChannel() is supported on device 310P/310B/Atlas800IA2 now, current device is "
<< DeviceManager::GetSocName() << ".";
return APP_ERR_COMM_INIT_FAIL;
}
DeviceContext device = {};
device.devId = deviceId_;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->SetDevice(device);
if (ret != APP_ERR_OK) {
LogError << "SetDevice failed when create channel." << GetErrorInfo(ret);
return APP_ERR_COMM_INIT_FAIL;
}
if (chnId_ == -1) {
ret = hi_mpi_sys_init();
if (ret != APP_ERR_OK) {
LogError << "Failed to initialize dvpp system." << GetErrorInfo(ret, "hi_mpi_sys_init");
return APP_ERR_COMM_INIT_FAIL;
}
hi_vpc_chn_attr stChnAttr;
hi_vpc_chn channelId;
stChnAttr.attr = 0;
ret = hi_mpi_vpc_sys_create_chn(&channelId, &stChnAttr);
if (ret != APP_ERR_OK) {
LogError << "Failed to create vpc hi_mpi_vpc_sys." << GetErrorInfo(ret, "hi_mpi_vpc_sys_create_chn");
return APP_ERR_ACL_FAILURE;
}
chnId_ = channelId;
LogInfo << "Create vpc channel success. channel id is " << chnId_ << ".";
}
return ret;
}
APP_ERROR AscendStream::DoProcessCallback()
{
param_->isExit = false;
param_->isSuccessSetContext = false;
param_->deviceId = deviceId_;
td_ = std::make_shared<std::thread>(ProcessCallback, param_.get());
if (td_.get() == nullptr) {
LogError << "Create td ptr failed." << GetErrorInfo(APP_ERR_COMM_ALLOC_MEM);
return APP_ERR_COMM_ALLOC_MEM;
}
int ret = pthread_setname_np(td_->native_handle(), "mx_stream");
if (ret != 0) {
LogError << "Failed to set stream listen thread name." << GetErrorInfo(APP_ERR_COMM_FAILURE);
return APP_ERR_COMM_FAILURE;
}
std::ostringstream oss;
oss << td_->get_id();
tid = std::stoull(oss.str());
return APP_ERR_OK;
}
APP_ERROR AscendStream::CreateAscendStream()
{
if (td_ != nullptr) {
LogError << "Do not call the CreateAscendStream interface repeatedly."
<< GetErrorInfo(APP_ERR_COMM_REPEAT_INITIALIZE);
return APP_ERR_COMM_REPEAT_INITIALIZE;
}
APP_ERROR ret = DoProcessCallback();
if (ret != APP_ERR_OK) {
LogError << "Do ProcessCallback failed." << GetErrorInfo(ret);
return ret;
}
ret = aclrtSubscribeReport(static_cast<uint64_t>(tid), stream);
if (ret != APP_ERR_OK) {
LogError << "AclrtSubscribeReport failed." << GetErrorInfo(ret, "aclrtSubscribeReport");
param_->isExit = true;
td_->join();
td_.reset();
return APP_ERR_ACL_FAILURE;
}
for (size_t i = 0; i < WAIT_TIMES; i++) {
if (param_->isSuccessSetContext) {
return ret;
}
usleep(SLEEP_TIME);
}
param_->isExit = true;
td_->join();
td_.reset();
return APP_ERR_COMM_TIMEOUT;
}
void AscendStream::ProcessCallback(void *arg)
{
CallbackParam *param = static_cast<CallbackParam*>(arg);
DeviceContext device = {};
device.devId = param->deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->SetDevice(device);
if (ret != APP_ERR_OK) {
LogError << "Acl set current context failed." << GetErrorInfo(ret);
return;
} else {
param->isSuccessSetContext = true;
}
while (true) {
(void) aclrtProcessReport(CALLBACK_WAIT_TIMES);
if (param->isExit) {
param->isSuccessSetContext = false;
return;
}
}
}
AscendStream &AscendStream::DefaultStream()
{
static bool isFirstInit = true;
static DeviceContext device = {};
if (isFirstInit) {
std::lock_guard<std::mutex> lock(g_ascendStreamMtx);
device.devId = 0;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->GetCurrentDevice(device);
if (ret != APP_ERR_OK) {
LogDebug << "Fail to get current device." << GetErrorInfo(ret);
}
}
static AscendStream defaultStream(device.devId);
if (isFirstInit) {
std::lock_guard<std::mutex> lock(g_ascendStreamMtx);
defaultStream.isDefault_ = true;
isFirstInit = false;
}
if (defaultStream.stream == nullptr) {
std::lock_guard<std::mutex> lock(g_ascendStreamMtx);
defaultStream = AscendStream(device.devId);
defaultStream.isDefault_ = true;
if (defaultStream.stream == nullptr) {
LogError << "Failed to recreate default ascend stream." << GetErrorInfo(APP_ERR_COMM_INIT_FAIL);
}
}
return defaultStream;
}
static void FinishedProcess(void *)
{
LogInfo << "Ready to destroy stream.";
}
APP_ERROR AscendStream::DestroyAscendStream()
{
DeviceContext device = {};
device.devId = deviceId_;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->SetDevice(device);
if (ret != APP_ERR_OK) {
LogError << "SetDevice failed when destroy stream." << GetErrorInfo(ret);
}
if (Synchronize() != APP_ERR_OK) {
LogError << "User stream synchronize failed when destroy stream." << GetErrorInfo(APP_ERR_COMM_FAILURE);
}
if (param_->isSuccessSetContext) {
param_->isExit = true;
ret = aclrtLaunchCallback(FinishedProcess, nullptr, ACL_CALLBACK_NO_BLOCK, stream);
if (ret != APP_ERR_OK) {
LogError << "Execute launch callback failed in finish process." << GetErrorInfo(ret, "aclrtLaunchCallback");
}
if (Synchronize() != APP_ERR_OK) {
LogError << "Synchronize failed in finish process." << GetErrorInfo(ret);
}
td_->join();
ret = aclrtUnSubscribeReport(tid, stream);
if (ret != APP_ERR_OK) {
LogError << "AclrtUnSubscribeReport is failed." << GetErrorInfo(ret, "aclrtUnSubscribeReport");
}
td_.reset();
}
if (chnId_ != -1) {
LogInfo << "DeInit with mode [DVPP_CHNMODE_VPC]";
ret = hi_mpi_vpc_destroy_chn(chnId_);
if (ret != APP_ERR_OK) {
LogError << "Failed to destroy Vpc channel." << GetErrorInfo(ret, "hi_mpi_vpc_destroy_chn");
}
ret = hi_mpi_sys_exit();
if (ret != APP_ERR_OK) {
LogError << "Failed to exit dvpp system." << GetErrorInfo(ret, "hi_mpi_sys_exit");
}
chnId_ = -1;
}
if (StreamTensorManager::GetInstance()->DeleteStream(stream) != APP_ERR_OK) {
LogError << "Fail to delete stream resource in StreamTensorManager" << GetErrorInfo(APP_ERR_COMM_FAILURE);
}
if (stream != nullptr) {
ret = aclrtDestroyStream(stream);
stream = nullptr;
if (ret != APP_ERR_OK) {
LogError << "AclrtDestroyStream execution failed." << GetErrorInfo(ret, "aclrtDestroyStream");
return APP_ERR_ACL_FAILURE;
}
}
return ret;
}
APP_ERROR AscendStream::Synchronize() const
{
APP_ERROR ret = aclrtSynchronizeStream(stream);
if (ret != APP_ERR_OK) {
LogError << "Synchronize stream execution failed." << GetErrorInfo(ret, "aclrtSynchronizeStream");
return APP_ERR_ACL_FAILURE;
}
ret = StreamTensorManager::GetInstance()->ClearTensorsByStream(stream);
if (ret != APP_ERR_OK) {
LogError << "Fail to clean tensors of stream in synchronize" << GetErrorInfo(ret);
}
return ret;
}
void AscendStream::SetErrorCode(APP_ERROR errCode)
{
if (isDefault_) {
return;
}
if (errCodeLogger.first == APP_ERR_OK) {
errCodeLogger.first = errCode;
} else {
errCodeLogger.second = errCode;
}
}
std::pair<APP_ERROR, APP_ERROR> AscendStream::GetErrorCode()
{
return errCodeLogger;
}
APP_ERROR AscendStream::LaunchCallBack(aclrtCallback fn, void* userData)
{
if (fn == nullptr) {
LogError << "Function is nullptr."<< GetErrorInfo(APP_ERR_COMM_INVALID_POINTER);
return APP_ERR_COMM_INVALID_POINTER;
}
if (stream == nullptr) {
LogError << "AclrtStream is nullptr." << GetErrorInfo(APP_ERR_COMM_INVALID_POINTER);
return APP_ERR_COMM_INVALID_POINTER;
}
APP_ERROR ret = aclrtLaunchCallback(fn, userData, ACL_CALLBACK_BLOCK, stream);
if (ret != APP_ERR_OK) {
LogError << "Execute aclrtLaunchCallback failed for user callback" << GetErrorInfo(ret, "aclrtLaunchCallback");
return APP_ERR_ACL_FAILURE;
}
MxBaseUserData *mxBaseData = new MxBaseUserData();
mxBaseData->stream = stream;
ret = StreamTensorManager::GetInstance()->GetStreamTensorListLastTensor(stream, mxBaseData->lastTensor);
if (ret != APP_ERR_OK) {
LogError << "Failed to GetStreamTensorListLastTensor." << GetErrorInfo(ret);
delete mxBaseData;
return APP_ERR_COMM_FAILURE;
}
ret = aclrtLaunchCallback(RecycleTensorCallback, mxBaseData, ACL_CALLBACK_BLOCK, stream);
if (ret != APP_ERR_OK) {
LogError << "Execute aclrtLaunchCallback failed for recycle callback."
<< GetErrorInfo(ret, "aclrtLaunchCallback");
delete mxBaseData;
return APP_ERR_ACL_FAILURE;
}
return ret;
}
APP_ERROR AscendStream::AddTensorRefPtr(const Tensor& inputTensor)
{
APP_ERROR ret = StreamTensorManager::GetInstance()->AddTensor(stream, inputTensor);
if (ret != APP_ERR_OK) {
LogWarn << "Fail to add tensor in StreamTensorManager." << GetErrorInfo(ret);
}
return APP_ERR_OK;
}
}