* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <map>
#include <regex>
#include "shmemi_string_util.h"
#include "store_net_common.h"
#include "host/shmem_host_def.h"
#include "shmemi_logger.h"
namespace shm {
namespace store {
const std::string PROTOCOL_TCP4 = "tcp://";
const std::string PROTOCOL_TCP6 = "tcp6://";
enum class PROTOCOLTYPE {
PROTOCOLV4,
PROTOCOLV6,
IPNONE,
};
static PROTOCOLTYPE type = PROTOCOLTYPE::IPNONE;
inline void Split(const std::string &src, const std::string &sep, std::vector<std::string> &out)
{
int COUNT = 1;
std::string::size_type pos1 = 0;
std::string::size_type pos2 = src.find_last_of(sep);
std::string tmpStr;
if (src[0] != '[') {
while (pos2 != std::string::npos) {
tmpStr = src.substr(pos1, pos2 - pos1);
out.emplace_back(tmpStr);
pos1 = pos2 + sep.size();
pos2 = src.find(sep, pos1);
}
if (pos1 != src.length()) {
tmpStr = src.substr(pos1);
out.emplace_back(tmpStr);
}
} else {
if (std::count(src.begin(), src.end(), sep[0]) > COUNT) {
const int diff = 2;
tmpStr = src.substr(pos1 + 1, pos2 - pos1 - diff);
out.emplace_back(tmpStr);
pos1 = pos2 + sep.size();
pos2 = src.find(sep, pos1);
if (pos1 != src.length()) {
tmpStr = src.substr(pos1);
out.emplace_back(tmpStr);
}
}
}
}
bool IsValidIp(const std::string &address)
{
if (type == PROTOCOLTYPE::PROTOCOLV4) {
constexpr size_t maxIpLenV4 = 15;
if (address.size() > maxIpLenV4) {
return false;
}
return true;
} else if (type == PROTOCOLTYPE::PROTOCOLV6) {
constexpr size_t maxIpLenV6 = 39;
if (address.size() > maxIpLenV6) {
return false;
}
} else {
return false;
}
return true;
}
Result ExtractTcpURL(const std::string &url, std::map<std::string, std::string> &details)
{
std::string tmpUrl;
if (url.compare(0, PROTOCOL_TCP6.size(), PROTOCOL_TCP6) == 0) {
type = PROTOCOLTYPE::PROTOCOLV6;
tmpUrl = url.substr(PROTOCOL_TCP6.length(), url.length() - PROTOCOL_TCP6.length());
} else if (url.compare(0, PROTOCOL_TCP4.size(), PROTOCOL_TCP4) == 0) {
type = PROTOCOLTYPE::PROTOCOLV4;
tmpUrl = url.substr(PROTOCOL_TCP4.length(), url.length() - PROTOCOL_TCP4.length());
} else {
return ACLSHMEM_INVALID_PARAM;
}
std::vector<std::string> splits;
Split(tmpUrl, ":", splits);
if (splits.size() != UN2) {
return ACLSHMEM_INVALID_PARAM;
}
details["port"] = splits[1];
if (splits[0].find('/') == std::string::npos) {
details["ip"] = splits[0];
return ACLSHMEM_SUCCESS;
}
tmpUrl = splits[0];
splits.clear();
Split(tmpUrl, "/", splits);
if (splits.size() != UN2) {
return ACLSHMEM_INVALID_PARAM;
}
details["ip"] = splits[0];
details["mask"] = splits[1];
return ACLSHMEM_SUCCESS;
}
Result UrlExtraction::ExtractIpPortFromUrl(const std::string &url)
{
std::map<std::string, std::string> details;
auto result = ExtractTcpURL(url, details);
auto iterMask = details.find("mask");
std::string ipStr = details["ip"];
std::string portStr = details["port"];
long tmpPort = 0;
if (!StrToLong(portStr, tmpPort)) {
SHM_LOG_ERROR("Invalid portStr :" << portStr << ", which is invalid");
return ACLSHMEM_INVALID_PARAM;
}
if (!IsValidIp(ipStr) || tmpPort <= N1024 || tmpPort > UINT16_MAX) {
SHM_LOG_ERROR("Invalid ipStr :" << ipStr << " or port :" << tmpPort << ", which is invalid");
return ACLSHMEM_INVALID_PARAM;
}
ip = ipStr;
port = tmpPort;
return ACLSHMEM_SUCCESS;
}
static Result GetLocalIpWithTargetWhenIpv6(struct in6_addr &localIp, char *localResultIp, int size,
mf_ip_addr &ipaddr, std::string &local)
{
Result result = ACLSHMEM_SMEM_ERROR;
if (inet_ntop(AF_INET6, &localIp, localResultIp, size) == nullptr) {
SHM_LOG_ERROR("convert local ipv6 to string failed. ");
result = ACLSHMEM_SMEM_ERROR;
} else {
ipaddr.type = IpV6;
std::copy(std::begin(localIp.s6_addr), std::end(localIp.s6_addr), std::begin(ipaddr.addr.addrv6));
local = std::string(localResultIp);
result = ACLSHMEM_SUCCESS;
}
return result;
}
static Result GetLocalIpWithTargetWhenIpv4(struct in_addr &localIp, char *localResultIp, int size,
mf_ip_addr &ipaddr, std::string &local)
{
Result result = ACLSHMEM_SMEM_ERROR;
if (inet_ntop(AF_INET, &localIp, localResultIp, size) == nullptr) {
SHM_LOG_ERROR("convert local ipv4 to string failed. ");
result = ACLSHMEM_SMEM_ERROR;
} else {
ipaddr.type = IpV4;
ipaddr.addr.addrv4 = ntohl(localIp.s_addr);
local = std::string(localResultIp);
result = ACLSHMEM_SUCCESS;
}
return result;
}
static Result DetermineTargetIpType(const std::string &target, struct in_addr &targetIpV4,
struct in6_addr &targetIpV6, bool &isTargetV6)
{
if (inet_pton(AF_INET, target.c_str(), &targetIpV4) == 1) {
isTargetV6 = false;
} else if (inet_pton(AF_INET6, target.c_str(), &targetIpV6) == 1) {
isTargetV6 = true;
} else {
SHM_LOG_ERROR("target ip address invalid.");
return ACLSHMEM_INVALID_PARAM;
}
return ACLSHMEM_SUCCESS;
}
static bool IsSameNetwork(const struct in6_addr &localIp, const struct in6_addr &localMask,
const struct in6_addr &targetIp)
{
constexpr int SIZE = 16;
for (int i = 0; i < SIZE; i++) {
if ((localIp.s6_addr[i] & localMask.s6_addr[i]) != (targetIp.s6_addr[i] & localMask.s6_addr[i])) {
return false;
}
}
return true;
}
Result GetLocalIpWithTarget(const std::string &target, std::string &local, mf_ip_addr &ipaddr)
{
struct ifaddrs *ifaddr;
const int SIZE = 64;
char localResultIp[SIZE];
Result result = ACLSHMEM_SMEM_ERROR;
bool isTargetV6 = false;
struct in_addr targetIpV4;
struct in6_addr targetIpV6;
if (DetermineTargetIpType(target, targetIpV4, targetIpV6, isTargetV6) != ACLSHMEM_SUCCESS) {
return ACLSHMEM_INVALID_PARAM;
}
if (getifaddrs(&ifaddr) == -1) {
SHM_LOG_ERROR("get local net interfaces failed: " << errno << ": " << strerror(errno));
return ACLSHMEM_SMEM_ERROR;
}
for (auto ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) {
if ((ifa->ifa_addr == nullptr) || ((ifa->ifa_addr->sa_family != AF_INET) &&
(ifa->ifa_addr->sa_family != AF_INET6)) || (ifa->ifa_netmask == nullptr)) {
continue;
}
if (!isTargetV6 && ifa->ifa_addr->sa_family == AF_INET) {
auto localIp = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr)->sin_addr;
auto localMask = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_netmask)->sin_addr;
if ((localIp.s_addr & localMask.s_addr) != (targetIpV4.s_addr & localMask.s_addr)) {
continue;
}
result = GetLocalIpWithTargetWhenIpv4(localIp, localResultIp, SIZE, ipaddr, local);
break;
} else if (isTargetV6 && ifa->ifa_addr->sa_family == AF_INET6) {
auto localIp = reinterpret_cast<struct sockaddr_in6 *>(ifa->ifa_addr)->sin6_addr;
auto localMask = reinterpret_cast<struct sockaddr_in6 *>(ifa->ifa_netmask)->sin6_addr;
if (!IsSameNetwork(localIp, localMask, targetIpV6)) {
continue;
}
result = GetLocalIpWithTargetWhenIpv6(localIp, localResultIp, SIZE, ipaddr, local);
break;
}
}
freeifaddrs(ifaddr);
return result;
}
}
}