* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "urma_endpoint.h"
#include <algorithm>
#include "endpoint_mgr.h"
#include "log.h"
#include "hccl/hccl_res.h"
#include "urma_mem.h"
#include "adapter_rts_common.h"
#include "server_socket_manager.h"
#include "ra_rs_comm.h"
namespace hcomm {
UrmaEndpoint::UrmaEndpoint(const EndpointDesc &endpointDesc)
: Endpoint(endpointDesc)
{
}
UrmaEndpoint::~UrmaEndpoint() noexcept
{
std::lock_guard<std::mutex> lock(portMutex_);
if (dynamicPort_ != HCCL_INVALID_PORT) {
ServerSocketStopListenImpl(dynamicPort_);
}
dynamicPort_ = HCCL_INVALID_PORT;
}
HcclResult UrmaEndpoint::Init()
{
HCCL_INFO("[%s] localEndpoint protocol[%d]", __func__, endpointDesc_.protocol);
if (endpointDesc_.loc.locType != ENDPOINT_LOC_TYPE_DEVICE){
HCCL_ERROR("[UrmaEndpoint][%s] endpointDesc.loc.locType[%d] only support ENDPOINT_LOC_TYPE_DEVICE", __func__, endpointDesc_.loc.locType);
return HCCL_E_PARA;
}
Hccl::IpAddress ipAddr{};
HcclResult ret = CommAddrToIpAddress(endpointDesc_.commAddr, ipAddr);
if(ret!= HCCL_SUCCESS) {
HCCL_ERROR("call CommAddrToIpAddress failed");
return ret;
}
u32 devPhyId;
s32 deviceLogicId;
ret = hrtGetDevice(&deviceLogicId);
if(ret != HCCL_SUCCESS){
HCCL_ERROR("call hrtGetDevice failed, deviceLogicId[%d]", deviceLogicId);
return ret;
}
Hccl::HccpHdcManager::GetInstance().Init(deviceLogicId);
ret = hrtGetDevicePhyIdByIndex(static_cast<uint32_t>(deviceLogicId), devPhyId);
if(ret!= HCCL_SUCCESS){
HCCL_ERROR("call hrtGetDevicePhyIdByIndex failed, deviceLogicId[%d], devPhyId[%u]",deviceLogicId, devPhyId);
return ret;
}
if (endpointDesc_.loc.device.devPhyId != devPhyId){
HCCL_WARNING("[UrmaEndpoint][%s] endpointDesc.loc.device.devPhyId[%u] incorrect", __func__, endpointDesc_.loc.device.devPhyId);
endpointDesc_.loc.device.devPhyId = devPhyId;
}
auto &rdmaHandleMgr = Hccl::RdmaHandleManager::GetInstance();
EXCEPTION_CATCH(ctxHandle_ = static_cast<void *>(rdmaHandleMgr.GetByIp(endpointDesc_.loc.device.devPhyId, ipAddr)), return HCCL_E_PARA);
CHK_PTR_NULL(ctxHandle_);
HCCL_INFO("%s success, devId[%u], ipAddr[%s], ctxHandle[%p]",
__func__, devPhyId, ipAddr.Describe().c_str(), ctxHandle_);
EXCEPTION_CATCH(this->regedMemMgr_ = std::make_unique<UbRegedMemMgr>(), return HCCL_E_INTERNAL);
this->regedMemMgr_->rdmaHandle_ = this->ctxHandle_;
ccuChannelCtxPool_.reset(new (std::nothrow) CcuChannelCtxPool(deviceLogicId));
CHK_PTR_NULL(ccuChannelCtxPool_);
return HcclResult::HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::ServerSocketListen(const uint32_t port)
{
if (endpointDesc_.loc.locType != ENDPOINT_LOC_TYPE_DEVICE){
HCCL_INFO("[UrmaEndpoint][%s] endpointDesc.loc.locType[%d] skip create ServerSocket", __func__, endpointDesc_.loc.locType);
return HCCL_SUCCESS;
}
Hccl::IpAddress ipaddr{};
CHK_RET(CommAddrToIpAddress(endpointDesc_.commAddr, ipaddr));
HCCL_INFO("[UrmaEndpoint][%s] endpointDesc_.loc.device.devPhyId %u", __func__, endpointDesc_.loc.device.devPhyId);
Hccl::DevNetPortType type = Hccl::DevNetPortType(Hccl::ConnectProtoType::UB);
Hccl::PortData localPort = Hccl::PortData(static_cast<Hccl::RankId>(endpointDesc_.loc.device.devPhyId), type, 0, ipaddr);
HCCL_INFO("[UrmaEndpoint][%s] devicePhyId[%u] localPort[%s]",
__func__,
endpointDesc_.loc.device.devPhyId,
localPort.Describe().c_str()
);
uint32_t requestPort = port;
CHK_RET(ServerSocketManager::GetInstance().ServerSocketStartListen(localPort, Hccl::NicType::DEVICE_NIC_TYPE, endpointDesc_.loc.device.devPhyId, &requestPort));
return HCCL_SUCCESS;
}
inline HcclResult UrmaEndpoint::ServerSocketStopListenImpl(const uint32_t port)
{
Hccl::IpAddress ipAddr{};
CHK_RET(CommAddrToIpAddress(endpointDesc_.commAddr, ipAddr));
Hccl::DevNetPortType type = Hccl::DevNetPortType(Hccl::ConnectProtoType::UB);
Hccl::PortData localPort = Hccl::PortData(static_cast<Hccl::RankId>(endpointDesc_.loc.device.devPhyId), type, 0, ipAddr);
CHK_RET(ServerSocketManager::GetInstance().ServerSocketStopListen(localPort, Hccl::NicType::DEVICE_NIC_TYPE, port));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::ServerSocketStopListen(const uint32_t port)
{
return ServerSocketStopListenImpl(port);
}
HcclResult UrmaEndpoint::ServerSocketGetListenPort(uint32_t *port)
{
std::lock_guard<std::mutex> lock(portMutex_);
if (endpointDesc_.loc.locType != ENDPOINT_LOC_TYPE_DEVICE){
HCCL_INFO("[UrmaEndpoint][%s] endpointDesc.loc.locType[%d] skip create ServerSocket", __func__, endpointDesc_.loc.locType);
return HCCL_SUCCESS;
}
HCCL_INFO("[UrmaEndpoint][%s] endpointDesc_.loc.device.devPhyId %u", __func__, endpointDesc_.loc.device.devPhyId);
Hccl::IpAddress ipaddr{};
CHK_RET(CommAddrToIpAddress(endpointDesc_.commAddr, ipaddr));
Hccl::DevNetPortType type = Hccl::DevNetPortType(Hccl::ConnectProtoType::UB);
Hccl::PortData localPort = Hccl::PortData(static_cast<Hccl::RankId>(endpointDesc_.loc.device.devPhyId), type, 0, ipaddr);
HCCL_INFO("[UrmaEndpoint][%s] devicePhyId[%u] localPort[%s]",
__func__,
endpointDesc_.loc.device.devPhyId,
localPort.Describe().c_str()
);
if (dynamicPort_ != HCCL_INVALID_PORT) {
*port = dynamicPort_;
HCCL_INFO("[UrmaEndpoint::%s] already listening, return existing port[%u]", __func__, dynamicPort_);
return HCCL_SUCCESS;
}
uint32_t requestPort = 0;
CHK_RET(ServerSocketManager::GetInstance().ServerSocketStartListen(localPort, Hccl::NicType::DEVICE_NIC_TYPE, endpointDesc_.loc.device.devPhyId, &requestPort));
if (requestPort == 0 || requestPort == HCCL_INVALID_PORT) {
HCCL_ERROR("[UrmaEndpoint::%s] get listen port failed, port is invalid.", __func__);
return HCCL_E_NETWORK;
}
dynamicPort_ = requestPort;
*port = dynamicPort_;
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::RegisterMemory(HcommMem mem, const char *memTag, void **memHandle)
{
CHK_RET(this->regedMemMgr_->RegisterMemory(mem, memTag, memHandle));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::UnregisterMemory(void* memHandle)
{
CHK_RET(this->regedMemMgr_->UnregisterMemory(memHandle));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::MemoryExport(void *memHandle, void **memDesc, uint32_t *memDescLen)
{
CHK_RET(this->regedMemMgr_->MemoryExport(this->endpointDesc_, memHandle, memDesc, memDescLen));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::MemoryImport(const void *memDesc, uint32_t descLen, HcommMem *outMem)
{
CHK_RET(this->regedMemMgr_->MemoryImport(memDesc, descLen, outMem));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::MemoryUnimport(const void *memDesc, uint32_t descLen)
{
CHK_RET(this->regedMemMgr_->MemoryUnimport(memDesc, descLen));
return HCCL_SUCCESS;
}
HcclResult UrmaEndpoint::GetAllMemHandles(void **memHandles, uint32_t *memHandleNum)
{
CHK_RET(this->regedMemMgr_->GetAllMemHandles(memHandles, memHandleNum));
return HCCL_SUCCESS;
}
CcuChannelCtxPool *UrmaEndpoint::GetCcuChannelCtxPool()
{
return ccuChannelCtxPool_.get();
}
HcclResult UrmaEndpoint::GetAsyncEvents(uint32_t devPhyId, struct AsyncEvent events[], uint32_t &num)
{
uint32_t interfaceVersion{0};
int ret;
ret = RaGetInterfaceVersion(devPhyId, RA_RS_CTX_GET_ASYNC_EVENTS, &interfaceVersion);
if (ret != 0) {
HCCL_ERROR("[%s] devPhyId[%u] RaGetInterfaceVersion failed, ret[%d]", __func__, devPhyId, ret);
return HCCL_E_INTERNAL;
}
if (interfaceVersion <= 1) {
HCCL_ERROR("[%s] devPhyId[%u] version[%u] not support", __func__, devPhyId, interfaceVersion);
return HCCL_E_NOT_SUPPORT;
}
CHK_PTR_NULL(ctxHandle_);
ret = RaCtxGetAsyncEvents(ctxHandle_, events, &num);
if (ret != 0) {
HCCL_ERROR("[%s] devPhyId[%u] RaCtxGetAsyncEvents failed, ctxHandle[%p] ret[%d]", __func__, devPhyId,
(void *)ctxHandle_, ret);
return HCCL_E_INTERNAL;
}
return HCCL_SUCCESS;
}
}