* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef SEQUENCE_GROUP_MAPPER_H
#define SEQUENCE_GROUP_MAPPER_H
#include <pthread.h>
#include <mutex>
#include <unordered_map>
#include "concurrent_map.h"
#include "sequence_group.h"
namespace mindie_llm {
constexpr size_t MAX_DP_COUNT = 32;
class LiveInferContext {
public:
static std::shared_ptr<LiveInferContext> GetInstance(size_t localDPRank);
LiveInferContext();
~LiveInferContext();
void Add(SequenceGroupSPtr &seqGroup);
void AddIntoSeqRootMap(SequenceId seqId, SequenceGroupSPtr &rootSeqGroup);
void Remove(RequestId reqId);
void Remove(SequenceId seqId);
void RemoveFromSeqRootMap(SequenceId seqId);
void RemoveAll();
SequenceSPtr GetSeq(SequenceId seqId);
SequenceGroupSPtr GetSeqGroup(RequestId reqId);
SequenceGroupSPtr GetSeqGroup(SequenceId seqId);
SequenceGroupSPtr GetSeqGroupFromSeqRootMap(SequenceId seqId);
void SetInferInstanceFlexRole4Req(RequestId reqId, Role role);
Role GetInferInstanceFlexRole4Req(RequestId reqId);
* 根据id从本机的所有的dp rank中查询seqgrp;只有主节点有可能有多个rank
*/
template <typename T>
static std::pair<size_t, SequenceGroupSPtr> FindSeqGroupInAllRank(T id);
private:
std::unordered_map<RequestId, SequenceGroupSPtr> reqId2SeqGroupMap_;
std::unordered_map<SequenceId, SequenceGroupSPtr> seqId2SeqGroupMap_;
std::unordered_map<SequenceId, SequenceGroupSPtr> seqId2RootSeqGroupMap_;
std::unordered_map<RequestId, Role> reqId2UsedInstanceRoleMap_;
mutable pthread_spinlock_t spinlock_;
};
using LiveInferContextSPtr = std::shared_ptr<LiveInferContext>;
template <typename T>
std::pair<size_t, SequenceGroupSPtr> LiveInferContext::FindSeqGroupInAllRank(T id) {
for (size_t i = 0; i < MAX_DP_COUNT; i++) {
LiveInferContextSPtr lifeContext = GetInstance(i);
if (lifeContext == nullptr || lifeContext->GetSeqGroup(id) == nullptr) {
continue;
}
return {i, lifeContext->GetSeqGroup(id)};
}
return {0, nullptr};
}
}
#endif