* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "llm_common/cluster_manager.h"
#include "flow_func/flow_func_log.h"
#include "flow_func/meta_flow_func.h"
#include "common/scope_guard.h"
#include "llm_common/statistic_manager.h"
#include "entity/link_req_handler.h"
#include "entity/llm_comm_entity_mgr.h"
#include "llm_common/hccl_so_manager.h"
namespace FlowFunc {
namespace {
constexpr const char *kIpAttr = "ip";
constexpr const char *kPortAttr = "port";
constexpr const char *kDeviceIndexAttr = "device_index";
constexpr uint32_t kMaxLinkNum = 512;
}
int32_t ClusterManager::Initialize(const std::shared_ptr<MetaParams> ¶ms,
int32_t device_id,
uint32_t unlink_output_idx,
uint32_t update_link_out_idx,
bool is_prompt) {
unlink_output_idx_ = unlink_output_idx;
update_link_output_idx_ = update_link_out_idx;
auto ret = params->GetAttr(kDeviceIndexAttr, device_index_);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("GetAttr %s failed.", kDeviceIndexAttr);
return ret;
}
ret = HcclSoManager::GetInstance()->LoadSo();
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("Load hccl so failed.");
return ret;
}
FsmStatus entity_ret = LlmCommEntityMgr::RegisterHcclMr(static_cast<uint32_t>(device_id), mem_addrs_);
if (entity_ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("Register hccl mr failed.");
return FLOW_FUNC_FAILED;
}
if (is_prompt) {
ret = InitServerConn(params);
}
UDF_LOG_INFO("Init success, unlink_output_idx = %u, update_link_out_idx = %u", unlink_output_idx, update_link_out_idx);
return ret;
}
void ClusterManager::Finalize() {
(void) LlmCommEntityMgr::UnRegisterHcclMr(mem_addrs_);
}
int32_t ClusterManager::InitServerConn(const std::shared_ptr<MetaParams> ¶ms) {
int64_t listen_ip = 0L;
auto ret = params->GetAttr(kIpAttr, listen_ip);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("GetAttr ip failed.");
return ret;
}
int64_t listen_port = 0L;
ret = params->GetAttr(kPortAttr, listen_port);
if (ret != FLOW_FUNC_SUCCESS) {
UDF_LOG_ERROR("GetAttr port failed.");
return ret;
}
auto entity_ret = LlmCommEntityMgr::GetInstance().InitServerConn(static_cast<uint32_t>(listen_ip),
static_cast<uint16_t>(listen_port));
if (entity_ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("InitServerConn failed. ret = %d", static_cast<int32_t>(entity_ret));
return FLOW_FUNC_FAILED;
}
return FLOW_FUNC_SUCCESS;
}
int32_t ClusterManager::UpdateLink(const std::shared_ptr<MetaRunContext> &run_context,
const std::vector<std::shared_ptr<FlowMsg>> &input_msgs) {
uint64_t start_tick = StatisticManager::GetInstance().GetCpuTick();
StatisticManager::GetInstance().GetStatisticInfo().update_link_total_times++;
update_link_output_ = run_context->AllocTensorMsg({kUpdateLinkMaxCount}, TensorDataType::DT_INT32);
if (update_link_output_ == nullptr) {
UDF_LOG_ERROR("Fail to alloc update link output tensor.");
return FLOW_FUNC_FAILED;
}
ScopeGuard guard([this]() { update_link_output_ = nullptr; });
if (CheckInputTensor(input_msgs, 1UL, 0UL, sizeof(UpdateLinkReqInfo), true) != FLOW_FUNC_SUCCESS) {
return FLOW_FUNC_FAILED;
}
const auto update_link_req = static_cast<UpdateLinkReqInfo *>(input_msgs.front()->GetTensor()->GetData());
if ((update_link_req->cluster_num > kUpdateLinkMaxCount)) {
UDF_LOG_ERROR("Cluster num:%u is over limit:%zu.", update_link_req->cluster_num, kUpdateLinkMaxCount);
return FLOW_FUNC_FAILED;
}
const int64_t output_idx = update_link_output_idx_;
UDF_LOG_INFO("Begin to update link, operator_type:%u, cluster_num:%u, device_index:%ld, output_idx:%ld.",
update_link_req->operator_type, update_link_req->cluster_num, device_index_, output_idx);
auto update_link_ret = static_cast<int32_t *>(update_link_output_->GetTensor()->GetData());
if ((update_link_req->cluster_num + LlmCommEntityMgr::GetInstance().QueryLinkNum()) > kMaxLinkNum) {
UDF_LOG_ERROR("Link num is over limit:%u", kMaxLinkNum);
for (size_t index = 0UL; index < static_cast<size_t>(update_link_req->cluster_num); index++) {
update_link_ret[index] = static_cast<int32_t>(FsmStatus::kFsmParamInvalid);
}
} else {
LinkReqHandler::GetInstance().SetUpdateLinkReqInfo(update_link_req, update_link_ret, device_index_);
(void) LinkReqHandler::GetInstance().HandleReq(start_tick);
StatisticManager::GetInstance().Dump();
}
int32_t ret = run_context->SetOutput(output_idx, update_link_output_);
UDF_LOG_INFO("Finish to update link, operator_type:%u, cluster_num:%u, device_index:%ld, output_idx:%ld.",
update_link_req->operator_type, update_link_req->cluster_num, device_index_, output_idx);
return ret;
}
int32_t ClusterManager::UnlinkData(const std::shared_ptr<MetaRunContext> &run_context,
const std::vector<std::shared_ptr<FlowMsg>> &input_msgs) {
uint64_t start_tick = StatisticManager::GetInstance().GetCpuTick();
StatisticManager::GetInstance().GetStatisticInfo().update_link_total_times++;
unlink_output_ = run_context->AllocTensorMsg({kUpdateLinkMaxCount}, TensorDataType::DT_INT32);
if (unlink_output_ == nullptr) {
UDF_LOG_ERROR("Fail to alloc unlink output tensor.");
return FLOW_FUNC_FAILED;
}
ScopeGuard guard([this]() { unlink_output_ = nullptr; });
if (CheckInputTensor(input_msgs, 1UL, 0UL, sizeof(UpdateLinkReqInfo), true) != FLOW_FUNC_SUCCESS) {
return FLOW_FUNC_FAILED;
}
auto output_data = static_cast<int32_t *>(unlink_output_->GetTensor()->GetData());
const auto update_link_req = static_cast<UpdateLinkReqInfo *>(input_msgs.front()->GetTensor()->GetData());
if (update_link_req->operator_type != static_cast<uint32_t>(LinkReqType::kServerUnlink) &&
update_link_req->operator_type != static_cast<uint32_t>(LinkReqType::kUnlink)) {
UDF_LOG_ERROR("Invalid param, operator type must be unlink.");
*output_data = static_cast<int32_t>(FsmStatus::kFsmParamInvalid);
return run_context->SetOutput(unlink_output_idx_, unlink_output_);
}
UDF_LOG_INFO("Begin to unlink, cluster_num:%u, output_idx:%ld.", update_link_req->cluster_num, unlink_output_idx_);
LinkReqHandler::GetInstance().SetUpdateLinkReqInfo(
update_link_req, static_cast<int32_t *>(unlink_output_->GetTensor()->GetData()), device_index_);
(void) LinkReqHandler::GetInstance().HandleReq(start_tick);
StatisticManager::GetInstance().Dump();
int32_t ret = run_context->SetOutput(unlink_output_idx_, unlink_output_);
UDF_LOG_INFO("Finish to unlink, cluster_num:%u, device_index:%ld, output_idx:%ld.", update_link_req->cluster_num,
device_index_, unlink_output_idx_);
return ret;
}
FsmStatus ClusterManager::CheckPullReqSize(uint64_t msg_data_size, const PullKvReqInfo *req_info) {
uint64_t block_info_byte_size = msg_data_size - sizeof(PullKvReqInfo);
if ((req_info->block_count + req_info->prompt_block_count) > (block_info_byte_size / sizeof(uint64_t))) {
UDF_LOG_ERROR("Invalid req size, block_count:%u, prompt_block_count:%u, real count:%lu.",
req_info->block_count, req_info->prompt_block_count, block_info_byte_size / sizeof(uint64_t));
return FsmStatus::kFsmParamInvalid;
}
return FsmStatus::kFsmSuccess;
}
FsmStatus ClusterManager::PullCacheFromPrompt(PullKvReqInfo *req_info,
bool enable_paged_attention,
uint64_t start_tick,
const CacheEntry *cache_entry) {
EntityPtr entity = LlmCommEntityMgr::GetInstance().GetEntityByRemoteClusterId(req_info->prompt_cluster_id);
if (entity == nullptr) {
UDF_RUN_LOG_ERROR("Current cluster is not linked with remote cluster:%lu, req_id:%lu.",
req_info->prompt_cluster_id, req_info->req_id);
return FsmStatus::kFsmNotLink;
}
if (entity->GetType() != EntityType::kEntityClient) {
UDF_RUN_LOG_ERROR("Invalid entity type, entity type is %d, req_id:%lu, prompt_cluster_id:%lu.",
req_info->req_id, req_info->prompt_cluster_id, static_cast<int32_t>(entity->GetType()));
return FsmStatus::kFsmFailed;
}
PullOrTransferExtraInfo extra_info{enable_paged_attention, true, false, start_tick};
FsmStatus ret = entity->PullOrTransferCache({}, *req_info, {}, cache_entry, extra_info);
const uint64_t finish_tick = StatisticManager::GetInstance().GetCpuTick();
if (ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR(
"Fail to sync kv from prompt, ret:%d, %s, time_cost:%.2f us, entity:%s.",
static_cast<int32_t>(ret),
CacheKeyDebugString(*req_info).c_str(),
StatisticManager::GetInstance().GetTimeCost(finish_tick - start_tick),
entity->GetDesc().c_str());
return ret;
}
bool max_tick_cost_flag = false;
const uint64_t tick_cost = finish_tick - start_tick;
EntityStatisticInfo &stat_info = entity->GetStatisticInfo();
UpdateTickCost(tick_cost, stat_info.sync_kv_total_times, stat_info.sync_kv_min_tick_cost,
stat_info.sync_kv_max_tick_cost, stat_info.sync_kv_total_tick_cost, max_tick_cost_flag);
const double time_cost = StatisticManager::GetInstance().GetTimeCost(tick_cost);
if (max_tick_cost_flag) {
entity->SetForcePrintFlag(true);
UDF_RUN_LOG_INFO(
"Success to pull kv from prompt with max time cost, %s, "
"time cost:%.2f us, total_times:%lu, entity:%s.",
CacheKeyDebugString(*req_info).c_str(), time_cost, stat_info.sync_kv_total_times,
entity->GetDesc().c_str());
}
UDF_LOG_INFO(
"Success to pull kv from prompt, %s, time_cost:%.2f us, total_times:%lu, entity:%s.",
CacheKeyDebugString(*req_info).c_str(), time_cost, stat_info.sync_kv_total_times, entity->GetDesc().c_str());
entity->PrintStatisticInfo();
return FsmStatus::kFsmSuccess;
}
FsmStatus ClusterManager::TransferCache(TransferKvReqInfo *req_info,
bool enable_paged_attention,
uint64_t start_tick,
const CacheEntry *cache_entry) {
EntityPtr entity = LlmCommEntityMgr::GetInstance().GetEntityByRemoteClusterId(req_info->dst_cluster_id);
if (entity == nullptr) {
UDF_RUN_LOG_ERROR("Current cluster is not linked with remote cluster:%lu.", req_info->dst_cluster_id);
return FsmStatus::kFsmNotLink;
}
PullOrTransferExtraInfo extra_info{enable_paged_attention, false, false, start_tick};
FsmStatus ret = entity->PullOrTransferCache(*req_info, {}, {}, cache_entry, extra_info);
const uint64_t finish_tick = StatisticManager::GetInstance().GetCpuTick();
if (ret != FsmStatus::kFsmSuccess) {
UDF_LOG_ERROR("Fail to transfer cache, ret:%d, %s, time cost:%.2f us, entity:%s.",
static_cast<int32_t>(ret), TransferReqDebugString(*req_info).c_str(),
StatisticManager::GetInstance().GetTimeCost(finish_tick - start_tick),
entity->GetDesc().c_str());
return ret;
}
bool max_tick_cost_flag = false;
const uint64_t tick_cost = finish_tick - start_tick;
EntityStatisticInfo &stat_info = entity->GetStatisticInfo();
UpdateTickCost(tick_cost, stat_info.transfer_kv_total_times, stat_info.transfer_kv_min_tick_cost,
stat_info.transfer_kv_max_tick_cost, stat_info.transfer_kv_total_tick_cost, max_tick_cost_flag);
const double time_cost = StatisticManager::GetInstance().GetTimeCost(tick_cost);
if (max_tick_cost_flag) {
entity->SetForcePrintFlag(true);
UDF_RUN_LOG_INFO("Success to transfer cache with max time cost, %s, time cost:%.2f us, total_times:%lu, entity:%s.",
TransferReqDebugString(*req_info).c_str(), time_cost, stat_info.transfer_kv_total_times,
entity->GetDesc().c_str());
}
UDF_LOG_INFO("Success to transfer cache, %s, time cost:%.2f us, total_times:%lu, entity:%s.",
TransferReqDebugString(*req_info).c_str(), time_cost, stat_info.transfer_kv_total_times, entity->GetDesc().c_str());
entity->PrintPutStatisticInfo();
return FsmStatus::kFsmSuccess;
}
int32_t ClusterManager::SyncKvData(const std::shared_ptr<MetaRunContext> &run_context,
const std::vector<std::shared_ptr<FlowMsg>> &input_msgs) const
{
(void) run_context;
(void) input_msgs;
LlmCommEntityMgr::GetInstance().HandleRequest(true);
return FLOW_FUNC_SUCCESS;
}
}