* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "entity_manager.h"
#include <algorithm>
#include "bqs_util.h"
#include "client_entity.h"
#include "common/bqs_log.h"
#include "driver/ascend_hal_external.h"
#include "group_entity.h"
#include "queue_manager.h"
#include "schedule_config.h"
#include "simple_entity.h"
#include "state_manager.h"
namespace dgw {
namespace {
const uint32_t SHIFT32 = 32U;
constexpr size_t DEFAULT_CHANNEL_CAPACITY = 32UL;
}
EntityManager::~EntityManager()
{
(void)pthread_rwlock_destroy(&srcCommChannels_.lock);
(void)pthread_rwlock_destroy(&dstCommChannels_.lock);
}
EntityManager &EntityManager::Instance(const uint32_t resIndex)
{
if (resIndex != 0U) {
static EntityManager instanceExtra(resIndex);
return instanceExtra;
}
static EntityManager instance(resIndex);
return instance;
}
EntityPtr EntityManager::DoGetEntity(const uint32_t queueType, const uint32_t deviceId, const EntityType eType,
const uint32_t id, const EntityDirection direction)
{
const auto iterDevId = idToEntity_.find(deviceId);
if (iterDevId == idToEntity_.end()) {
return nullptr;
}
const auto qTypeIter = iterDevId->second.find(queueType);
if (qTypeIter == iterDevId->second.end()) {
return nullptr;
}
const auto iterType = qTypeIter->second.find(eType);
if (iterType == qTypeIter->second.end()) {
return nullptr;
}
const auto itEntityIter = iterType->second.find(id);
if (itEntityIter == iterType->second.end()) {
return nullptr;
}
EntityPtr entity = nullptr;
for (const auto &elem: itEntityIter->second) {
if (elem->GetDirection() == direction) {
entity = elem;
}
}
return entity;
}
EntityPtr EntityManager::GetEntityById(const uint32_t queueType, const uint32_t deviceId, const EntityType eType,
const uint32_t id, const EntityDirection direction)
{
EntityPtr entity = DoGetEntity(queueType, deviceId, eType, id, direction);
if (entity != nullptr) {
DGW_LOG_DEBUG("Direction:[%s] type:[%s] id:[%u] state:[%s] queue id:[%u] "
"device id:[%u] owner deviceid:[%u] is found",
GetDirectionDesc(entity->GetDirection()), entity->GetTypeDesc().c_str(), id,
entity->GetStateDesc(entity->GetCurState()).c_str(),
entity->GetQueueId(), deviceId, entity->GetDeviceId());
} else {
DGW_LOG_WARN("Type:[%s] id:[%u], device id:[%u], queueType:[%u], direction[%d] does not exist",
StateManager::Instance().GetTypeDesc(eType).c_str(), id, deviceId,
queueType, static_cast<int32_t>(direction));
}
return entity;
}
EntityPtr EntityManager::GetSrcEntityByGlobalId(const uint32_t key, const uint32_t globalId) const
{
EntityPtr entity = nullptr;
uint64_t uniqKey = key;
uniqKey = (uniqKey << SHIFT32) | static_cast<uint64_t>(globalId);
const auto iter = globalIdToSrcEntity_.find(uniqKey);
if (iter != globalIdToSrcEntity_.end()) {
entity = iter->second;
}
return entity;
}
EntityPtr EntityManager::GetDstEntityByGlobalId(const uint32_t key, const uint32_t globalId) const
{
EntityPtr entity = nullptr;
uint64_t uniqKey = key;
uniqKey = (uniqKey << SHIFT32) | static_cast<uint64_t>(globalId);
const auto iter = globalIdToDstEntity_.find(uniqKey);
if (iter != globalIdToDstEntity_.end()) {
entity = iter->second;
}
return entity;
}
const std::vector<EntityPtr> &EntityManager::GetEntitiesInGroup(const uint32_t groupId)
{
static const std::vector<EntityPtr> emptyVec;
const auto iter = groupEntityMap_.find(groupId);
if (iter != groupEntityMap_.end()) {
return iter->second;
}
return emptyVec;
}
FsmStatus EntityManager::CreateGroup(const uint32_t groupId, std::vector<EntityPtr>& entities)
{
const auto iter = groupEntityMap_.find(groupId);
if (iter != groupEntityMap_.end()) {
DGW_LOG_WARN("groupId[%u] has exist.", groupId);
return FsmStatus::FSM_SUCCESS;
}
(void) groupEntityMap_.insert(std::make_pair(groupId, entities));
return FsmStatus::FSM_SUCCESS;
}
FsmStatus EntityManager::DeleteGroup(const uint32_t groupId)
{
const auto iter = groupEntityMap_.find(groupId);
if (iter == groupEntityMap_.end()) {
DGW_LOG_WARN("groupId[%u] does not exist.", groupId);
return FsmStatus::FSM_SUCCESS;
}
(void) groupEntityMap_.erase(iter);
return FsmStatus::FSM_SUCCESS;
}
EntityPtr EntityManager::CreateEntity(const EntityMaterial &material)
{
auto entity = GetEntityById(material.queueType, material.resId, material.eType, material.id, material.direction);
if (entity != nullptr) {
DGW_LOG_INFO("Type:[%s] id:[%u], direction:[%d] is existed", entity->GetTypeDesc().c_str(), material.id,
static_cast<int32_t>(material.direction));
entity->IncreaseRefCount();
return entity;
}
entity = AllocEntity(material);
if (entity == nullptr) {
DGW_LOG_ERROR("Create [%s] entity failed type:[%s] id:[%u].", GetDirectionDesc(material.direction),
StateManager::Instance().GetTypeDesc(material.eType).c_str(), material.id);
return nullptr;
}
FsmState state = FsmState::FSM_IDLE_STATE;
if (material.direction == EntityDirection::DIRECTION_RECV) {
state = FsmState::FSM_WAIT_PUSH_STATE;
}
const auto ret = entity->Init(state, material.direction);
if (ret != FsmStatus::FSM_SUCCESS) {
DGW_LOG_ERROR("Init entity failed, direction:[%s].", entity->ToString().c_str());
return nullptr;
}
entity->IncreaseRefCount();
idToEntity_[material.resId][material.queueType][material.eType][material.id].emplace_back(entity);
uint64_t uniqKey = material.schedCfgKey;
uniqKey = (uniqKey << SHIFT32) | static_cast<uint64_t>(material.globalId);
if (material.direction == EntityDirection::DIRECTION_RECV) {
DGW_LOG_INFO("add globalIdToDstEntity_ for %u:%u", material.schedCfgKey, material.globalId);
globalIdToDstEntity_[uniqKey] = entity;
} else {
DGW_LOG_INFO("add globalIdToSrcEntity_ for %u:%u", material.schedCfgKey, material.globalId);
globalIdToSrcEntity_[uniqKey] = entity;
}
DGW_LOG_INFO("Create entity[%s] success, state:[%s]",
entity->ToString().c_str(), entity->GetStateDesc(entity->GetCurState()).c_str());
bqs::StatisticManager::GetInstance().SetExistEntityFlag(true);
if (material.eType == dgw::EntityType::ENTITY_TAG) {
const ChannelEntityPtr channelEntity = std::dynamic_pointer_cast<ChannelEntity>(entity);
(void)InsertCommChannel(channelEntity, material.direction == EntityDirection::DIRECTION_SEND);
}
return entity;
}
FsmStatus EntityManager::DeleteEntity(const uint32_t queueType, const uint32_t deviceId, const EntityType eType,
const uint32_t id, const EntityDirection direction) {
EntityPtr entity = DoGetEntity(queueType, deviceId, eType, id, direction);
if (entity == nullptr) {
DGW_LOG_WARN("Failed to find entity for queueType[%u], deviceId[%u], type[%s], id:[%u], direction:[%d]",
queueType, deviceId, StateManager::Instance().GetTypeDesc(eType).c_str(), id,
static_cast<int32_t>(direction));
return FsmStatus::FSM_SUCCESS;
}
entity->DecreaseRefCount();
if (entity->GetRefCount() != 0U) {
return FsmStatus::FSM_SUCCESS;
}
auto &entityVec = idToEntity_[deviceId][queueType][eType][id];
auto entityVectorIter = entityVec.begin();
while (entityVectorIter != entityVec.end()) {
if ((*entityVectorIter)->GetDirection() == direction) {
DGW_LOG_INFO("Delete entity[%s] from entityMap", entity->ToString().c_str());
entityVec.erase(entityVectorIter);
break;
}
++entityVectorIter;
}
CleanEntityMap(queueType, deviceId, eType, id);
uint64_t uniqKey = entity->GetSchedCfgKey();
uniqKey = (uniqKey << SHIFT32) | static_cast<uint64_t>(entity->GetGlobalId());
if (direction == EntityDirection::DIRECTION_RECV) {
globalIdToDstEntity_.erase(uniqKey);
} else {
globalIdToSrcEntity_.erase(uniqKey);
}
DGW_LOG_INFO("Success to delete entity[%s], direction:[%s].",
entity->ToString().c_str(), GetDirectionDesc(direction));
if (!dgw::ScheduleConfig::GetInstance().GetSchedKeys().empty() &&
(eType == dgw::EntityType::ENTITY_QUEUE)) {
dgw::DynamicSchedMgr::GetInstance(resIndex_).DeleteQueue(entity->GetGlobalId(), entity->GetSchedCfgKey());
}
if (eType == dgw::EntityType::ENTITY_TAG) {
(void)EraseCommChannel(entity, direction == EntityDirection::DIRECTION_SEND);
}
DGW_LOG_INFO("Finish to delete entity[%s], direction:[%s].",
entity->ToString().c_str(), GetDirectionDesc(direction));
(void)entity->Uninit();
return FsmStatus::FSM_SUCCESS;
}
void EntityManager::CleanEntityMap(const uint32_t queueType, const uint32_t deviceId, const EntityType eType,
const uint32_t id)
{
if (!idToEntity_[deviceId][queueType][eType][id].empty()) {
return;
}
DGW_LOG_INFO("Erase id[%u] in entityMap[%u:%u:%d].", id, deviceId, queueType, static_cast<int32_t>(eType));
idToEntity_[deviceId][queueType][eType].erase(id);
if (!idToEntity_[deviceId][queueType][eType].empty()) {
return;
}
DGW_LOG_INFO("Erase eType[%u] in entityMap[%u:%u].", static_cast<int32_t>(eType), deviceId, queueType);
idToEntity_[deviceId][queueType].erase(eType);
if (!idToEntity_[deviceId][queueType].empty()) {
return;
}
DGW_LOG_INFO("Erase queueType[%u] in entityMap[%u].", queueType, deviceId);
idToEntity_[deviceId].erase(queueType);
if (!idToEntity_[deviceId].empty()) {
return;
}
DGW_LOG_INFO("Erase deviceId[%u] in entityMap.", deviceId);
idToEntity_.erase(deviceId);
if (idToEntity_.empty()) {
bqs::StatisticManager::GetInstance().SetExistEntityFlag(false);
}
}
FsmStatus EntityManager::ProbeSrcCommChannel(
const std::function<FsmStatus(const ChannelEntityPtr &, uint32_t &)> procFunc)
{
bool emptySched = true;
(void)pthread_rwlock_rdlock(&srcCommChannels_.lock);
auto &entities = srcCommChannels_.entities;
DGW_LOG_INFO("EntityManager[%u] srcCommChannels_'s size is %zu", resIndex_, entities.size());
for (auto iter = entities.begin(); iter != entities.end(); ++iter) {
uint32_t probeCount = 0U;
(void)procFunc(*iter, probeCount);
emptySched = (probeCount != 0U) ? false : emptySched;
}
(void)pthread_rwlock_unlock(&srcCommChannels_.lock);
if (emptySched) {
bqs::StatisticManager::GetInstance().HcclMpiRecvReqEmptySchedStat();
}
return FsmStatus::FSM_SUCCESS;
}
FsmStatus EntityManager::SupplyRecvRequestEvent()
{
return SupplyEventForRecvRequest(EVENT_RECV_REQUEST_MSG);
}
FsmStatus EntityManager::SupplyOneTrackEvent()
{
return SupplyEventForRecvRequest(EVENT_RECV_COMPLETION_MSG);
}
FsmStatus EntityManager::SupplyEventForRecvRequest(uint32_t msgType)
{
const auto checkFunc = [](const ChannelEntityPtr entity)->bool {
return entity->CheckRecvReqEventContinue();
};
bool supplyEventFlag = false;
(void)pthread_rwlock_rdlock(&srcCommChannels_.lock);
const auto &entities = srcCommChannels_.entities;
supplyEventFlag = std::any_of(entities.begin(), entities.end(), checkFunc);
(void)pthread_rwlock_unlock(&srcCommChannels_.lock);
if (supplyEventFlag) {
(void)SupplyEvent(msgType);
bqs::StatisticManager::GetInstance().RecvReqEventSupplyStat();
DGW_LOG_DEBUG("Supply receive request event[%u].", msgType);
}
return FsmStatus::FSM_SUCCESS;
}
FsmStatus EntityManager::TestSomeCommChannels(
const std::function<FsmStatus(CommChannels &, uint32_t &, uint32_t &)> procFunc, const bool isSrc)
{
uint32_t totalCompCount = 0U;
CommChannels &channels = isSrc ? srcCommChannels_ : dstCommChannels_;
auto ret = FsmStatus::FSM_SUCCESS;
(void)pthread_rwlock_rdlock(&channels.lock);
ret = procFunc(channels, totalCompCount, resIndex_);
(void)pthread_rwlock_unlock(&channels.lock);
if (totalCompCount == 0U) {
isSrc ? bqs::StatisticManager::GetInstance().HcclMpiRecvCompEmptySchedStat() :
bqs::StatisticManager::GetInstance().HcclMpiSendCompEmptySchedStat();
}
return ret;
}
FsmStatus EntityManager::EraseCommChannel(const EntityPtr &entity, const bool isSrc)
{
CommChannels &channels = isSrc ? srcCommChannels_ : dstCommChannels_;
(void)pthread_rwlock_wrlock(&channels.lock);
auto &entities = channels.entities;
const auto iter = std::find(entities.begin(), entities.end(), entity);
if (iter != entities.end()) {
(void)entities.erase(iter);
DGW_LOG_INFO("Success to erase entity:[%s]", entity->ToString().c_str());
}
(void)pthread_rwlock_unlock(&channels.lock);
DGW_LOG_INFO("Success to erase comm channel entity, entity:[%s], isSrc:[%d] comm channels size:[%zu].",
entity->ToString().c_str(), static_cast<int32_t>(isSrc), channels.entities.size());
return FsmStatus::FSM_SUCCESS;
}
FsmStatus EntityManager::InsertCommChannel(const ChannelEntityPtr &entity, const bool isSrc)
{
CommChannels &channels = isSrc ? srcCommChannels_ : dstCommChannels_;
(void)pthread_rwlock_wrlock(&channels.lock);
(void)channels.entities.push_back(entity);
const size_t curCapacity = channels.requests.capacity();
if (curCapacity < channels.entities.size()) {
const size_t newCapacity = curCapacity + DEFAULT_CHANNEL_CAPACITY;
channels.requests.reserve(newCapacity);
channels.compIndices.reserve(newCapacity);
channels.compStatus.reserve(newCapacity);
DGW_LOG_INFO("Success to reserve requests capacity, primary:[%zu], current:[%zu]", curCapacity, newCapacity);
}
(void)pthread_rwlock_unlock(&channels.lock);
DGW_LOG_INFO("Success to save comm channel entity:[%s], isSrc:[%d] comm channels size:[%zu].",
entity->ToString().c_str(), static_cast<int32_t>(isSrc), channels.entities.size());
return FsmStatus::FSM_SUCCESS;
}
EntityPtr EntityManager::AllocEntity(const EntityMaterial &material) const
{
EntityPtr entity = nullptr;
try {
DoAllocEntity(material, entity);
} catch (std::exception &e) {
DGW_LOG_ERROR("catch %s", e.what());
}
return entity;
}
void EntityManager::DoAllocEntity(const EntityMaterial &material, EntityPtr &entity) const
{
if (material.eType == dgw::EntityType::ENTITY_TAG) {
entity = std::make_shared<ChannelEntity>(material, resIndex_);
} else if (material.eType == dgw::EntityType::ENTITY_GROUP) {
entity = std::make_shared<GroupEntity>(material, resIndex_);
} else {
if (material.queueType == bqs::CLIENT_Q) {
entity = std::make_shared<ClientEntity>(material, resIndex_);
} else {
entity = std::make_shared<SimpleEntity>(material, resIndex_);
}
}
}
CommChannels &EntityManager::GetCommChannels(const bool isSrc)
{
return isSrc ? srcCommChannels_ : dstCommChannels_;
}
FsmStatus EntityManager::SupplyEvent(const uint32_t eventId) const
{
if (resIndex_ != 0U) {
return SupplyEvent(eventId, bqs::QueueManager::GetInstance().GetExtraDeviceId(),
static_cast<uint32_t>(bqs::EventGroupId::ENQUEUE_GROUP_ID_EXTRA));
}
return SupplyEvent(eventId, bqs::QueueManager::GetInstance().GetDeviceId(),
static_cast<uint32_t>(bqs::EventGroupId::ENQUEUE_GROUP_ID));
}
FsmStatus EntityManager::SupplyEvent(const uint32_t eventId, const uint32_t deviceId, const uint32_t groupId) const
{
uint32_t submitDeviceId = deviceId;
uint32_t submitGroupId = groupId;
if (!bqs::GlobalCfg::GetInstance().GetNumaFlag() &&
((eventId == static_cast<uint32_t>(EVENT_RECV_REQUEST_MSG)) ||
(eventId == static_cast<uint32_t>(EVENT_SEND_COMPLETION_MSG)) ||
(eventId == static_cast<uint32_t>(EVENT_RECV_COMPLETION_MSG)))) {
submitGroupId = static_cast<uint32_t>(bqs::EventGroupId::F2NF_GROUP_ID);
submitDeviceId = bqs::QueueManager::GetInstance().GetDeviceId();
}
event_summary sched = { };
sched.pid = getpid();
sched.grp_id = submitGroupId;
sched.event_id = static_cast<EVENT_ID>(eventId);
sched.dst_engine = ACPU_LOCAL;
const auto ret = halEschedSubmitEvent(submitDeviceId, &sched);
if (ret != DRV_ERROR_NONE) {
DGW_LOG_ERROR("Call halEschedSumbmitEvent failed, event:[%u], deviceId[%u], groupId[%u], ret:[%d].",
eventId, submitDeviceId, submitGroupId, static_cast<int32_t>(ret));
return FsmStatus::FSM_FAILED;
}
return FsmStatus::FSM_SUCCESS;
}
FsmStatus EntityManager::CheckLinkStatus()
{
(void)pthread_rwlock_rdlock(&srcCommChannels_.lock);
auto &srcEntities = srcCommChannels_.entities;
for (auto iter = srcEntities.begin(); iter != srcEntities.end(); ++iter) {
if ((*iter)->linkStatus_ == dgw::ChannelLinkStatus::UNCONNECTED) {
DGW_LOG_INFO("CheckLinkStatus srcCommChannels queueId[%u] unconnected", (*iter)->GetQueueId());
(void)pthread_rwlock_unlock(&srcCommChannels_.lock);
return FsmStatus::FSM_FAILED;
}
}
(void)pthread_rwlock_unlock(&srcCommChannels_.lock);
(void)pthread_rwlock_rdlock(&dstCommChannels_.lock);
auto &dstEntities = dstCommChannels_.entities;
for (auto iter = dstEntities.begin(); iter != dstEntities.end(); ++iter) {
if ((*iter)->linkStatus_ == dgw::ChannelLinkStatus::UNCONNECTED) {
DGW_LOG_INFO("CheckLinkStatus dstCommChannels_ queueId[%u] unconnected", (*iter)->GetQueueId());
(void)pthread_rwlock_unlock(&dstCommChannels_.lock);
return FsmStatus::FSM_FAILED;
}
}
(void)pthread_rwlock_unlock(&dstCommChannels_.lock);
DGW_LOG_INFO("CheckLinkStatus connect success");
return FsmStatus::FSM_SUCCESS;
}
}