#include <string>
#include <signal.h>
#if defined(__linux__)
#include <unistd.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/syscall.h>
#include <sys/inotify.h>
#elif defined(_WIN32)
#include <windows.h>
#include <tlhelp32.h>
#endif
#include "xsched/utils/log.h"
#include "xsched/utils/common.h"
#include "xsched/utils/xassert.h"
#include "xsched/utils/waitpid.h"
#define SCAN_INTERVAL_US 1000
#ifndef SYS_pidfd_open
#define SYS_pidfd_open 434
#endif
using namespace xsched::utils;
std::unique_ptr<PidWaiter> PidWaiter::Create(TerminateCallback callback)
{
if (callback == nullptr) XERRO("callback is nullptr");
#if defined(__linux__)
int self_pid_fd = PidFdWaiter::OpenPidFd(GetProcessId(), 0);
if (self_pid_fd == -1) {
XWARN("pidfd_open is not supported, using scan method, which may consume more CPU");
return std::make_unique<ScanPidWaiter>(callback);
}
XASSERT(!close(self_pid_fd), "fail to close self pid fd");
XINFO("pidfd_open is supported, using pidfd_wait method");
return std::make_unique<PidFdWaiter>(callback);
#elif defined(_WIN32)
return std::make_unique<WinPidWaiter>(callback);
#endif
}
#if defined(__linux__)
PidFdWaiter::~PidFdWaiter()
{
this->Stop();
}
void PidFdWaiter::Start()
{
event_fd_ = eventfd(0, EFD_CLOEXEC);
XASSERT(event_fd_ >= 0, "fail to create event fd");
epoll_fd_ = epoll_create1(EPOLL_CLOEXEC);
XASSERT(epoll_fd_ >= 0, "fail to create epoll fd");
struct epoll_event ev;
ev.events = EPOLLIN;
ev.data.u64 = PackEventData(kEpollEventTerminate, 0);
XASSERT(!epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, event_fd_, &ev),
"fail to add event fd to epoll");
thread_ = std::make_unique<std::thread>(&PidFdWaiter::WaitWorker, this);
}
void PidFdWaiter::Stop()
{
if (thread_ != nullptr) {
XASSERT(!eventfd_write(event_fd_, 1), "fail to write event fd");
thread_->join();
thread_ = nullptr;
}
if (event_fd_ >= 0) close(event_fd_);
if (epoll_fd_ >= 0) close(epoll_fd_);
for (auto& it : pid_fds_) { close(it.second); }
event_fd_ = -1;
epoll_fd_ = -1;
pid_fds_.clear();
}
void PidFdWaiter::AddWait(PID pid)
{
int pid_fd = OpenPidFd(pid, 0);
if (pid_fd < 0) {
if (errno == ESRCH) {
XDEBG("process %d is already terminated", pid);
callback_(pid);
return;
}
XERRO("fail to open pid fd for pid %d", pid);
}
mtx_.lock();
pid_fds_[pid] = pid_fd;
mtx_.unlock();
struct epoll_event ev;
ev.events = EPOLLIN;
ev.data.u64 = PackEventData(kEpollEventPid, pid);
XASSERT(!epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, pid_fd, &ev),
"fail to add pid fd to epoll");
}
void PidFdWaiter::WaitWorker()
{
struct epoll_event ev;
while (true) {
if (epoll_wait(epoll_fd_, &ev, 1, -1) == -1) {
XASSERT(errno == EINTR, "fail during epoll wait");
continue;
}
if (GetEventType(ev.data.u64) == kEpollEventTerminate) {
eventfd_t v;
XASSERT(!eventfd_read(event_fd_, &v), "fail to read event fd");
return;
}
XASSERT(GetEventType(ev.data.u64) == kEpollEventPid,
"invalid event type: %d", GetEventType(ev.data.u64));
PID pid = GetEventPid(ev.data.u64);
mtx_.lock();
auto it = pid_fds_.find(pid);
XASSERT(it != pid_fds_.end(), "pid fd not found");
int pid_fd = it->second;
pid_fds_.erase(it);
mtx_.unlock();
XASSERT(!epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, pid_fd, nullptr),
"fail to remove pid fd from epoll");
XASSERT(!close(pid_fd), "fail to close pid fd");
callback_(pid);
}
}
PID PidFdWaiter::GetEventPid(uint64_t data)
{
return PID(data & 0xFFFFFFFF);
}
EpollEventType PidFdWaiter::GetEventType(uint64_t data)
{
return EpollEventType(data >> 32);
}
uint64_t PidFdWaiter::PackEventData(EpollEventType type, PID pid)
{
return ((uint64_t)type << 32) | (uint64_t)pid;
}
int PidFdWaiter::OpenPidFd(PID pid, unsigned int flags)
{
return syscall(SYS_pidfd_open, (pid_t)pid, flags);
}
ScanPidWaiter::~ScanPidWaiter()
{
this->Stop();
}
void ScanPidWaiter::Start()
{
running_.store(true);
thread_ = std::make_unique<std::thread>(&ScanPidWaiter::WaitWorker, this);
}
void ScanPidWaiter::Stop()
{
running_.store(false);
if (thread_ != nullptr) thread_->join();
thread_ = nullptr;
}
void ScanPidWaiter::AddWait(PID pid)
{
std::lock_guard<std::mutex> lock(mtx_);
pids_.insert(pid);
}
void ScanPidWaiter::WaitWorker()
{
while (running_.load()) {
mtx_.lock();
std::unordered_set<PID> set = pids_;
mtx_.unlock();
std::list<PID> terminated;
for (auto pid : set) {
if (kill(pid, 0) == 0) continue;
if (errno != ESRCH) {
XWARN("fail to test process %d, errno: %d", pid, errno);
continue;
}
terminated.emplace_back(pid);
callback_(pid);
}
mtx_.lock();
for (auto pid : terminated) pids_.erase(pid);
mtx_.unlock();
std::this_thread::sleep_for(std::chrono::microseconds(SCAN_INTERVAL_US));
}
}
INotifyPidWaiter::~INotifyPidWaiter()
{
this->Stop();
}
void INotifyPidWaiter::Start()
{
inotify_fd_ = inotify_init1(0);
XASSERT(inotify_fd_ >= 0, "fail to create inotify fd");
thread_ = std::make_unique<std::thread>(&INotifyPidWaiter::WaitWorker, this);
}
void INotifyPidWaiter::Stop()
{
if (inotify_fd_ >= 0) close(inotify_fd_);
inotify_fd_ = -1;
if (thread_ != nullptr) thread_->join();
thread_ = nullptr;
}
void INotifyPidWaiter::AddWait(PID pid)
{
std::string proc_path = "/proc/" + std::to_string(pid);
std::lock_guard<std::mutex> lock(mtx_);
int wd = inotify_add_watch(inotify_fd_, proc_path.c_str(), IN_DELETE_SELF);
if (wd < 0) {
if (errno == ENOENT) {
XDEBG("process %d is already terminated", pid);
callback_(pid);
return;
}
XERRO("fail to add watch for pid %d", pid);
}
watch_pids_[wd] = pid;
}
void INotifyPidWaiter::WaitWorker()
{
char buf[4096];
while (inotify_fd_ >= 0) {
ssize_t n = read(inotify_fd_, buf, sizeof(buf));
if (n < 0) {
if (errno == EBADF) return;
XWARN("read error during inotify wait");
continue;
}
ssize_t i = 0;
while (i < n) {
struct inotify_event *event = (struct inotify_event *) &buf[i];
if (event->mask & IN_DELETE_SELF) {
mtx_.lock();
auto it = watch_pids_.find(event->wd);
XASSERT(it != watch_pids_.end(), "watch fd not found");
PID pid = it->second;
watch_pids_.erase(it);
mtx_.unlock();
callback_(pid);
inotify_rm_watch(inotify_fd_, event->wd);
}
i += sizeof(struct inotify_event) + event->len;
}
}
}
#elif defined(_WIN32)
WinPidWaiter::~WinPidWaiter()
{
this->Stop();
}
void WinPidWaiter::Start()
{
running_.store(true);
thread_ = std::make_unique<std::thread>(&WinPidWaiter::WaitWorker, this);
}
void WinPidWaiter::Stop()
{
running_.store(false);
if (thread_ && thread_->joinable()) {
thread_->join();
}
thread_ = nullptr;
}
void WinPidWaiter::AddWait(PID pid)
{
std::lock_guard<std::mutex> lock(mtx_);
pids_.insert(pid);
}
void WinPidWaiter::WaitWorker()
{
while (running_.load()) {
std::unordered_set<PID> current_pids;
{
std::lock_guard<std::mutex> lock(mtx_);
current_pids = pids_;
}
std::list<PID> terminated;
for (auto pid : current_pids) {
HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, pid);
if (hProcess == NULL) {
DWORD err = GetLastError();
if (err == ERROR_INVALID_PARAMETER) {
terminated.push_back(pid);
callback_(pid);
} else {
XWARN("fail to open process " FMT_PID ", error: %lu", pid, err);
}
continue;
}
DWORD exitCode;
if (GetExitCodeProcess(hProcess, &exitCode)) {
if (exitCode != STILL_ACTIVE) {
terminated.push_back(pid);
callback_(pid);
}
} else {
XWARN("fail to get exit code for process " FMT_PID, pid);
}
CloseHandle(hProcess);
}
if (!terminated.empty()) {
std::lock_guard<std::mutex> lock(mtx_);
for (auto pid : terminated) {
pids_.erase(pid);
}
}
std::this_thread::sleep_for(std::chrono::microseconds(SCAN_INTERVAL_US));
}
}
#endif