#pragma once
#include "common/config.h"
#include "common/simple_log.h"
#include "benchmark/benchmark_runner.h"
#include "benchmark/kv_client_adapter.h"
#include <datasystem/kv_client.h>
#include <datasystem/utils/connection.h>
#include <datasystem/utils/service_discovery.h>
#include <unistd.h>
#include <sys/wait.h>
#include <cstdlib>
#include <csignal>
#include <chrono>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <memory>
#include <string>
#include <thread>
#include <vector>
enum ChildCmd : int32_t { CMD_EXIT = 0, CMD_RUN_SET = 1, CMD_RUN_GET = 2,
CMD_RUN_DEL = 3, CMD_RUN_MSET = 4, CMD_RUN_MGET = 5 };
enum ChildRole : int32_t { ROLE_SET = 0, ROLE_GET = 1, ROLE_DEL = 2 };
struct CmdMsg {
int32_t cmd = 0;
int32_t round = 0;
int32_t numThreads = 0;
};
struct ResultMsg {
int32_t successCount = 0;
int32_t reserved = 0;
double avgMs = 0;
double p50Ms = 0;
double p90Ms = 0;
double p99Ms = 0;
double maxMs = 0;
double totalLatMs = 0;
};
struct InitMsg {
int32_t ok = 0;
char errorMsg[256] = {};
};
inline bool WriteExact(int fd, const void *buf, size_t len) {
const char *p = static_cast<const char *>(buf);
size_t written = 0;
while (written < len) {
ssize_t n = write(fd, p + written, len - written);
if (n <= 0) return false;
written += static_cast<size_t>(n);
}
return true;
}
inline bool ReadExact(int fd, void *buf, size_t len) {
char *p = static_cast<char *>(buf);
size_t got = 0;
while (got < len) {
ssize_t n = read(fd, p + got, len - got);
if (n <= 0) return false;
got += static_cast<size_t>(n);
}
return true;
}
struct ChildProcess {
pid_t pid = -1;
int toChildFd = -1;
int fromChildFd = -1;
ChildRole role{};
bool initOk = false;
};
inline bool RoleUsesServiceDiscovery(ChildRole role, TestMode testMode) {
switch (role) {
case ROLE_SET:
return testMode == TestMode::SET_LOCAL
|| testMode == TestMode::GET_LOCAL
|| testMode == TestMode::GET_REMOTE_CROSS
|| testMode == TestMode::MIXED_LOCAL_SET_GET
|| testMode == TestMode::MIXED_LOCAL_SET_CROSS_GET
|| testMode == TestMode::MSET_LOCAL
|| testMode == TestMode::MGET_LOCAL
|| testMode == TestMode::MGET_REMOTE_CROSS;
case ROLE_GET:
return testMode == TestMode::GET_LOCAL
|| testMode == TestMode::GET_CROSS_NODE
|| testMode == TestMode::MIXED_LOCAL_SET_GET
|| testMode == TestMode::MIXED_REMOTE_SET_REMOTE_CROSS_GET
|| testMode == TestMode::MGET_LOCAL
|| testMode == TestMode::MGET_CROSS_NODE;
case ROLE_DEL:
return RoleUsesServiceDiscovery(ROLE_SET, testMode);
}
return false;
}
inline bool NeedsSeparateGetChild(TestMode testMode) {
return testMode == TestMode::GET_CROSS_NODE
|| testMode == TestMode::GET_REMOTE_CROSS
|| IsMixedMode(testMode)
|| IsMGetMode(testMode);
}
inline std::shared_ptr<datasystem::KVClient> CreateClientForRole(
ChildRole role, const Config &cfg) {
using namespace datasystem;
bool useSD = RoleUsesServiceDiscovery(role, cfg.testMode);
ConnectOptions opts;
opts.connectTimeoutMs = cfg.connectTimeoutMs;
opts.enableCrossNodeConnection = cfg.enableCrossNodeConnection;
opts.fastTransportMemSize = cfg.fastTransportMemSize;
if (role == ROLE_DEL) {
opts.requestTimeoutMs = 5000;
} else {
opts.requestTimeoutMs = cfg.requestTimeoutMs;
}
if (useSD) {
ServiceDiscoveryOptions sdOpts;
sdOpts.etcdAddress = cfg.etcdAddress;
sdOpts.clusterName = cfg.clusterName;
sdOpts.hostIdEnvName = cfg.hostIdEnvName;
auto sd = std::make_shared<ServiceDiscovery>(sdOpts);
Status rc = sd->Init();
if (!rc.IsOk()) {
SLOG_ERROR("Child ServiceDiscovery init failed: " << rc.GetMsg());
return nullptr;
}
opts.serviceDiscovery = sd;
} else {
opts.host = cfg.remoteWorker.host;
opts.port = cfg.remoteWorker.port;
}
auto client = std::make_shared<KVClient>(opts);
Status rc = client->Init();
if (!rc.IsOk()) {
SLOG_ERROR("Child KVClient init failed: " << rc.GetMsg());
return nullptr;
}
return client;
}
inline PhaseResult RunPhaseMultiThread(
KVClientAdapter *adapter, ChildCmd phase, int round,
int numThreads, int keysPerRound, const std::string &setApi,
const std::string &data, int instanceId,
int msetBatchSize = 8, int mgetBatchSize = 8) {
std::vector<PhaseResult> threadResults(numThreads);
std::vector<std::thread> threads;
for (int t = 0; t < numThreads; t++) {
threads.emplace_back([&, t]() {
auto range = ThreadKeyRange(keysPerRound, numThreads, t);
int startKey = range.first;
int numKeys = range.second;
if (numKeys == 0) return;
switch (phase) {
case CMD_RUN_SET:
threadResults[t] = RunSetPhase(adapter, instanceId, round, startKey, numKeys, setApi, data);
break;
case CMD_RUN_GET:
threadResults[t] = RunGetPhase(adapter, instanceId, round, startKey, numKeys);
break;
case CMD_RUN_DEL:
threadResults[t] = RunDelPhase(adapter, instanceId, round, startKey, numKeys);
break;
case CMD_RUN_MSET:
threadResults[t] = RunMSetPhase(adapter, instanceId, round, startKey, numKeys,
msetBatchSize, data);
break;
case CMD_RUN_MGET:
threadResults[t] = RunMGetPhase(adapter, instanceId, round, startKey, numKeys,
mgetBatchSize);
break;
default:
break;
}
});
}
for (auto &t : threads) t.join();
PhaseResult merged;
for (auto &r : threadResults) {
merged.successCount += r.successCount;
merged.latenciesMs.insert(merged.latenciesMs.end(),
r.latenciesMs.begin(), r.latenciesMs.end());
}
return merged;
}
inline ResultMsg PhaseResultToMsg(const PhaseResult &result) {
ResultMsg msg{};
msg.successCount = result.successCount;
if (result.latenciesMs.empty()) return msg;
auto pct = ComputePercentiles(result.latenciesMs);
msg.avgMs = pct.avg;
msg.p50Ms = pct.p50;
msg.p90Ms = pct.p90;
msg.p99Ms = pct.p99;
msg.maxMs = pct.max;
double total = 0;
for (auto v : result.latenciesMs) total += v;
msg.totalLatMs = total;
return msg;
}
inline void ChildProcessMain(int readFd, int writeFd, const Config &cfg, ChildRole role) {
signal(SIGINT, SIG_IGN);
signal(SIGPIPE, SIG_IGN);
const char *roleName = (role == ROLE_SET) ? "set" :
(role == ROLE_GET) ? "get" : "del";
std::string logPath = cfg.outputDir + "/child_" + std::string(roleName) + ".log";
static std::ofstream childLog(logPath, std::ios::app);
if (childLog.is_open()) {
std::cout.rdbuf(childLog.rdbuf());
std::cerr.rdbuf(childLog.rdbuf());
}
SLOG_INFO("Child process started, role=" << roleName << ", pid=" << getpid());
setenv("CLIENT_MEMORY_COPY_THREAD_NUM", "0", 1);
setenv("CLIENT_MEMORY_COPY_THREAD_NUM_PER_KEY", "0", 1);
setenv("CLIENT_MEMCOPY_PARALLEL_THRESHOLD", "2147483647", 1);
setenv("CLIENT_PARALLEL_THREAD_MIN_NUM", "0", 1);
auto client = CreateClientForRole(role, cfg);
InitMsg init{};
if (!client) {
init.ok = 0;
snprintf(init.errorMsg, sizeof(init.errorMsg), "KVClient init failed for role %s", roleName);
WriteExact(writeFd, &init, sizeof(init));
_exit(1);
}
init.ok = 1;
if (!WriteExact(writeFd, &init, sizeof(init))) _exit(1);
SLOG_INFO("Child " << roleName << " KVClient initialized OK, waiting 3s for init to settle...");
std::this_thread::sleep_for(std::chrono::seconds(3));
datasystem::SetParam param;
param.writeMode = datasystem::WriteMode::NONE_L2_CACHE_EVICT;
if (cfg.ttlSeconds > 0) param.ttlSecond = cfg.ttlSeconds;
KVClientAdapter adapter(client, param);
uint64_t dataSize = cfg.dataSizes[0];
std::string data(dataSize, 'A');
int keysPerRound = CalcKeysPerRound(cfg.workerMemoryMb, dataSize);
while (true) {
CmdMsg cmd{};
if (!ReadExact(readFd, &cmd, sizeof(cmd))) break;
if (cmd.cmd == CMD_EXIT) break;
ChildCmd phase = static_cast<ChildCmd>(cmd.cmd);
PhaseResult result;
if (phase == CMD_RUN_MSET) {
result = RunMSetPhase(&adapter, cfg.instanceId, cmd.round, 0,
keysPerRound, cfg.msetBatchSize, data);
} else if (phase == CMD_RUN_MGET) {
result = RunMGetPhase(&adapter, cfg.instanceId, cmd.round, 0,
keysPerRound, cfg.mgetBatchSize);
} else {
int nThreads = cmd.numThreads > 0 ? cmd.numThreads : cfg.numThreads;
result = RunPhaseMultiThread(
&adapter, phase, cmd.round, nThreads, keysPerRound,
cfg.setApi, data, cfg.instanceId,
cfg.msetBatchSize, cfg.mgetBatchSize);
}
ResultMsg msg = PhaseResultToMsg(result);
if (!WriteExact(writeFd, &msg, sizeof(msg))) break;
SLOG_INFO("Child " << roleName << " round=" << cmd.round
<< " phase=" << cmd.cmd << " ok=" << result.successCount);
}
SLOG_INFO("Child " << roleName << " waiting 3s for in-flight operations to complete...");
std::this_thread::sleep_for(std::chrono::seconds(3));
SLOG_INFO("Child " << roleName << " exiting");
_exit(0);
}
inline ChildProcess SpawnChild(const Config &cfg, ChildRole role) {
ChildProcess cp;
int toChild[2];
int fromChild[2];
if (pipe(toChild) != 0 || pipe(fromChild) != 0) {
SLOG_ERROR("pipe() failed: " << strerror(errno));
return cp;
}
pid_t pid = fork();
if (pid < 0) {
SLOG_ERROR("fork() failed: " << strerror(errno));
close(toChild[0]); close(toChild[1]);
close(fromChild[0]); close(fromChild[1]);
return cp;
}
if (pid == 0) {
close(toChild[1]);
close(fromChild[0]);
ChildProcessMain(toChild[0], fromChild[1], cfg, role);
_exit(1);
}
close(toChild[0]);
close(fromChild[1]);
cp.pid = pid;
cp.toChildFd = toChild[1];
cp.fromChildFd = fromChild[0];
cp.role = role;
return cp;
}
inline bool WaitForInit(ChildProcess &cp) {
InitMsg init{};
if (!ReadExact(cp.fromChildFd, &init, sizeof(init))) {
SLOG_ERROR("Child (pid=" << cp.pid << ") init failed: pipe closed");
return false;
}
cp.initOk = (init.ok == 1);
if (!cp.initOk) {
SLOG_ERROR("Child (pid=" << cp.pid << ") init failed: " << init.errorMsg);
}
return cp.initOk;
}
inline bool SendCommand(const ChildProcess &cp, ChildCmd cmd, int32_t round,
int32_t numThreads = 0) {
CmdMsg msg{cmd, round, numThreads};
return WriteExact(cp.toChildFd, &msg, sizeof(msg));
}
inline bool RecvResult(const ChildProcess &cp, ResultMsg &result) {
return ReadExact(cp.fromChildFd, &result, sizeof(result));
}
inline void ShutdownChild(ChildProcess &cp) {
if (cp.pid <= 0) return;
CmdMsg exitCmd{CMD_EXIT, 0};
WriteExact(cp.toChildFd, &exitCmd, sizeof(exitCmd));
close(cp.toChildFd);
close(cp.fromChildFd);
int status;
for (int i = 0; i < 50; i++) {
pid_t ret = waitpid(cp.pid, &status, WNOHANG);
if (ret == cp.pid) {
cp.pid = -1;
return;
}
if (ret < 0) break;
usleep(100000);
}
if (cp.pid > 0) {
kill(cp.pid, SIGKILL);
waitpid(cp.pid, &status, 0);
cp.pid = -1;
}
}
inline void KillAllChildren(std::vector<ChildProcess> &children) {
for (auto &cp : children) {
if (cp.pid > 0) {
kill(cp.pid, SIGKILL);
}
}
for (auto &cp : children) {
if (cp.pid > 0) {
int status;
waitpid(cp.pid, &status, 0);
}
if (cp.toChildFd >= 0) close(cp.toChildFd);
if (cp.fromChildFd >= 0) close(cp.fromChildFd);
cp.pid = -1;
}
}