#include <sys/epoll.h>
#include <sys/prctl.h>
#include <sys/signalfd.h>
#include <sys/timerfd.h>
#include <array>
#include <csignal>
#include <cstddef>
#include <cstring>
#include <filesystem>
#include <iostream>
#include <vector>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <unistd.h>
constexpr std::array unblock_signals{ SIGABRT, SIGBUS, SIGFPE, SIGILL, SIGSEGV,
SIGSYS, SIGTRAP, SIGXCPU, SIGXFSZ };
namespace {
struct WaitPidResult
{
pid_t pid;
int status;
};
void print_sys_error(std::string_view msg) noexcept
{
std::cerr << msg << ": " << ::strerror(errno) << std::endl;
}
void print_sys_error(std::string_view msg, const std::error_code &ec) noexcept
{
std::cerr << msg << ": " << ec.message() << std::endl;
}
void print_info(std::string_view msg) noexcept
{
static const auto is_debug = ::getenv("LINYAPS_INIT_VERBOSE_OUTPUT") != nullptr;
if (is_debug) {
std::cerr << msg << std::endl;
}
}
class sigConf
{
public:
sigConf() noexcept = default;
sigConf(const sigConf &) = delete;
sigConf(sigConf &&) = delete;
sigConf &operator=(const sigConf &) = delete;
sigConf &operator=(sigConf &&) = delete;
~sigConf() noexcept = default;
bool block_signals() noexcept
{
::sigfillset(&cur_set);
for (auto signal : unblock_signals) {
::sigdelset(&cur_set, signal);
}
auto ret = ::sigprocmask(SIG_SETMASK, &cur_set, &old_set);
if (ret == -1) {
print_sys_error("Failed to set signal mask");
return false;
}
return true;
}
[[nodiscard]] bool restore_signals() const noexcept
{
auto ret = ::sigprocmask(SIG_SETMASK, &old_set, nullptr);
if (ret == -1) {
print_sys_error("Failed to restore signal mask");
return false;
}
return true;
}
[[nodiscard]] const sigset_t ¤t_sigset() const noexcept { return cur_set; }
private:
sigset_t cur_set{};
sigset_t old_set{};
};
class file_descriptor_wrapper
{
public:
explicit file_descriptor_wrapper(int fd) noexcept
: fd(fd)
{
}
file_descriptor_wrapper() noexcept = default;
file_descriptor_wrapper(const file_descriptor_wrapper &) = delete;
file_descriptor_wrapper &operator=(const file_descriptor_wrapper &) = delete;
file_descriptor_wrapper(file_descriptor_wrapper &&other) noexcept
: fd(other.fd)
{
other.fd = -1;
}
file_descriptor_wrapper &operator=(file_descriptor_wrapper &&other) noexcept
{
if (this == &other) {
return *this;
}
close();
fd = other.fd;
other.fd = -1;
return *this;
}
~file_descriptor_wrapper() noexcept { close(); }
void close() noexcept
{
if (fd != -1) {
::close(fd);
fd = -1;
}
}
explicit operator bool() const noexcept { return fd != -1; }
operator int() const noexcept { return fd; }
private:
int fd{ -1 };
};
file_descriptor_wrapper create_signalfd(const sigset_t &sigset) noexcept
{
auto fd = ::signalfd(-1, &sigset, SFD_NONBLOCK);
if (fd == -1) {
print_sys_error("Failed to create signalfd");
}
return file_descriptor_wrapper(fd);
}
file_descriptor_wrapper create_timerfd() noexcept
{
auto fd = ::timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK);
if (fd == -1) {
print_sys_error("Failed to create timerfd");
}
return file_descriptor_wrapper(fd);
}
file_descriptor_wrapper create_epoll() noexcept
{
auto fd = ::epoll_create1(0);
if (fd == -1) {
print_sys_error("Failed to create epoll");
}
return file_descriptor_wrapper(fd);
}
template <std::size_t N>
constexpr auto make_array(const char (&str)[N]) noexcept
{
static_assert(N > 0, "N must be greater than 0");
std::array<char, N - 1> arr{};
for (std::size_t i = 0; i < N - 1; ++i) {
arr[i] = str[i];
}
return arr;
}
std::pair<struct sockaddr_un, socklen_t> get_socket_address() noexcept
{
constexpr auto fs_addr{ make_array("/run/linglong/init/socket") };
struct sockaddr_un addr{};
addr.sun_family = AF_UNIX;
std::copy(fs_addr.cbegin(), fs_addr.cend(), &addr.sun_path[0]);
addr.sun_path[fs_addr.size()] = 0;
return std::make_pair(addr, offsetof(sockaddr_un, sun_path) + fs_addr.size());
}
file_descriptor_wrapper create_fs_uds() noexcept
{
auto fd = ::socket(AF_UNIX, SOCK_NONBLOCK | SOCK_SEQPACKET, 0);
file_descriptor_wrapper socket_fd{ fd };
if (fd == -1) {
print_sys_error("Failed to create unix domain socket");
return socket_fd;
}
auto [addr, len] = get_socket_address();
if (len == 0) {
print_info("Failed to get socket address");
return socket_fd;
}
::unlink(addr.sun_path);
auto ret = ::bind(socket_fd, reinterpret_cast<struct sockaddr *>(&addr), len);
if (ret == -1) {
print_sys_error("Failed to bind unix domain socket");
return socket_fd;
}
ret = ::listen(socket_fd, 1);
if (ret == -1) {
print_sys_error("Failed to listen on unix domain socket");
return socket_fd;
}
return socket_fd;
}
std::vector<const char *> parse_args(int argc, char *argv[]) noexcept
{
std::vector<const char *> args;
int idx{ 1 };
while (idx < argc) {
args.emplace_back(argv[idx++]);
}
args.emplace_back(nullptr);
return args;
}
void print_child_status(int status, const std::string &pid) noexcept
{
if (WIFEXITED(status)) {
print_info("child " + pid + " exited with status " + std::to_string(WEXITSTATUS(status)));
} else if (WIFSIGNALED(status)) {
print_info("child " + pid + " exited with signal " + std::to_string(WTERMSIG(status)));
} else {
print_info("child " + pid + " exited with unknown status");
}
}
pid_t run(std::vector<const char *> args, const sigConf &conf) noexcept
{
auto pid = ::fork();
if (pid == -1) {
print_sys_error("Failed to fork");
return -1;
}
if (pid == 0) {
auto ret = ::setpgid(0, 0);
if (ret == -1) {
print_sys_error("Failed to set process group");
return -1;
}
ret = ::tcsetpgrp(0, ::getpid());
if (ret == -1 && errno != ENOTTY) {
print_sys_error("Failed to set terminal process group");
return -1;
}
if (!conf.restore_signals()) {
return -1;
}
::execvp(args[0], const_cast<char *const *>(args.data()));
print_sys_error("Failed to run process");
::_exit(EXIT_FAILURE);
}
return pid;
}
bool handle_sigevent(const file_descriptor_wrapper &sigfd,
pid_t child,
struct WaitPidResult &waitChild) noexcept
{
while (true) {
signalfd_siginfo info{};
auto ret = ::read(sigfd, &info, sizeof(info));
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break;
}
print_sys_error("Failed to read from signalfd");
return false;
}
if (info.ssi_signo != SIGCHLD) {
if (info.ssi_pid != 0) {
auto ret = ::kill(child, info.ssi_signo);
if (ret == -1) {
auto msg =
std::string("Failed to forward signal ") + ::strsignal(info.ssi_signo);
print_sys_error(msg);
}
}
print_info("Received signal " + std::to_string(info.ssi_signo)
+ " from kernel, just ignore it");
continue;
}
while (true) {
int status{};
auto ret = ::waitpid(-1, &status, WNOHANG);
if (ret == 0 || (ret == -1 && errno == ECHILD)) {
break;
}
if (ret == -1) {
print_sys_error("Failed to wait for child");
return false;
}
print_child_status(status, std::to_string(ret));
if (ret == child) {
waitChild.pid = child;
waitChild.status = status;
}
}
}
return true;
}
bool shouldWait() noexcept
{
std::error_code ec;
auto proc_it = std::filesystem::directory_iterator{
"/proc",
std::filesystem::directory_options::skip_permission_denied,
ec
};
if (ec) {
print_sys_error("Failed to open /proc", ec);
return false;
}
for (const auto &entry : proc_it) {
if (!entry.is_directory(ec)) {
continue;
}
if (ec) {
print_sys_error("Failed to stat " + entry.path().string(), ec);
return false;
}
pid_t pid{ -1 };
try {
pid = std::stoi(entry.path().filename());
} catch (...) {
continue;
}
if (pid == 1) {
continue;
}
return true;
}
return false;
}
int handle_timerfdevent(const file_descriptor_wrapper &timerfd) noexcept
{
while (true) {
uint64_t expir{};
auto ret = ::read(timerfd, &expir, sizeof(expir));
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break;
}
print_sys_error("Failed to read from timerfd");
return -1;
}
}
return shouldWait() ? 1 : 0;
}
file_descriptor_wrapper start_timer(const file_descriptor_wrapper &epfd) noexcept
{
auto timerfd = create_timerfd();
if (!timerfd) {
return timerfd;
}
struct itimerspec timer_spec{};
auto *interval = ::getenv("LINYAPS_INIT_WAIT_INTERVAL");
constexpr auto default_interval{ 3 };
if (interval != nullptr) {
try {
auto interval_int = std::stoi(interval);
timer_spec.it_interval.tv_sec = interval_int;
timer_spec.it_interval.tv_nsec = 0;
} catch (...) {
print_info("Invalid interval, using default 3 seconds");
timer_spec.it_interval.tv_sec = default_interval;
timer_spec.it_interval.tv_nsec = 0;
}
} else {
timer_spec.it_interval.tv_sec = default_interval;
timer_spec.it_interval.tv_nsec = 0;
}
timer_spec.it_value.tv_sec = default_interval;
timer_spec.it_value.tv_nsec = 0;
auto ret = ::timerfd_settime(timerfd, 0, &timer_spec, nullptr);
if (ret == -1) {
print_sys_error("Failed to set timerfd");
return {};
}
struct epoll_event timer_ev{ .events = EPOLLIN | EPOLLET, .data = { .fd = timerfd } };
ret = ::epoll_ctl(epfd, EPOLL_CTL_ADD, timerfd, &timer_ev);
if (ret == -1) {
print_sys_error("Failed to add timerfd to epoll");
return {};
}
return timerfd;
}
unsigned long get_arg_max() noexcept
{
auto arg_max = sysconf(_SC_ARG_MAX);
if (arg_max == -1) {
return static_cast<unsigned long>(256 * 1024);
}
return arg_max - 4096;
}
int delegate_run(const std::vector<std::string> &args, const sigConf &conf) noexcept
{
auto child = fork();
if (child == -1) {
print_sys_error("Failed to fork child");
return -1;
}
if (child == 0) {
std::vector<char *> c_args;
c_args.reserve(args.size());
for (const auto &arg : args) {
c_args.emplace_back(const_cast<char *>(arg.c_str()));
}
c_args.emplace_back(nullptr);
if (!conf.restore_signals()) {
::_exit(EXIT_FAILURE);
}
::execvp(c_args[0], c_args.data());
print_sys_error("Failed to exec");
::_exit(EXIT_FAILURE);
}
return 0;
}
bool handle_client(const file_descriptor_wrapper &unix_socket, const sigConf &conf) noexcept
{
static const unsigned long arg_max = get_arg_max();
const file_descriptor_wrapper client{ ::accept(unix_socket, nullptr, nullptr) };
if (!client) {
print_sys_error("Failed to accept client");
return false;
}
const struct timeval tv{ 3, 0 };
if (::setsockopt(client, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == -1) {
print_sys_error("Failed to set recv timeout");
return false;
}
std::uint64_t arg_len{ 0 };
auto ret = ::recv(client, &arg_len, sizeof(arg_len), 0);
if (ret == -1) {
print_sys_error("Failed to read from client");
return false;
}
if (arg_len > arg_max) {
print_info("Command too long");
return false;
}
std::string command_buffer;
command_buffer.reserve(arg_len);
std::array<char, 4096> buffer{};
while (true) {
auto readBytes = ::recv(client, buffer.data(), buffer.size(), 0);
if (readBytes < 0) {
print_sys_error("Failed to read from client");
return false;
}
command_buffer.append(buffer.data(), readBytes);
if (command_buffer.size() >= arg_len) {
break;
}
}
if (command_buffer.empty()) {
print_info("Empty command");
return false;
}
if (command_buffer.back() != 0) {
command_buffer.push_back(0);
}
std::vector<std::string> commands;
auto start = command_buffer.cbegin();
while (start != command_buffer.end()) {
auto end{ start };
while (end != command_buffer.end() && *end != 0) {
++end;
}
commands.emplace_back(start, end);
start = end + 1;
}
if (commands.empty()) {
print_info("Command may be invalid");
return false;
}
ret = delegate_run(commands, conf);
if (ret == -1) {
print_sys_error("Failed to delegate command");
}
if (::send(client, &ret, sizeof(ret), 0) == -1) {
print_sys_error("Failed to send result to client");
}
return true;
}
bool register_event(const file_descriptor_wrapper &epfd,
const file_descriptor_wrapper &fd,
epoll_event ev) noexcept
{
auto ret = ::epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev);
if (ret == -1) {
print_sys_error("Failed to add event to epoll");
return false;
}
return true;
}
}
int main(int argc, char **argv)
{
sigConf conf;
if (!conf.block_signals()) {
return -1;
}
auto ret = ::prctl(PR_SET_CHILD_SUBREAPER, 1);
if (ret == -1) {
print_sys_error("Failed to set child subreaper");
return -1;
}
auto args = parse_args(argc, argv);
if (args.empty()) {
print_info("No arguments provided");
return -1;
}
auto *singleModeEnv = ::getenv("LINYAPS_INIT_SINGLE_MODE");
const bool singleMode = singleModeEnv != nullptr && std::string_view{ singleModeEnv } == "1";
auto child = run(args, conf);
if (child == -1) {
print_info("Failed to run child process");
return -1;
}
auto epfd = create_epoll();
if (!epfd) {
return -1;
}
auto sigfd = create_signalfd(conf.current_sigset());
if (!sigfd) {
return -1;
}
const struct epoll_event ev{ .events = EPOLLIN | EPOLLET, .data = { .fd = sigfd } };
if (!register_event(epfd, sigfd, ev)) {
return -1;
}
file_descriptor_wrapper unix_socket;
if (!singleMode) {
unix_socket = create_fs_uds();
if (!unix_socket) {
return -1;
}
const struct epoll_event ev{ .events = EPOLLIN | EPOLLET,
.data = { .fd = unix_socket } };
if (!register_event(epfd, unix_socket, ev)) {
return -1;
}
}
file_descriptor_wrapper timerfd;
bool done{ false };
std::array<struct epoll_event, 10> events{};
WaitPidResult waitChild{ .pid = child };
int childExitCode = 0;
while (true) {
ret = ::epoll_wait(epfd, events.data(), events.size(), -1);
if (ret == -1) {
if (errno == EINTR) {
continue;
}
print_sys_error("Failed to wait for events");
return -1;
}
for (auto i = 0; i < ret; ++i) {
const auto event = events.at(i);
if (event.data.fd == sigfd) {
if (!handle_sigevent(sigfd, waitChild.pid, waitChild)) {
return -1;
}
if (waitChild.pid == child) {
if (WIFEXITED(waitChild.status)) {
waitChild.pid = -1;
childExitCode = WEXITSTATUS(waitChild.status);
} else if (WIFSIGNALED(waitChild.status)) {
waitChild.pid = -1;
childExitCode = 128 + WTERMSIG(waitChild.status);
}
if (!shouldWait()) {
done = true;
}
timerfd = start_timer(epfd);
if (!timerfd) {
return -1;
}
}
continue;
}
if (event.data.fd == timerfd) {
ret = handle_timerfdevent(timerfd);
if (ret == -1) {
return -1;
}
if (ret == 0) {
done = true;
}
continue;
}
if (unix_socket && event.data.fd == unix_socket) {
if (handle_client(unix_socket, conf)) {
done = false;
}
}
}
if (done) {
unix_socket.close();
break;
}
}
return childExitCode;
}