* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: fd pass over scmtcp test.
*/
#include "datasystem/common/rpc/unix_sock_fd.h"
#include "datasystem/common/util/fd_pass.h"
#include <gtest/gtest.h>
#include <sys/mman.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <atomic>
#include <cerrno>
#include <cstdint>
#include <memory>
#include <string>
#include <thread>
#include <securec.h>
#include "ut/common.h"
#include "datasystem/common/util/format.h"
#include "datasystem/common/util/status_helper.h"
namespace datasystem {
namespace ut {
constexpr const char *SHM_CONTENT = "Hello, Shared Memory!";
const size_t MMAP_SIZE = 1024;
class ScmTcpSocketClient {
public:
ScmTcpSocketClient(uint32_t port) : tcpPort_(port)
{
}
~ScmTcpSocketClient()
{
if (serverFd_ > 0) {
RETRY_ON_EINTR(close(serverFd_));
}
}
void Connect()
{
auto endpoint = FormatString("tcp://%s:%d", "127.0.0.1", tcpPort_);
UnixSockFd sock(RPC_NO_FILE_FD, true);
DS_ASSERT_OK(sock.Connect(endpoint));
DS_ASSERT_OK(sock.SetTimeout(STUB_FRONTEND_TIMEOUT));
serverFd_ = sock.GetFd();
int waitTimeSec = 2;
sleep(waitTimeSec);
}
int32_t GetServerFd() const
{
return serverFd_;
}
private:
std::string sockPath_;
int32_t serverFd_ = -1;
uint32_t tcpPort_;
};
class ScmTcpSocketServer {
public:
ScmTcpSocketServer(uint32_t port) : tcpPort_(port)
{
}
~ScmTcpSocketServer()
{
stop_ = true;
if (acceptTrd_.joinable()) {
acceptTrd_.join();
}
if (clientFd_ > 0) {
RETRY_ON_EINTR(close(clientFd_));
}
if (listenFd_ > 0) {
RETRY_ON_EINTR(close(listenFd_));
}
}
void Init()
{
UnixSockFd scmSockFd(RPC_NO_FILE_FD, true);
std::string tmp;
auto rc = scmSockFd.Bind(FormatString("tcp://%s:%d", "127.0.0.1", tcpPort_), RPC_SOCK_MODE, tmp);
if (rc.IsError()) {
if (errno == EINVAL) {
GTEST_SKIP() << "SCM TCP not supported on this platform";
return;
}
ASSERT_TRUE(false) << "Failed to bind SCM TCP socket: " << rc.ToString();
}
listenFd_ = scmSockFd.GetFd();
std::string fileTemplate = "/dev/shm/test-XXXXXX";
std::vector<char> fileName(fileTemplate.begin(), fileTemplate.end());
fileName.push_back('\0');
fileFd_ = mkstemp(&fileName[0]);
ASSERT_GE(fileFd_, 0);
LOG(INFO) << "File descriptor: " << fileFd_;
ASSERT_EQ(unlink(&fileName[0]), 0);
ASSERT_EQ(ftruncate(fileFd_, (off_t)MMAP_SIZE), 0);
char *space = (char *)mmap(nullptr, MMAP_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fileFd_, 0);
ASSERT_EQ(memcpy_s(space, MMAP_SIZE, SHM_CONTENT, strlen(SHM_CONTENT)), EOK);
acceptTrd_ = std::thread(std::bind(&ScmTcpSocketServer::Start, this));
}
int GetClientFd() const
{
return clientFd_;
}
int GetFileFd() const
{
return fileFd_;
}
private:
void Start()
{
while (!stop_) {
struct pollfd pfd{ .fd = listenFd_, .events = POLLIN, .revents = 0 };
int n = poll(&pfd, 1, RPC_POLL_TIME);
if (n <= 0) {
continue;
}
int fd = accept(listenFd_, nullptr, 0);
if (fd > 0) {
clientFd_ = fd;
}
}
}
std::thread acceptTrd_;
std::atomic<int> clientFd_ = -1;
int listenFd_ = -1;
std::atomic<bool> stop_ = false;
int fileFd_ = -1;
uint32_t tcpPort_;
};
class FdPassOverTcpTest : public CommonTest {
public:
void SetUp() override
{
GetFreePort();
StartServer();
}
void TearDown() override
{
}
void StartServer()
{
server_ = std::make_unique<ScmTcpSocketServer>(port_);
server_->Init();
}
void GetFreePort()
{
int sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ASSERT_GE(sockfd, 0);
struct sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
addr.sin_port = 0;
ASSERT_EQ(bind(sockfd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr)), 0);
socklen_t addrLen = sizeof(addr);
ASSERT_EQ(getsockname(sockfd, reinterpret_cast<struct sockaddr *>(&addr), &addrLen), 0);
port_ = ntohs(addr.sin_port);
LOG(INFO) << "Free port: " << port_;
RETRY_ON_EINTR(close(sockfd));
}
protected:
std::unique_ptr<ScmTcpSocketServer> server_;
uint32_t port_ = 0;
const uint64_t requestId_ = 1;
};
TEST_F(FdPassOverTcpTest, TestBasicFunction)
{
LOG(INFO) << "Test fd pass over scmtpc basic function.";
ScmTcpSocketClient client(port_);
client.Connect();
int clientFd = server_->GetClientFd();
ASSERT_GE(clientFd, 0);
DS_ASSERT_OK(SockSendFd(clientFd, true, { server_->GetFileFd() }, requestId_));
std::vector<int> revcFds;
uint64_t recvRequestId = 0;
DS_ASSERT_OK(SockRecvFd(client.GetServerFd(), true, revcFds, recvRequestId));
ASSERT_EQ(recvRequestId, requestId_);
ASSERT_EQ(revcFds.size(), 1);
ASSERT_GE(revcFds[0], 0);
auto *pointer = (uint8_t *)mmap(nullptr, MMAP_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, revcFds[0], 0);
ASSERT_NE(pointer, MAP_FAILED);
std::string msg;
msg.resize(strlen(SHM_CONTENT));
ASSERT_EQ(memcpy_s(const_cast<char *>(msg.data()), msg.size(), pointer, msg.size()), EOK);
ASSERT_EQ(msg, SHM_CONTENT);
LOG(INFO) << "Test fd pass over scmtcp basic function done.";
}
TEST_F(FdPassOverTcpTest, TestPassRandomInteger)
{
LOG(INFO) << "Test fd pass with a random integer (not a fd).";
const int fd = 438;
int sock1 = -1;
DS_ASSERT_NOT_OK(SockSendFd(sock1, false, { fd }, requestId_));
LOG(INFO) << "Test fd pass with a random integer (not a fd) done.";
}
TEST_F(FdPassOverTcpTest, TestSendFdToDisconnectedClient)
{
LOG(INFO) << "Test send fd to disconnected client.";
{
ScmTcpSocketClient client(port_);
client.Connect();
}
const uint64_t waitTime = 100'000;
usleep(waitTime);
int clientFd = server_->GetClientFd();
ASSERT_GE(clientFd, 0);
DS_ASSERT_NOT_OK(SockSendFd(clientFd, false, { server_->GetFileFd() }, requestId_));
LOG(INFO) << "Test send fd to disconnected client done.";
}
}
}