* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef BUILT_IN_ENTITY_LLM_COMM_ENTITY_H
#define BUILT_IN_ENTITY_LLM_COMM_ENTITY_H
#include <mutex>
#include <unordered_map>
#include <vector>
#include <memory>
#include <functional>
#include "ascend_hal.h"
#include "flow_func/flow_msg.h"
#include "fsm/state_define.h"
#include "llm_common/llm_common.h"
#include "llm_common/statistic_manager.h"
namespace FlowFunc {
class LlmCommEntity {
public:
struct SyncKvAddrInfo {
Mbuf *sync_kv_req_mbuf;
void *sync_kv_req_addr;
Mbuf *sync_kv_resp_meta_mbuf;
void *sync_kv_resp_meta_addr;
int32_t req_info_count;
};
struct TransferKvAddrInfo {
Mbuf *transfer_kv_req_mbuf;
void *transfer_kv_req_addr;
Mbuf *transfer_kv_resp_meta_mbuf;
void *transfer_kv_resp_meta_addr;
int32_t req_info_count;
};
struct SendKvRecordInfo {
std::vector<KvTensor> kv_tensors;
bool send_kv_meta_succ_flag;
uint64_t send_kv_succ_count;
uint64_t send_complete_cnt;
};
struct RecvTransferKvRecordInfo {
uint32_t recv_req_suc_flag;
uint32_t call_send_meta_flag;
uint32_t send_meta_complete_flag;
uint32_t recv_call_count;
uint32_t recv_suc_count;
uint64_t recv_slot_index;
int32_t last_probed_count;
};
struct ServerTickRecord {
uint64_t probe_req_start_tick;
uint64_t send_meta_start_tick;
uint64_t send_kv_start_tick;
uint64_t link_start_tick;
};
struct ClientTickRecord {
uint64_t send_req_start_tick;
uint64_t probe_resp_start_tick;
uint64_t probe_kv_start_tick;
uint64_t send_transfer_req_start_tick;
uint64_t probe_transfer_resp_start_tick;
uint64_t transfer_kv_start_tick;
};
struct RecvKvRecordInfo {
uint64_t recv_complete_cnt;
};
LlmCommEntity(EntityType type, HcclConn conn, HcclAddr &local_hccl_addr, HcclAddr &remote_hccl_addr);
virtual ~LlmCommEntity();
FsmStatus ProcessState();
FsmStatus ChangeState(FsmState next_state);
const std::string &GetDesc() const;
static const std::string &GetStateDesc(FsmState state);
HcclConn &GetConn();
void SetConn(HcclConn conn);
uint64_t GetRemoteIp() const;
uint64_t GetLocalIp() const;
EntityType GetType() const;
uint64_t GetCurReqId() const;
uint64_t GetCurPrefixId() const;
uint64_t GetCurModelId() const;
bool GetCurIsPullBlock() const;
void SetCurReqId(uint64_t req_id);
void SetCurPrefixId(uint64_t prefix_id);
void SetCurModelId(uint64_t model_id);
void SetCurIsPullBlock(bool is_pull_block);
void SetCurIsPullWithOffset(bool is_pull_with_offset);
bool GetCurIsPullWithOffset() const;
uint64_t GetRemoteClusterId() const;
void SetRemoteClusterId(uint64_t remote_cluster_id);
const std::pair<int64_t, uint32_t> &GetCurCacheIdAndBatchIndex() const;
void SetCurCacheIdAndBatchIndex(std::pair<int64_t, uint32_t> cache_id_and_batch_index);
std::vector<HcclRequest> &GetReceiveRequests();
std::vector<HcclRequest> &GetSendRequests();
std::vector<HcclMessage> &GetProbeMsgs();
std::once_flag &GetsendMetaOnceFlag();
std::once_flag &GetReceiveReqOnceFlag();
SyncKvAddrInfo &GetSyncKvAddrInfo();
TransferKvAddrInfo &GetTransferKvAddrInfo();
SendKvRecordInfo &GetSendKvRecordInfo();
RecvKvRecordInfo &GetRecvKvRecordInfo();
RecvTransferKvRecordInfo &GetRecvTransferKvRecordInfo();
bool GetProbeLinkClusterInfoFlag() const;
void SetProbeLinkClusterInfoFlag(bool link_empty_resp_flag);
bool GetLinkEstablished() const;
void SetLinkEstablished(bool link_established);
EntityStatisticInfo &GetStatisticInfo();
ServerTickRecord &GetServerTickRecord();
void SetForcePrintFlag(bool force_print_flag);
void Dump();
void PrintStatisticInfo();
void PrintPutStatisticInfo();
FsmStatus SendAsync(void *buff, size_t buff_size);
FsmStatus PullOrTransferCache(const TransferKvReqInfo &req_info,
const PullKvReqInfo &pull_req_info,
const CheckLinkReqInfo &check_req_info,
const CacheEntry *cache_entry,
const PullOrTransferExtraInfo &extra_info);
FsmStatus AllocMbuf(Mbuf *&mbuf, uint64_t count, void *&data_buff) const;
FsmState GetCurState() const;
void ClearResource();
std::atomic_bool &GetEntityOccupied();
FsmStatus TestCompleteAsync(HcclRequest requests[], uint64_t request_count, int32_t ¤t_comp_count);
FsmStatus MarkEntityDeletedByConn();
ClientClusterInfo &GetClientClusterInfo();
std::atomic_bool &GetReqIsUsing();
std::atomic_bool &GetIsUnlinking();
std::mutex &GetMutex();
std::mutex &GetUnlinkMutex();
uint8_t &GetCheckLinkRecvData();
FsmStatus CheckLink(const CheckLinkReqInfo &req_info);
std::pair<uint32_t, const TransferInfo *> GetTensorNumAndIndices() const;
std::vector<uint64_t> &GetPushDstAddrs();
LlmCommEntity(const LlmCommEntity &) = delete;
LlmCommEntity(const LlmCommEntity &&) = delete;
LlmCommEntity &operator=(const LlmCommEntity &) = delete;
LlmCommEntity &operator=(const LlmCommEntity &&) = delete;
private:
template<typename T>
static void GeneratePromptSendBlockInfo(const T &req_info,
std::vector<std::pair<uint64_t, uint64_t>> &prompt_send_block_info);
template<typename T>
void AggregateKvBlocks(const T &req_info);
FsmStatus PullKvFromPrompt(const PullKvReqInfo &req_info,
bool enable_paged_attention,
const CacheEntry *cache_entry,
uint64_t start_tick);
FsmStatus TransferCache(const TransferKvReqInfo &req_info,
bool enable_paged_attention,
const CacheEntry *cache_entry,
uint64_t start_tick);
FsmStatus AllocMbufForTransferKvMeta();
void GenerateTransferSlotInfos(const TransferKvReqInfo &req_info, TransferToRemoteReq *transfer_req);
FsmStatus SendTransferReq(const TransferKvReqInfo &req_info);
FsmStatus SendReqToPrompt(const PullKvReqInfo &req_info);
FsmStatus ReceiveTransferKvMeta(const TransferKvReqInfo &req_info);
FsmStatus TransferCacheData(const TransferKvReqInfo &req_info,
const CacheEntry *cache_entry);
FsmStatus ReceiveKvCache(const PullKvReqInfo &req_info,
const SyncKvMetaInfo &resp_info,
const CacheEntry *cache_entry);
FsmStatus ReceiveSyncKvMetaInfo(const PullKvReqInfo &req_info,
SyncKvMetaInfo &resp_info,
const CacheEntry *cache_entry);
FsmStatus ProbeSync(uint64_t data_aize, uint64_t timeout, bool is_receive_meta = false);
FsmStatus TrySendWithTimeout(void *buff, size_t buff_size, uint64_t timeout);
FsmStatus TestSingleRequestSync(HcclRequest &request, uint64_t timeout) const;
FsmStatus TestMultiRequestsSync(HcclRequest requests[], uint64_t request_count, uint64_t &total_comp_count,
uint64_t timeout) const;
FsmStatus AllocMbufForSyncKvMetaInfo();
FsmStatus CheckSyncKvMetaInfo(const PullKvReqInfo &req_info,
SyncKvMetaInfo &resp_info,
const CacheEntry *cache_entry) const;
std::vector<uintptr_t> GetRecvAddrs(const PullKvReqInfo &req_info,
const SyncKvMetaInfo &resp_info,
const CacheEntry *cache_entry);
FsmStatus ReceiveKvCacheAsync(const PullKvReqInfo &req_info,
const SyncKvMetaInfo &resp_info,
const std::vector<uintptr_t> &recv_addrs,
uint64_t timeout);
FsmStatus RecvSingleKvTransfer(size_t index, uintptr_t kv_addr_of_one_layer, const PullKvReqInfo &req_info,
const SyncKvMetaInfo &resp_info, HcclRequest &receive_request);
void ResetNewTransfer(size_t &transfer_index);
void GenerateBufferInfos(const PullKvReqInfo &req_info, SyncKvReqInfo *sync_req_info);
static void CalBuffSize(TransferToRemoteReq *transfer_req_info, uint64_t slot_index, uint64_t &buff_size);
FsmStatus TestUncompleteCheckReq(uint64_t timeout);
private:
void GenerateTensorIndices(const PullKvReqInfo &req_info, SyncKvReqInfo *sync_req_info) const;
void FillSyncKvReqInfo(const PullKvReqInfo &req_info, SyncKvReqInfo *sync_req_info) const;
std::mutex unlink_mutex_;
EntityType type_;
HcclConn conn_;
FsmState cur_state_;
std::vector<HcclRequest> receive_requests_;
std::vector<HcclRequest> send_requests_;
std::vector<HcclMessage> probe_msgs_;
std::pair<int64_t, uint32_t> cur_cache_id_and_batch_index_;
uint64_t cur_req_id_;
uint64_t cur_prefix_id_;
uint64_t cur_model_id_;
SyncKvAddrInfo sync_kv_addr_info_ = {};
TransferKvAddrInfo transfer_kv_addr_info_ = {};
SendKvRecordInfo send_kv_record_info_ = {};
RecvKvRecordInfo recv_kv_record_info_ = {};
RecvTransferKvRecordInfo recv_transfer_kv_record_info_ = {};
HcclAddr local_hccl_addr_ = {};
HcclAddr remote_hccl_addr_ = {};
std::string desc_;
ClientTickRecord client_tick_record_ = {};
ServerTickRecord server_tick_record_ = {};
EntityStatisticInfo stat_info_ = {};
bool force_print_flag_ = false;
std::once_flag receive_resp_once_flag_;
std::once_flag send_meta_once_flag_;
uint8_t check_link_data_ = UINT8_MAX;
uint8_t check_link_recv_data_ = UINT8_MAX;
bool probe_link_cluster_info_succ_flag_ = false;
bool link_established_ = false;
bool enable_paged_attention_ = false;
bool cur_is_pull_block_;
std::vector<std::pair<uint64_t, std::vector<uint64_t>>> transfer_indices_;
std::vector<std::vector<uint64_t>> transfer_continuous_nums_;
std::atomic_bool entity_occupied_ = {false};
std::atomic_bool req_is_using_ = {false};
std::mutex mutex_;
ClientClusterInfo client_cluster_info_ = {};
uint64_t remote_cluster_id_ = 0U;
uint64_t pull_or_transfer_start_tick_ = 0U;
std::atomic_bool is_unlinking_ = {false};
bool cur_is_pull_with_offset_ = false;
std::vector<uint64_t> push_dst_addrs_;
};
}
#endif