* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: UcpWorkerPool class that wraps around and manages multiple
* UcpWorkers. This class automatically handles reuses of UcpWorkers and provides
* methods for removing info associated with a bad IP address
*/
#include "datasystem/common/rdma/ucp_worker_pool.h"
#include <mutex>
#include "datasystem/common/flags/flags.h"
#include "datasystem/common/log/log.h"
#include "datasystem/common/util/format.h"
#include "datasystem/common/util/status_helper.h"
namespace datasystem {
UcpWorkerPool::UcpWorkerPool(const ucp_context_h &ucpContext, uint32_t workerN)
: context_(ucpContext), workerN_(workerN)
{
}
UcpWorkerPool::~UcpWorkerPool()
{
Clean();
}
Status UcpWorkerPool::Init()
{
for (uint32_t i = 0; i < workerN_; i++) {
std::shared_ptr<UcpWorker> ucpWorker = std::make_shared<UcpWorker>(context_, i);
RETURN_IF_NOT_OK(ucpWorker->Init());
localWorkerPool_.emplace(i, std::move(ucpWorker));
}
return Status::OK();
}
Status UcpWorkerPool::Write(const std::string &remoteRkey, const uintptr_t remoteSegAddr,
const std::string &remoteWorkerAddr, const std::string &ipAddr,
const uintptr_t localSegAddr, size_t localSegSize, uint64_t requestID,
std::shared_ptr<Event> event)
{
UcpWorker *worker = GetOrSelSendWorker(ipAddr, requestID);
if (worker == nullptr) {
VLOG(ERROR) << FormatString("Communication with IP %s Failed", ipAddr);
RETURN_STATUS(K_RDMA_ERROR, std::string("[UcpWorkerPool] Failed to obtain worker for communication."));
}
return worker->Write(remoteRkey, remoteSegAddr, remoteWorkerAddr, ipAddr, localSegAddr, localSegSize, requestID,
std::move(event));
}
Status UcpWorkerPool::WriteN(const std::string &remoteRkey, uintptr_t remoteBaseAddr,
const std::string &remoteWorkerAddr, const std::string &ipAddr,
const std::vector<IovSegment> &segments, uint64_t requestID, std::shared_ptr<Event> event)
{
UcpWorker *worker = GetOrSelSendWorker(ipAddr, requestID);
if (worker == nullptr) {
VLOG(ERROR) << FormatString("Communication with IP %s Failed", ipAddr);
RETURN_STATUS(K_RDMA_ERROR, std::string("[UcpWorkerPool] Failed to obtain worker for communication."));
}
return worker->WriteN(remoteRkey, remoteBaseAddr, remoteWorkerAddr, ipAddr, segments, requestID, std::move(event));
}
std::string UcpWorkerPool::GetOrSelRecvWorkerAddr(const std::string &ipAddr)
{
{
std::shared_lock<std::shared_mutex> readLock(recvMapMutex_);
auto it = localWorkerRecvMap_.find(ipAddr);
if (it != localWorkerRecvMap_.end()) {
return it->second;
}
}
std::unique_lock<std::shared_mutex> writeLock(recvMapMutex_);
auto it = localWorkerRecvMap_.find(ipAddr);
if (it != localWorkerRecvMap_.end()) {
return it->second;
}
if (localWorkerPool_.empty()) {
LOG(ERROR) << "Failed to select recv worker for " << ipAddr << ": UCP worker pool is empty";
return "";
}
const size_t workerCount = localWorkerPool_.size();
const size_t workerIndex = roundRobin_.fetch_add(1, std::memory_order_relaxed) % workerCount;
VLOG(1) << "Select new ucp worker " << workerIndex << " for " << ipAddr << " to recv";
auto &worker = localWorkerPool_.at(workerIndex);
const std::string &workerAddr = worker->GetLocalWorkerAddr();
localWorkerRecvMap_.emplace(ipAddr, workerAddr);
return workerAddr;
}
Status UcpWorkerPool::RemoveByIp(const std::string &ipAddr)
{
{
std::unique_lock writeLock(recvMapMutex_);
if (localWorkerRecvMap_.erase(ipAddr) <= 0) {
LOG(INFO) << FormatString("Try to remove by IP but never received from %s", ipAddr);
}
}
for (auto &[workerId, worker] : localWorkerPool_) {
Status status = worker->RemoveEndpointByIp(ipAddr);
if (status.IsError()) {
LOG(WARNING) << FormatString("Try to remove by IP %s on worker %u but failed: %s", ipAddr, workerId,
status.ToString());
}
}
return Status::OK();
}
UcpWorker *UcpWorkerPool::GetOrSelSendWorker(const std::string &ipAddr, uint64_t requestID)
{
(void)ipAddr;
(void)requestID;
if (localWorkerPool_.empty()) {
return nullptr;
}
const size_t workerCount = localWorkerPool_.size();
const size_t workerIndex = roundRobin_.fetch_add(1, std::memory_order_relaxed) % workerCount;
auto worker = localWorkerPool_[workerIndex].get();
return worker;
}
void UcpWorkerPool::Clean()
{
{
std::unique_lock<std::shared_mutex> writeLock(recvMapMutex_);
localWorkerRecvMap_.clear();
}
localWorkerPool_.clear();
}
}