* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "endpoint_mgr.h"
#include "endpoint.h"
#include "aicpu_ts_roce_endpoint.h"
#include "cpu_roce_endpoint.h"
#include "urma_endpoint.h"
#include "ub_mem_endpoint.h"
#include "uboe_endpoint.h"
#include "cpu_urma_endpoint.h"
#include "aicputs_hccs_endpoint.h"
#include "hccp_nda.h"
#include "adapter_rts_common.h"
namespace hcomm{
static bool IsSupported(const EndpointDesc &endpointDesc)
{
bool protocolSupported = false;
bool locTypeSupported = false;
switch (endpointDesc.protocol) {
case COMM_PROTOCOL_ROCE:
case COMM_PROTOCOL_UBC_TP:
case COMM_PROTOCOL_UBC_CTP:
case COMM_PROTOCOL_UB_MEM:
case COMM_PROTOCOL_PCIE:
case COMM_PROTOCOL_UBOE:
case COMM_PROTOCOL_HCCS:
protocolSupported = true;
break;
default:
return false;
}
switch (endpointDesc.loc.locType) {
case ENDPOINT_LOC_TYPE_DEVICE:
case ENDPOINT_LOC_TYPE_HOST:
locTypeSupported = true;
break;
default:
return false;
}
return protocolSupported && locTypeSupported;
}
Endpoint::Endpoint(const EndpointDesc &endpointDesc)
{
endpointDesc_ = endpointDesc;
}
HcclResult Endpoint::CreateEndpoint(const EndpointDesc &endpointDesc, std::unique_ptr<Endpoint> &endpointPtr)
{
if (!IsSupported(endpointDesc)) {
HCCL_ERROR("[%s]endpointDesc is not supported. endpointDesc.protocol [%d] endpointDesc.loc.locType [%d].", __func__, endpointDesc.protocol, endpointDesc.loc.locType);
return HCCL_E_PARA;
}
HCCL_INFO("[%s]endpointDesc.protocol [%d] endpointDesc.loc.locType [%d].", __func__, endpointDesc.protocol, endpointDesc.loc.locType);
return CreateEndpointBase(endpointDesc, endpointPtr);
}
HcclResult Endpoint::CreateEndpointBase(const EndpointDesc &endpointDesc, std::unique_ptr<Endpoint> &endpointPtr)
{
if (endpointDesc.protocol == COMM_PROTOCOL_ROCE && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_HOST) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<CpuRoceEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if ((endpointDesc.protocol == COMM_PROTOCOL_UBC_TP || endpointDesc.protocol == COMM_PROTOCOL_UBC_CTP)
&& endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_HOST) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<CpuUrmaEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if ((endpointDesc.protocol == COMM_PROTOCOL_UBC_TP || endpointDesc.protocol == COMM_PROTOCOL_UBC_CTP)
&& endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<UrmaEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if (endpointDesc.protocol == COMM_PROTOCOL_UB_MEM && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<UbMemEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if (endpointDesc.protocol == COMM_PROTOCOL_PCIE && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<UbMemEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if (endpointDesc.protocol == COMM_PROTOCOL_UBOE && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<UboeEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if (endpointDesc.protocol == COMM_PROTOCOL_ROCE && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<AicpuTsRoceEndpoint>(endpointDesc), return HCCL_E_PTR);
} else if (endpointDesc.protocol == COMM_PROTOCOL_HCCS && endpointDesc.loc.locType == ENDPOINT_LOC_TYPE_DEVICE) {
EXCEPTION_CATCH(endpointPtr = std::make_unique<AicpuTsHccsEndpoint>(endpointDesc), return HCCL_E_PTR);
} else {
endpointPtr = nullptr;
HCCL_ERROR("[%s] failed, endpointDesc.protocol [%d] and endpointDesc.loc.locType [%d] do not match.",
__func__, endpointDesc.protocol, endpointDesc.loc.locType);
return HCCL_E_PARA;
}
return HCCL_SUCCESS;
}
HcclResult Endpoint::CheckFeature(const EndpointDesc &endpointDesc, HcommEndpointFeatureType featureType, bool &value)
{
if (featureType == HCOMM_ENDPOINT_FEATURE_NDA) {
if (endpointDesc.protocol != COMM_PROTOCOL_ROCE || endpointDesc.loc.locType != ENDPOINT_LOC_TYPE_HOST) {
HCCL_WARNING("[%s] not support NDA, protocol[%d], locType[%d]",
__func__, endpointDesc.protocol, endpointDesc.loc.locType);
value = false;
return HCCL_SUCCESS;
}
Hccl::IpAddress ipAddr{};
CHK_RET(CommAddrToIpAddress(endpointDesc.commAddr, ipAddr));
s32 devId = 0;
CHK_RET(hrtGetDevice(&devId));
u32 devPhyId = 0;
CHK_RET(hrtGetDevicePhyIdByIndex(devId, devPhyId));
auto &rdmaHandleMgr = Hccl::RdmaHandleManager::GetInstance();
void *rdmaHandle = static_cast<void *>(
rdmaHandleMgr.GetByAddr(devPhyId, Hccl::LinkProtoType::RDMA, ipAddr, Hccl::PortDeploymentType::HOST_NET));
CHK_PTR_NULL(rdmaHandle);
s32 directFlag = 0;
s32 ret = RaNdaGetDirectFlag(rdmaHandle, &directFlag);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[%s] failed to get directFlag, ret[%d]", __func__, ret), HCCL_E_INTERNAL);
value = (directFlag != DIRECT_FLAG_NOTSUPP);
HCCL_INFO("[%s] %s NDA, rdmaHandle[%p], directFlag[%d]",
__func__, value ? "support" : "not support", rdmaHandle, directFlag);
} else {
HCCL_WARNING("[%s] unsupported featureType[%d]", __func__, featureType);
value = false;
}
return HCCL_SUCCESS;
}
}