* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef QUEUE_SCHEDULE_BIND_RELATION_H
#define QUEUE_SCHEDULE_BIND_RELATION_H
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "hccl/hccl_types_in.h"
#include "queue_schedule/dgw_client.h"
#include "common/bqs_status.h"
#include "common/bqs_util.h"
#include "fsm/state_define.h"
#include "hccl/comm_channel_manager.h"
#include "data_obj_manager.h"
#include "common/bqs_log.h"
namespace bqs {
enum class EventType : uint32_t {
ENQUEUE,
F2NF,
};
struct OptionalArg {
const dgw::CommChannel* channelPtr = nullptr;
dgw::EntityType eType = dgw::EntityType::ENTITY_QUEUE;
bqs::GroupPolicy policy = bqs::GroupPolicy::HASH;
uint32_t peerInstanceNum = 1U;
uint32_t localInstanceIndex = 0U;
uint32_t globalId = 0U;
uint32_t uuId = 0U;
uint32_t schedCfgKey = 0U;
uint32_t queueType = bqs::LOCAL_Q;
};
class EntityInfo {
public:
EntityInfo() = delete;
~EntityInfo() = default;
explicit EntityInfo(const uint32_t id, const uint32_t deviceId, const OptionalArg *const args = nullptr)
: id_(id), deviceId_(deviceId)
{
if (args != nullptr) {
optionalArgs_ = *args;
}
(void)entityDesc_.append("qid:").append(std::to_string(id_))
.append(", type:").append(std::to_string(static_cast<int32_t>(optionalArgs_.eType)))
.append(", globalId:").append(std::to_string(optionalArgs_.globalId))
.append(", schedCfgKey:").append(std::to_string(optionalArgs_.schedCfgKey))
.append(", queue type:").append(std::to_string(optionalArgs_.queueType))
.append(", deviceId:").append(std::to_string(deviceId_));
if (optionalArgs_.channelPtr != nullptr) {
(void)entityDesc_.append(", ").append(optionalArgs_.channelPtr->ToString());
}
}
bool operator==(const EntityInfo &entityInfo) const
{
if (id_ != entityInfo.id_) {
return false;
}
if (optionalArgs_.eType != entityInfo.optionalArgs_.eType) {
return false;
}
if (optionalArgs_.queueType != entityInfo.optionalArgs_.queueType) {
return false;
}
if (deviceId_ != entityInfo.deviceId_) {
return false;
}
return true;
}
inline uint32_t GetId() const
{
return id_;
}
inline dgw::EntityType GetType() const
{
return optionalArgs_.eType;
}
inline bqs::GroupPolicy GetGroupPolicy() const
{
return optionalArgs_.policy;
}
inline const dgw::CommChannel *GetCommChannel() const
{
return optionalArgs_.channelPtr;
}
inline const dgw::EntityPtr &GetEntity() const
{
return entity_;
}
inline void SetEntity(const dgw::EntityPtr &entity)
{
entity_ = entity;
}
inline const std::string &ToString() const
{
return entityDesc_;
}
inline uint32_t GetPeerInstanceNum() const
{
return optionalArgs_.peerInstanceNum;
}
inline uint32_t GetLocalInstanceIndex() const
{
return optionalArgs_.localInstanceIndex;
}
inline uint32_t GetGlobalId() const
{
return optionalArgs_.globalId;
}
inline uint32_t GetUuId() const
{
return optionalArgs_.uuId;
}
inline uint32_t GetSchedCfgKey() const
{
return optionalArgs_.schedCfgKey;
}
inline uint32_t GetDeviceId() const
{
return deviceId_;
}
inline uint32_t GetQueueType() const
{
return optionalArgs_.queueType;
}
private:
uint32_t id_;
uint32_t deviceId_;
OptionalArg optionalArgs_ = {};
std::string entityDesc_;
dgw::EntityPtr entity_ = nullptr;
};
class EntityInfoHash {
public:
size_t operator()(const EntityInfo &info) const
{
return std::hash<uint32_t>()(info.GetId()) ^ std::hash<uint16_t>()(static_cast<uint16_t>(info.GetType()));
}
};
using EntityInfoPtr = std::shared_ptr<EntityInfo>;
using EntityInfoSet = std::unordered_set<EntityInfo, EntityInfoHash>;
using MapEnitityInfoToInfoSet = std::unordered_map<EntityInfo, EntityInfoSet, EntityInfoHash>;
class BindRelation {
public:
static BindRelation &GetInstance();
~BindRelation() = default;
BindRelation(const BindRelation &) = delete;
BindRelation &operator=(const BindRelation &) = delete;
BindRelation(BindRelation &&) = delete;
BindRelation &operator=(BindRelation &&) = delete;
* create relation srcEntity->dstEntity.
* @param srcEntity src queue
* @param dstEntity destination queue
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus Bind(EntityInfo &srcEntity, EntityInfo &dstEntity, const uint32_t resIndex = 0U);
* delete relation srcEntity->dstEntity.
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus UnBind(EntityInfo &srcEntity, EntityInfo &dstEntity, const uint32_t resIndex = 0U);
* delete relation srcEntity->*.
* @param srcEntity src queue id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus UnBindBySrc(const EntityInfo& srcEntity);
* delete relation *->dstEntity.
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus UnBindByDst(const EntityInfo& dstEntity);
* order subscribe queue id by topology.
*/
void Order(const uint32_t index = 0U);
* create group.
* @param entities entity in group
* @param groupId group id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus CreateGroup(const std::vector<EntityInfoPtr> &entities, uint32_t &groupId);
* delete group.
* @param groupId group id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus DeleteGroup(const uint32_t groupId);
* get src to destination relation.
* @return src to destination relation
*/
const MapEnitityInfoToInfoSet &GetSrcToDstRelation() const;
* get destination to src relation.
* @return destination to src relation
*/
const MapEnitityInfoToInfoSet &GetDstToSrcRelation() const;
* Get ordered subscribe queue.
* @return ordered subscribe queue id
*/
const std::vector<EntityInfo> &GetOrderedSubscribeQueueId() const;
* Get bind size.
* @return size of all binds
*/
uint32_t CountBinds() const;
* Get abnormal bind size.
* @return size of all binds
*/
uint32_t CountAbnormalBinds() const;
* Get group config.
* @return endpoint entity info in current group
*/
const std::vector<EntityInfoPtr> &GetEntitiesInGroup(const uint32_t groupId) const;
BqsStatus UnBindRelationBySrc(const EntityInfo &srcEntity);
BqsStatus UnBindRelationByDst(const EntityInfo &dstEntity);
void MarkAbnormalSrc(const EntityInfo &srcEntity);
void MarkAbnormalDst(const EntityInfo &dstEntity);
const MapEnitityInfoToInfoSet &GetAbnormalSrcToDstRelation() const;
const MapEnitityInfoToInfoSet &GetAbnormalDstToSrcRelation() const;
BqsStatus GetBindRelationIndex(const EntityInfo &srcEntity, const EntityInfo &dstEntity, uint32_t &index) const;
const MapEnitityInfoToInfoSet &GetSrcToDstExtraRelation() const;
const MapEnitityInfoToInfoSet &GetDstToSrcExtraRelation() const;
const std::vector<EntityInfo> &GetOrderedSubscribeQueueIdExtra() const;
BqsStatus GetBindIndexBySrc(const EntityInfo &srcEntity, uint32_t &index) const;
BqsStatus ClearInputQueue(const uint32_t index, const std::unordered_set<uint32_t>& keySet);
BqsStatus MakeSureOutputCompletion(const uint32_t index, const std::unordered_set<uint32_t>& keySet);
void ClearAbnormalEntityInfo(const uint32_t index);
void AppendAbnormalEntity(const EntityInfo &info, const dgw::EntityDirection direction, const uint32_t index);
void UpdateRelation(const uint32_t index);
private:
BindRelation() = default;
* check whether can add relation srcEntity->dstEntity.
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:failed
*/
BqsStatus CheckBind(const EntityInfo& srcEntity, const EntityInfo& dstEntity,
const uint32_t resIndex, uint32_t &index) const;
* check whether exist multi-layer bind relation.
* @param srcEntity src entity
* @param dstEntity dst entity
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus CheckMultiLayerBind(const EntityInfo& srcEntity, const EntityInfo& dstEntity, const uint32_t index) const;
* check whether entity exist in group.
* @param src src entity
* @param dst dst entity
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus CheckEntityExistInGroup(const EntityInfo& src, const EntityInfo& dst, const uint32_t resIndex = 0U) const;
* del relation from srcToDstRelation_
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus DelSrcToDst(const EntityInfo& srcEntity, const EntityInfo& dstEntity, const uint32_t index = 0U);
* add relation to srcToDstRelation_
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus AddSrcToDst(EntityInfo &srcEntity, EntityInfo &dstEntity, const uint32_t index = 0U);
* del relation from dstToSrcRelation_
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus DelDstToSrc(const EntityInfo& srcEntity, const EntityInfo& dstEntity, const uint32_t index);
* del relation to dstToSrcRelation_
* @param srcEntity src queue id
* @param dstEntity destination queue id
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus AddDstToSrc(EntityInfo &srcEntity, EntityInfo &dstEntity, const uint32_t index = 0U);
* generate group id
* @return group id
*/
uint32_t GenerateGroupId();
* Create entity
* @param src src entity info
* @param dst dst entity info
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus CreateEntity(const EntityInfo &src, const EntityInfo &dst, const uint32_t resIndex = 0U);
* Create entity
* @param info entity info
* @param isSrc is src
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus CreateEntity(const EntityInfo &info, const bool isSrc, const uint32_t resIndex);
* create entity for group
* @param groupEntity group entity
* @param isSrc is src
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus CreateEntityForGroup(const EntityInfo &groupEntity, const bool isSrc, const uint32_t resIndex);
* Delete entity
* @param info entity info
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus DeleteEntity(const EntityInfo &info, const bool isSrc, const uint32_t resIndex = 0U) const;
* delete entity for group
* @param groupId group id
* @return BQS_STATUS_OK:success, other:exist
*/
BqsStatus DeleteEntityForGroup(const uint32_t queueType, const uint32_t deviceId,
const uint32_t groupId, const dgw::EntityDirection direction, const uint32_t resIndex) const;
* subscribe event
* @param subscribeEntity entity info
* @param eventType event type
* @return BQS_STATUS_OK: success, other: failed
*/
BqsStatus SubscribeEvent(const EntityInfo &subscribeEntity, const EventType eventType,
const uint32_t index) const;
* unsubscribe event
* @param subscribeEntity entity info
* @param eventType event type
* @return BQS_STATUS_OK: success, other: failed
*/
BqsStatus UnsubscribeEvent(const EntityInfo &subscribeEntity, const EventType eventType,
const uint32_t index) const;
* update subscribe event
* @param subscribeEntity entity info
* @param eventType event type
* @return BQS_STATUS_OK: success, other: failed
*/
BqsStatus UpdateSubscribeEvent(const EntityInfo &subscribeEntity, const EventType eventType,
const uint32_t index = 0U) const;
* check whether can execute unbind for entity
* @param entity entity info
* @return BQS_STATUS_OK: success, other: failed
*/
BqsStatus CheckUnBind(const EntityInfo &entity) const;
* set entityptr for entityInfo
* @param entitiInfo to expand
* @return BQS_STATUS_OK: success, other: failed
*/
BqsStatus SetEntityPtr(EntityInfo &entityInfo, const dgw::EntityDirection direction, const uint32_t index) const;
void UnBindAbnormalRelationBySrc(const EntityInfo &srcEntity);
void UnBindAbnormalRelationByDst(const EntityInfo &dstEntity);
void DelAbnormalSrcToDst(const EntityInfo &srcEntity, const EntityInfo &dstEntity);
void DelAbnormalDstToSrc(const EntityInfo &srcEntity, const EntityInfo &dstEntity);
BqsStatus GetBindIndexByDst(const EntityInfo &srcEntity, uint32_t &index) const;
void OrderOneTable(std::vector<EntityInfo> &orderedSubscribeQueueId,
const MapEnitityInfoToInfoSet &srcToDstRelation,
const MapEnitityInfoToInfoSet &dstToSrcRelation);
MapEnitityInfoToInfoSet srcToDstRelation_;
MapEnitityInfoToInfoSet dstToSrcRelation_;
MapEnitityInfoToInfoSet abnormalSrcToDst_;
MapEnitityInfoToInfoSet abnormalDstToSrc_;
std::vector<EntityInfo> orderedSubscribeQueueId_;
std::unordered_map<uint32_t, std::vector<EntityInfoPtr>> allGroupConfig_;
std::unordered_map<uint32_t, std::pair<uint32_t, uint32_t>> group2ResIndex_;
bool isHasLoop_ = false;
SpinLock lockForGroup_;
MapEnitityInfoToInfoSet srcToDstRelationExtra_;
MapEnitityInfoToInfoSet dstToSrcRelationExtra_;
std::vector<EntityInfo> orderedSubscribeQueueIdExtra_;
std::vector<EntityInfo> abnormalSrc_;
std::vector<EntityInfo> abnormalDst_;
};
}
#endif