/*
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* ATF is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* atf_test.cpp
* ATF Server Functional Test (only tests whether ATF can correctly obtain the cluster state,
* and does not involve joint testing with the client)
*
* IDENTIFICATION
* ATF/atf_test.cpp
*
* -------------------------------------------------------------------------
*/
#define _GNU_SOURCE
#include <iostream>
#include <string>
#include <cstring>
#include <cerrno>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <sys/time.h>
#include <sys/select.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
// 配置项(适配服务端非阻塞ET模式)
const std::string SERVER_IP = "172.19.0.131"; // 替换为你的服务端IP
const int SERVER_PORT = 12345; // 服务端监听端口
const int CONNECT_TIMEOUT_SEC = 10; // 连接超时(匹配服务端处理节奏)
const int RECV_TIMEOUT_SEC = 100; // 接收超时延长(覆盖服务端非阻塞发送耗时)
const int SSL_HANDSHAKE_TIMEOUT_SEC = 10; // SSL握手超时
const int BUFFER_SIZE = 8192; // 增大缓冲区(适配服务端分开发送)
const int TCP_RECV_BUF_SIZE = 16384; // 增大TCP接收缓冲区,避免服务端发送阻塞
// 辅助函数:等待socket可读写(精准处理非阻塞SSL的WANT_READ/WRITE)
bool wait_socket_ready(int fd, bool for_write, int timeout_ms) {
fd_set fds;
FD_ZERO(&fds);
FD_SET(fd, &fds);
struct timeval timeout;
timeout.tv_sec = timeout_ms / 1000;
timeout.tv_usec = (timeout_ms % 1000) * 1000;
int ret = for_write ?
select(fd + 1, nullptr, &fds, nullptr, &timeout) :
select(fd + 1, &fds, nullptr, nullptr, &timeout);
if (ret < 0) {
std::cerr << "[wait_socket] select failed: " << strerror(errno) << std::endl;
return false;
} else if (ret == 0) {
return false; // 超时仅返回false,不直接报错
}
return FD_ISSET(fd, &fds);
}
// 初始化SSL上下文(匹配服务端的TLS1.2/1.3)
SSL_CTX* init_ssl_context() {
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
ERR_load_crypto_strings();
const SSL_METHOD* method = TLS_client_method();
if (method == nullptr) {
std::cerr << "Failed to create TLS client method: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
return nullptr;
}
SSL_CTX* ctx = SSL_CTX_new(method);
if (ctx == nullptr) {
std::cerr << "Failed to create SSL context: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
return nullptr;
}
// 匹配服务端的TLS版本(禁用低版本)
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION);
SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT);
return ctx;
}
// 设置socket为非阻塞模式(适配服务端的非阻塞IO)
bool set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
std::cerr << "Failed to get socket flags: " << strerror(errno) << std::endl;
return false;
}
if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
std::cerr << "Failed to set socket non-blocking: " << strerror(errno) << std::endl;
return false;
}
return true;
}
// 设置socket选项(核心:增大缓冲区 + TCP keepalive + 超时)
bool set_socket_options(int fd, int timeout_sec) {
// 1. 增大TCP接收缓冲区(避免客户端缓冲区满,导致服务端SSL_write返回WANT_WRITE)
int recv_buf_size = TCP_RECV_BUF_SIZE;
if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &recv_buf_size, sizeof(recv_buf_size)) < 0) {
std::cerr << "Warning: Failed to set SO_RCVBUF: " << strerror(errno) << std::endl;
}
// 2. 设置读写超时(匹配服务端的超时逻辑)
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = 0;
if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
std::cerr << "Failed to set recv timeout: " << strerror(errno) << std::endl;
return false;
}
if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
std::cerr << "Failed to set send timeout: " << strerror(errno) << std::endl;
return false;
}
// 3. 启用TCP keepalive(防止服务端超时清理连接)
int keepalive = 1;
int keepidle = 5; // 5秒无数据发送keepalive包
int keepinterval = 2; // 每2秒重试一次
int keepcount = 3; // 3次无响应则判定连接失效
if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
std::cerr << "Failed to set SO_KEEPALIVE: " << strerror(errno) << std::endl;
return false;
}
if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &keepidle, sizeof(keepidle)) < 0) {
std::cerr << "Failed to set TCP_KEEPIDLE: " << strerror(errno) << std::endl;
return false;
}
if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &keepinterval, sizeof(keepinterval)) < 0) {
std::cerr << "Failed to set TCP_KEEPINTVL: " << strerror(errno) << std::endl;
return false;
}
if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &keepcount, sizeof(keepcount)) < 0) {
std::cerr << "Failed to set TCP_KEEPCNT: " << strerror(errno) << std::endl;
return false;
}
return true;
}
// 创建SSL连接(适配服务端的非阻塞SSL握手)
SSL* create_ssl_connection(SSL_CTX* ctx, const std::string& server_ip, int port, int process_id) {
// 1. 创建非阻塞socket
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
std::cerr << "Process " << process_id << " Failed to create socket: " << strerror(errno) << std::endl;
return nullptr;
}
// 2. 设置非阻塞 + 缓冲区 + keepalive + 超时
if (!set_nonblocking(sockfd)) {
close(sockfd);
return nullptr;
}
if (!set_socket_options(sockfd, CONNECT_TIMEOUT_SEC)) {
close(sockfd);
return nullptr;
}
// 3. 非阻塞connect(适配服务端的非阻塞accept)
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
if (inet_pton(AF_INET, server_ip.c_str(), &server_addr.sin_addr) <= 0) {
std::cerr << "Process " << process_id << " Invalid server IP: " << server_ip << std::endl;
close(sockfd);
return nullptr;
}
int ret = connect(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr));
if (ret < 0 && errno != EINPROGRESS) {
std::cerr << "Process " << process_id << " Failed to connect: " << strerror(errno) << std::endl;
close(sockfd);
return nullptr;
}
// 等待connect完成(非阻塞模式必须等)
if (!wait_socket_ready(sockfd, true, CONNECT_TIMEOUT_SEC * 1000)) {
std::cerr << "Process " << process_id << " Connect timeout (" << CONNECT_TIMEOUT_SEC << "s)" << std::endl;
close(sockfd);
return nullptr;
}
// 检查连接是否真的成功
int err = 0;
socklen_t err_len = sizeof(err);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len) < 0 || err != 0) {
std::cerr << "Process " << process_id << " Connect failed: " << strerror(err) << std::endl;
close(sockfd);
return nullptr;
}
// 4. 非阻塞SSL握手(适配服务端的非阻塞SSL_accept)
SSL* ssl = SSL_new(ctx);
if (ssl == nullptr) {
std::cerr << "Process " << process_id << " Failed to create SSL object: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
close(sockfd);
return nullptr;
}
SSL_set_fd(ssl, sockfd);
SSL_set_connect_state(ssl);
time_t handshake_start = time(nullptr);
int handshake_ret;
do {
if (time(nullptr) - handshake_start > SSL_HANDSHAKE_TIMEOUT_SEC) {
std::cerr << "Process " << process_id << " SSL handshake timeout (" << SSL_HANDSHAKE_TIMEOUT_SEC << "s)" << std::endl;
SSL_free(ssl);
close(sockfd);
return nullptr;
}
handshake_ret = SSL_connect(ssl);
if (handshake_ret == 1) {
break; // 握手成功
}
int ssl_err = SSL_get_error(ssl, handshake_ret);
if (ssl_err == SSL_ERROR_WANT_READ) {
wait_socket_ready(sockfd, false, 100); // 等待100ms再重试
} else if (ssl_err == SSL_ERROR_WANT_WRITE) {
wait_socket_ready(sockfd, true, 100);
} else {
std::cerr << "Process " << process_id << " SSL handshake failed: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
SSL_free(ssl);
close(sockfd);
return nullptr;
}
} while (true);
std::cout << "Process " << process_id << " SSL handshake success (cipher: " << SSL_get_cipher(ssl) << ")" << std::endl;
return ssl;
}
// 构造请求(严格匹配服务端的JSON格式)
std::string build_request(const std::string& role) {
json req = {
{"command", "query"},
{"data", {{"role", role}}}
};
return req.dump();
}
// 发送请求(处理非阻塞SSL_write的部分写入)
bool send_request(SSL* ssl, const std::string& request, int process_id) {
const char* buf = request.c_str();
size_t total_to_send = request.size();
size_t total_sent = 0;
// 循环发送,确保所有数据发送完成(适配服务端非阻塞接收)
while (total_sent < total_to_send) {
ssize_t bytes_sent = SSL_write(ssl, buf + total_sent, total_to_send - total_sent);
if (bytes_sent > 0) {
total_sent += bytes_sent;
std::cout << "Process " << process_id << " Sent " << bytes_sent << " bytes (total: " << total_sent << "/" << total_to_send << ")" << std::endl;
} else if (bytes_sent == 0) {
std::cerr << "Process " << process_id << " SSL write closed" << std::endl;
return false;
} else {
int ssl_err = SSL_get_error(ssl, bytes_sent);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
int fd = SSL_get_fd(ssl);
wait_socket_ready(fd, ssl_err == SSL_ERROR_WANT_WRITE, 100); // 等待就绪后重试
continue;
} else {
std::cerr << "Process " << process_id << " Failed to send request: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
return false;
}
}
}
std::cout << "Process " << process_id << " Sent request: " << request << std::endl;
return true;
}
// 接收响应(核心:识别服务端的\r\n结束符 + 处理非阻塞读取)
std::string recv_response(SSL* ssl, int process_id) {
char buffer[BUFFER_SIZE] = {0};
std::string response;
struct timeval start_time;
gettimeofday(&start_time, nullptr);
int fd = SSL_get_fd(ssl);
std::cout << "Process " << process_id << " Start receiving response (timeout: " << RECV_TIMEOUT_SEC << "s)..." << std::endl;
while (true) {
// 检查总超时(毫秒级精准控制)
struct timeval now;
gettimeofday(&now, nullptr);
long elapsed_ms = (now.tv_sec - start_time.tv_sec) * 1000 + (now.tv_usec - start_time.tv_usec) / 1000;
if (elapsed_ms > RECV_TIMEOUT_SEC * 1000) {
std::cerr << "Process " << process_id << " Recv timeout (elapsed: " << elapsed_ms << "ms)" << std::endl;
break;
}
// 非阻塞SSL_read
int bytes_read = SSL_read(ssl, buffer, BUFFER_SIZE - 1);
if (bytes_read > 0) {
response.append(buffer, bytes_read);
std::cout << "Process " << process_id << " Read " << bytes_read << " bytes (total: " << response.size() << ")" << std::endl;
// 检查服务端的\r\n结束符(核心:匹配服务端的响应格式)
size_t crlf_pos = response.find("\r\n");
if (crlf_pos != std::string::npos) {
response = response.substr(0, crlf_pos); // 去掉结束符
std::cout << "Process " << process_id << " Full response received: " << response << std::endl;
return response;
}
memset(buffer, 0, BUFFER_SIZE);
} else if (bytes_read == 0) {
std::cerr << "Process " << process_id << " Server closed connection" << std::endl;
break;
} else {
int ssl_err = SSL_get_error(ssl, bytes_read);
// 处理非阻塞的WANT_READ/WRITE(正常等待,不退出)
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
wait_socket_ready(fd, ssl_err == SSL_ERROR_WANT_WRITE, 200);
continue;
} else {
std::cerr << "Process " << process_id << " SSL read failed: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
break;
}
}
}
// 返回已接收的部分数据(便于调试)
if (!response.empty()) {
std::cout << "Process " << process_id << " Partial response: " << response << std::endl;
}
return response;
}
// 解析响应(匹配服务端的响应JSON格式)
void parse_response(const std::string& response, int process_id) {
if (response.empty()) {
std::cerr << "Process " << process_id << " Empty response" << std::endl;
return;
}
try {
json resp = json::parse(response);
std::string type = resp["type"];
if (type == "ERROR FROM ATF") {
std::cerr << "Process " << process_id << " Server error: " << resp.dump() << std::endl;
} else if (type == "QUERY FROM ATF") {
std::string ip = resp["data"]["ip"];
std::string role = resp["data"]["role"];
std::cout << "Process " << process_id << " Query success: ip=" << ip << ", role=" << role << std::endl;
} else {
std::cerr << "Process " << process_id << " Unknown response type: " << type << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Process " << process_id << " Parse failed: " << e.what() << " (raw: " << response << ")" << std::endl;
}
}
// 安全关闭连接(避免提前断开导致响应丢失)
void safe_ssl_shutdown(SSL* ssl) {
int sockfd = SSL_get_fd(ssl);
// 先关闭发送方向,保留接收方向
shutdown(sockfd, SHUT_WR);
// 等待服务端剩余数据(若有)
char buf[1024];
read(sockfd, buf, sizeof(buf));
// 安全关闭SSL
int ssl_ret = SSL_shutdown(ssl);
if (ssl_ret == 0) {
SSL_shutdown(ssl);
}
SSL_free(ssl);
close(sockfd);
}
// 单进程测试任务
void test_task(int process_id, const std::string& role) {
SSL_CTX* ctx = init_ssl_context();
if (ctx == nullptr) {
return;
}
SSL* ssl = create_ssl_connection(ctx, SERVER_IP, SERVER_PORT, process_id);
if (ssl == nullptr) {
SSL_CTX_free(ctx);
return;
}
std::string request = build_request(role);
if (!send_request(ssl, request, process_id)) {
safe_ssl_shutdown(ssl);
SSL_CTX_free(ctx);
return;
}
std::string response = recv_response(ssl, process_id);
parse_response(response, process_id);
safe_ssl_shutdown(ssl);
SSL_CTX_free(ctx);
// 休眠3秒,确保服务端完成响应发送(适配服务端的超时清理)
std::cout << "Process " << process_id << " Sleep 3s to wait server..." << std::endl;
sleep(3);
}
int main(int argc, char* argv[]) {
std::string role = "primary";
if (argc == 2) {
role = argv[1];
if (role != "primary" && role != "standby") {
std::cerr << "Invalid role! Use: primary/standby" << std::endl;
return 1;
}
}
test_task(1, role);
std::cout << "Test completed" << std::endl;
return 0;
}