* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "domain_socket.h"
#include <string>
#include <iostream>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <thread>
#include <system_error>
#include "umask_guard.h"
#include "securec.h"
#include "file_system.h"
#include "log.h"
namespace {
constexpr mode_t SOCK_UMASK = 0177;
}
namespace Sanitizer {
DomainSocket::DomainSocket(std::string socketPath) : socketPath_(socketPath) { }
DomainSocket::~DomainSocket()
{
if (sfd_ != -1) {
close(sfd_);
}
}
Result DomainSocket::CreateSocket()
{
Result result;
sfd_ = socket(AF_UNIX, SOCK_STREAM, 0);
if (sfd_ == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("socket failed. " + ec.message());
return result;
}
if (memset_s(&addr_, sizeof(addr_), 0, sizeof(addr_)) != 0) {
result.SetError("socket struct memset_s zero failed.");
return result;
}
addr_ = sockaddr_un{};
addr_.sun_family = AF_UNIX;
size_t minSunPathLeft = 2;
socketPath_.copy(addr_.sun_path + 1, std::min(sizeof(addr_.sun_path) - minSunPathLeft, socketPath_.size()));
auto timeout = timeval {};
timeout.tv_sec = 1;
timeout.tv_usec = 0;
if (setsockopt(sfd_, SOL_SOCKET, SO_RCVTIMEO, &timeout,
sizeof(timeout)) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("setsockopt failed. " + ec.message());
return result;
}
int opt = 1;
if (setsockopt(sfd_, SOL_SOCKET, SO_PASSCRED, &opt, sizeof(opt)) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("Socket set SO_PEERCRED failed: " + ec.message());
return result;
}
return result;
}
Result DomainSocket::Clean(void)
{
Result result;
if (remove(addr_.sun_path) == -1 && errno != ENOENT) {
std::error_condition ec(errno, std::generic_category());
result.SetError("remove failed. " + ec.message());
}
return result;
}
DomainSocketServer::DomainSocketServer(std::string socketPath, std::size_t maxClientNum)
: DomainSocket(socketPath), maxClientNum_(maxClientNum) { }
DomainSocketServer::~DomainSocketServer(void)
{
for (int32_t fd : cfds_) {
close(fd);
}
unlink(addr_.sun_path);
}
Result DomainSocketServer::ListenAndBind()
{
Result result = CreateSocket();
if (result.Fail()) {
return result;
}
{
UmaskGuard umaskGuard(SOCK_UMASK);
if (bind(sfd_, reinterpret_cast<sockaddr *>(&addr_), sizeof(addr_)) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("bind failed. " + ec.message());
return result;
}
}
if (listen(sfd_, 1) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("listen failed. " + ec.message());
return result;
}
return result;
}
Result DomainSocketServer::Accept(ClientId &id)
{
Result result;
if (cfds_.size() >= maxClientNum_) {
result.SetError("over max client num. " + std::string(strerror(errno)));
return result;
}
int32_t cfd = accept(sfd_, nullptr, nullptr);
if (cfd == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("accept failed. " + ec.message());
return result;
}
struct ucred cred{};
socklen_t cred_len = sizeof(cred);
if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &cred, &cred_len) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("get client SO_PEERCRED failed: " + ec.message());
return result;
}
if (getuid() != cred.uid || getgid() != cred.gid) {
result.SetError("client SO_PEERCRED check permission failed, recv id: uid=" + std::to_string(cred.uid) +
", gid=" + std::to_string(cred.gid));
return result;
}
{
std::lock_guard<std::mutex> guard(cfdsMutex_);
id = cfds_.size();
cfds_.push_back(cfd);
}
return result;
}
std::size_t DomainSocketServer::GetClientNum() const
{
return cfds_.size();
}
Result DomainSocketServer::Read(ClientId id, std::string &message, size_t maxBytes, size_t &receivedBytes)
{
Result result;
int32_t cfd;
{
std::lock_guard<std::mutex> guard(cfdsMutex_);
if (id >= cfds_.size() || cfds_[id] == -1) {
result.SetError("invalid client id ");
return result;
}
cfd = cfds_[id];
}
std::vector<char> buffer(maxBytes);
ssize_t ret = read(cfd, buffer.data(), maxBytes);
if (ret == -1) {
result.SetError("recv failed. " + std::string(strerror(errno)));
return result;
}
receivedBytes = static_cast<size_t>(ret);
message.assign(buffer.data(), receivedBytes);
return result;
}
Result DomainSocketServer::Write(ClientId id, const std::string &message, size_t &sentBytes)
{
Result result;
int32_t cfd;
{
std::lock_guard<std::mutex> guard(cfdsMutex_);
if (id >= cfds_.size() || cfds_[id] == -1) {
result.SetError("invalid client id ");
return result;
}
cfd = cfds_[id];
}
auto buffer = message.data();
auto size = message.size();
ssize_t ret;
sentBytes = 0;
while (size > 0) {
ret = write(cfd, buffer, size);
if (ret == -1) {
result.SetError("write failed. " + std::string(strerror(errno)));
break;
}
size_t writeBytes = static_cast<size_t>(ret);
sentBytes += writeBytes;
size -= writeBytes;
buffer += writeBytes;
}
return result;
}
DomainSocketClient::DomainSocketClient(std::string socketPath) : DomainSocket(socketPath) {}
DomainSocketClient::~DomainSocketClient(void) { }
Result DomainSocketClient::Connect()
{
Result result = CreateSocket();
if (result.Fail()) {
std::cerr<<"Error in create socket:"<<strerror(errno)<<"(errno:"<<errno<<")"<<std::endl;
return result;
}
if (connect(sfd_, reinterpret_cast<sockaddr *>(&addr_), sizeof(addr_)) == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("connect failed. " + ec.message());
return result;
}
return result;
}
Result DomainSocketClient::Read(std::string &message, uint64_t maxBytes, size_t &receivedBytes)
{
Result result;
if (sfd_ == -1) {
result.SetError("connect failed. ");
return result;
}
std::vector<char> buffer(maxBytes);
ssize_t ret = read(sfd_, buffer.data(), maxBytes);
if (ret == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("read failed. " + ec.message());
return result;
}
receivedBytes = static_cast<size_t>(ret);
message.assign(buffer.data(), receivedBytes);
return result;
}
Result DomainSocketClient::Write(const std::string &message, size_t &sentBytes)
{
Result result;
if (sfd_ == -1) {
result.SetError("connect failed. ");
return result;
}
auto buffer = message.data();
auto size = message.size();
ssize_t ret;
sentBytes = 0;
while (size > 0) {
ret = write(sfd_, buffer, size);
if (ret == -1) {
std::error_condition ec(errno, std::generic_category());
result.SetError("write failed. " + ec.message());
break;
}
size_t writeBytes = static_cast<size_t>(ret);
sentBytes += writeBytes;
size -= writeBytes;
buffer += writeBytes;
}
return result;
}
}