/*
 * Copyright (c) 2025 Huawei Technologies Co.,Ltd.
 *
 * ATF is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
 *
 * atf_test.cpp
 *    ATF Server Functional Test (only tests whether ATF can correctly obtain the cluster state,
 *                                and does not involve joint testing with the client)
 *
 * IDENTIFICATION
 *    ATF/atf_test.cpp
 *
 * -------------------------------------------------------------------------
 */
#define _GNU_SOURCE
#include <iostream>
#include <string>
#include <cstring>
#include <cerrno>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <sys/time.h>
#include <sys/select.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <nlohmann/json.hpp>

using json = nlohmann::json;

// 配置项(适配服务端非阻塞ET模式)
const std::string SERVER_IP = "172.19.0.131";  // 替换为你的服务端IP
const int SERVER_PORT = 12345;                 // 服务端监听端口
const int CONNECT_TIMEOUT_SEC = 10;             // 连接超时(匹配服务端处理节奏)
const int RECV_TIMEOUT_SEC = 100;               // 接收超时延长(覆盖服务端非阻塞发送耗时)
const int SSL_HANDSHAKE_TIMEOUT_SEC = 10;       // SSL握手超时
const int BUFFER_SIZE = 8192;                  // 增大缓冲区(适配服务端分开发送)
const int TCP_RECV_BUF_SIZE = 16384;           // 增大TCP接收缓冲区,避免服务端发送阻塞

// 辅助函数:等待socket可读写(精准处理非阻塞SSL的WANT_READ/WRITE)
bool wait_socket_ready(int fd, bool for_write, int timeout_ms) {
    fd_set fds;
    FD_ZERO(&fds);
    FD_SET(fd, &fds);

    struct timeval timeout;
    timeout.tv_sec = timeout_ms / 1000;
    timeout.tv_usec = (timeout_ms % 1000) * 1000;

    int ret = for_write ? 
              select(fd + 1, nullptr, &fds, nullptr, &timeout) :
              select(fd + 1, &fds, nullptr, nullptr, &timeout);
    
    if (ret < 0) {
        std::cerr << "[wait_socket] select failed: " << strerror(errno) << std::endl;
        return false;
    } else if (ret == 0) {
        return false; // 超时仅返回false,不直接报错
    }
    return FD_ISSET(fd, &fds);
}

// 初始化SSL上下文(匹配服务端的TLS1.2/1.3)
SSL_CTX* init_ssl_context() {
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();
    ERR_load_crypto_strings();

    const SSL_METHOD* method = TLS_client_method();
    if (method == nullptr) {
        std::cerr << "Failed to create TLS client method: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
        return nullptr;
    }

    SSL_CTX* ctx = SSL_CTX_new(method);
    if (ctx == nullptr) {
        std::cerr << "Failed to create SSL context: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
        return nullptr;
    }

    // 匹配服务端的TLS版本(禁用低版本)
    SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
    SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION);
    SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT);

    return ctx;
}

// 设置socket为非阻塞模式(适配服务端的非阻塞IO)
bool set_nonblocking(int fd) {
    int flags = fcntl(fd, F_GETFL, 0);
    if (flags == -1) {
        std::cerr << "Failed to get socket flags: " << strerror(errno) << std::endl;
        return false;
    }
    if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
        std::cerr << "Failed to set socket non-blocking: " << strerror(errno) << std::endl;
        return false;
    }
    return true;
}

// 设置socket选项(核心:增大缓冲区 + TCP keepalive + 超时)
bool set_socket_options(int fd, int timeout_sec) {
    // 1. 增大TCP接收缓冲区(避免客户端缓冲区满,导致服务端SSL_write返回WANT_WRITE)
    int recv_buf_size = TCP_RECV_BUF_SIZE;
    if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &recv_buf_size, sizeof(recv_buf_size)) < 0) {
        std::cerr << "Warning: Failed to set SO_RCVBUF: " << strerror(errno) << std::endl;
    }

    // 2. 设置读写超时(匹配服务端的超时逻辑)
    struct timeval timeout;
    timeout.tv_sec = timeout_sec;
    timeout.tv_usec = 0;
    if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
        std::cerr << "Failed to set recv timeout: " << strerror(errno) << std::endl;
        return false;
    }
    if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
        std::cerr << "Failed to set send timeout: " << strerror(errno) << std::endl;
        return false;
    }

    // 3. 启用TCP keepalive(防止服务端超时清理连接)
    int keepalive = 1;
    int keepidle = 5;     // 5秒无数据发送keepalive包
    int keepinterval = 2; // 每2秒重试一次
    int keepcount = 3;    // 3次无响应则判定连接失效
    if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof(keepalive)) < 0) {
        std::cerr << "Failed to set SO_KEEPALIVE: " << strerror(errno) << std::endl;
        return false;
    }
    if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &keepidle, sizeof(keepidle)) < 0) {
        std::cerr << "Failed to set TCP_KEEPIDLE: " << strerror(errno) << std::endl;
        return false;
    }
    if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &keepinterval, sizeof(keepinterval)) < 0) {
        std::cerr << "Failed to set TCP_KEEPINTVL: " << strerror(errno) << std::endl;
        return false;
    }
    if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &keepcount, sizeof(keepcount)) < 0) {
        std::cerr << "Failed to set TCP_KEEPCNT: " << strerror(errno) << std::endl;
        return false;
    }

    return true;
}

// 创建SSL连接(适配服务端的非阻塞SSL握手)
SSL* create_ssl_connection(SSL_CTX* ctx, const std::string& server_ip, int port, int process_id) {
    // 1. 创建非阻塞socket
    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd < 0) {
        std::cerr << "Process " << process_id << " Failed to create socket: " << strerror(errno) << std::endl;
        return nullptr;
    }

    // 2. 设置非阻塞 + 缓冲区 + keepalive + 超时
    if (!set_nonblocking(sockfd)) {
        close(sockfd);
        return nullptr;
    }
    if (!set_socket_options(sockfd, CONNECT_TIMEOUT_SEC)) {
        close(sockfd);
        return nullptr;
    }

    // 3. 非阻塞connect(适配服务端的非阻塞accept)
    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(port);
    if (inet_pton(AF_INET, server_ip.c_str(), &server_addr.sin_addr) <= 0) {
        std::cerr << "Process " << process_id << " Invalid server IP: " << server_ip << std::endl;
        close(sockfd);
        return nullptr;
    }

    int ret = connect(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr));
    if (ret < 0 && errno != EINPROGRESS) {
        std::cerr << "Process " << process_id << " Failed to connect: " << strerror(errno) << std::endl;
        close(sockfd);
        return nullptr;
    }

    // 等待connect完成(非阻塞模式必须等)
    if (!wait_socket_ready(sockfd, true, CONNECT_TIMEOUT_SEC * 1000)) {
        std::cerr << "Process " << process_id << " Connect timeout (" << CONNECT_TIMEOUT_SEC << "s)" << std::endl;
        close(sockfd);
        return nullptr;
    }

    // 检查连接是否真的成功
    int err = 0;
    socklen_t err_len = sizeof(err);
    if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len) < 0 || err != 0) {
        std::cerr << "Process " << process_id << " Connect failed: " << strerror(err) << std::endl;
        close(sockfd);
        return nullptr;
    }

    // 4. 非阻塞SSL握手(适配服务端的非阻塞SSL_accept)
    SSL* ssl = SSL_new(ctx);
    if (ssl == nullptr) {
        std::cerr << "Process " << process_id << " Failed to create SSL object: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
        close(sockfd);
        return nullptr;
    }
    SSL_set_fd(ssl, sockfd);
    SSL_set_connect_state(ssl);

    time_t handshake_start = time(nullptr);
    int handshake_ret;
    do {
        if (time(nullptr) - handshake_start > SSL_HANDSHAKE_TIMEOUT_SEC) {
            std::cerr << "Process " << process_id << " SSL handshake timeout (" << SSL_HANDSHAKE_TIMEOUT_SEC << "s)" << std::endl;
            SSL_free(ssl);
            close(sockfd);
            return nullptr;
        }

        handshake_ret = SSL_connect(ssl);
        if (handshake_ret == 1) {
            break; // 握手成功
        }

        int ssl_err = SSL_get_error(ssl, handshake_ret);
        if (ssl_err == SSL_ERROR_WANT_READ) {
            wait_socket_ready(sockfd, false, 100); // 等待100ms再重试
        } else if (ssl_err == SSL_ERROR_WANT_WRITE) {
            wait_socket_ready(sockfd, true, 100);
        } else {
            std::cerr << "Process " << process_id << " SSL handshake failed: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
            SSL_free(ssl);
            close(sockfd);
            return nullptr;
        }
    } while (true);

    std::cout << "Process " << process_id << " SSL handshake success (cipher: " << SSL_get_cipher(ssl) << ")" << std::endl;
    return ssl;
}

// 构造请求(严格匹配服务端的JSON格式)
std::string build_request(const std::string& role) {
    json req = {
        {"command", "query"},
        {"data", {{"role", role}}}
    };
    return req.dump();
}

// 发送请求(处理非阻塞SSL_write的部分写入)
bool send_request(SSL* ssl, const std::string& request, int process_id) {
    const char* buf = request.c_str();
    size_t total_to_send = request.size();
    size_t total_sent = 0;

    // 循环发送,确保所有数据发送完成(适配服务端非阻塞接收)
    while (total_sent < total_to_send) {
        ssize_t bytes_sent = SSL_write(ssl, buf + total_sent, total_to_send - total_sent);
        if (bytes_sent > 0) {
            total_sent += bytes_sent;
            std::cout << "Process " << process_id << " Sent " << bytes_sent << " bytes (total: " << total_sent << "/" << total_to_send << ")" << std::endl;
        } else if (bytes_sent == 0) {
            std::cerr << "Process " << process_id << " SSL write closed" << std::endl;
            return false;
        } else {
            int ssl_err = SSL_get_error(ssl, bytes_sent);
            if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                int fd = SSL_get_fd(ssl);
                wait_socket_ready(fd, ssl_err == SSL_ERROR_WANT_WRITE, 100); // 等待就绪后重试
                continue;
            } else {
                std::cerr << "Process " << process_id << " Failed to send request: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
                return false;
            }
        }
    }

    std::cout << "Process " << process_id << " Sent request: " << request << std::endl;
    return true;
}

// 接收响应(核心:识别服务端的\r\n结束符 + 处理非阻塞读取)
std::string recv_response(SSL* ssl, int process_id) {
    char buffer[BUFFER_SIZE] = {0};
    std::string response;
    struct timeval start_time;
    gettimeofday(&start_time, nullptr);
    int fd = SSL_get_fd(ssl);

    std::cout << "Process " << process_id << " Start receiving response (timeout: " << RECV_TIMEOUT_SEC << "s)..." << std::endl;

    while (true) {
        // 检查总超时(毫秒级精准控制)
        struct timeval now;
        gettimeofday(&now, nullptr);
        long elapsed_ms = (now.tv_sec - start_time.tv_sec) * 1000 + (now.tv_usec - start_time.tv_usec) / 1000;
        if (elapsed_ms > RECV_TIMEOUT_SEC * 1000) {
            std::cerr << "Process " << process_id << " Recv timeout (elapsed: " << elapsed_ms << "ms)" << std::endl;
            break;
        }

        // 非阻塞SSL_read
        int bytes_read = SSL_read(ssl, buffer, BUFFER_SIZE - 1);
        if (bytes_read > 0) {
            response.append(buffer, bytes_read);
            std::cout << "Process " << process_id << " Read " << bytes_read << " bytes (total: " << response.size() << ")" << std::endl;
            
            // 检查服务端的\r\n结束符(核心:匹配服务端的响应格式)
            size_t crlf_pos = response.find("\r\n");
            if (crlf_pos != std::string::npos) {
                response = response.substr(0, crlf_pos); // 去掉结束符
                std::cout << "Process " << process_id << " Full response received: " << response << std::endl;
                return response;
            }
            memset(buffer, 0, BUFFER_SIZE);
        } else if (bytes_read == 0) {
            std::cerr << "Process " << process_id << " Server closed connection" << std::endl;
            break;
        } else {
            int ssl_err = SSL_get_error(ssl, bytes_read);
            // 处理非阻塞的WANT_READ/WRITE(正常等待,不退出)
            if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
                wait_socket_ready(fd, ssl_err == SSL_ERROR_WANT_WRITE, 200);
                continue;
            } else {
                std::cerr << "Process " << process_id << " SSL read failed: " << ERR_reason_error_string(ERR_get_error()) << std::endl;
                break;
            }
        }
    }

    // 返回已接收的部分数据(便于调试)
    if (!response.empty()) {
        std::cout << "Process " << process_id << " Partial response: " << response << std::endl;
    }
    return response;
}

// 解析响应(匹配服务端的响应JSON格式)
void parse_response(const std::string& response, int process_id) {
    if (response.empty()) {
        std::cerr << "Process " << process_id << " Empty response" << std::endl;
        return;
    }

    try {
        json resp = json::parse(response);
        std::string type = resp["type"];
        if (type == "ERROR FROM ATF") {
            std::cerr << "Process " << process_id << " Server error: " << resp.dump() << std::endl;
        } else if (type == "QUERY FROM ATF") {
            std::string ip = resp["data"]["ip"];
            std::string role = resp["data"]["role"];
            std::cout << "Process " << process_id << " Query success: ip=" << ip << ", role=" << role << std::endl;
        } else {
            std::cerr << "Process " << process_id << " Unknown response type: " << type << std::endl;
        }
    } catch (const std::exception& e) {
        std::cerr << "Process " << process_id << " Parse failed: " << e.what() << " (raw: " << response << ")" << std::endl;
    }
}

// 安全关闭连接(避免提前断开导致响应丢失)
void safe_ssl_shutdown(SSL* ssl) {
    int sockfd = SSL_get_fd(ssl);
    // 先关闭发送方向,保留接收方向
    shutdown(sockfd, SHUT_WR);
    // 等待服务端剩余数据(若有)
    char buf[1024];
    read(sockfd, buf, sizeof(buf));
    // 安全关闭SSL
    int ssl_ret = SSL_shutdown(ssl);
    if (ssl_ret == 0) {
        SSL_shutdown(ssl);
    }
    SSL_free(ssl);
    close(sockfd);
}

// 单进程测试任务
void test_task(int process_id, const std::string& role) {
    SSL_CTX* ctx = init_ssl_context();
    if (ctx == nullptr) {
        return;
    }

    SSL* ssl = create_ssl_connection(ctx, SERVER_IP, SERVER_PORT, process_id);
    if (ssl == nullptr) {
        SSL_CTX_free(ctx);
        return;
    }

    std::string request = build_request(role);
    if (!send_request(ssl, request, process_id)) {
        safe_ssl_shutdown(ssl);
        SSL_CTX_free(ctx);
        return;
    }

    std::string response = recv_response(ssl, process_id);
    parse_response(response, process_id);

    safe_ssl_shutdown(ssl);
    SSL_CTX_free(ctx);

    // 休眠3秒,确保服务端完成响应发送(适配服务端的超时清理)
    std::cout << "Process " << process_id << " Sleep 3s to wait server..." << std::endl;
    sleep(3);
}

int main(int argc, char* argv[]) {
    std::string role = "primary";
    if (argc == 2) {
        role = argv[1];
        if (role != "primary" && role != "standby") {
            std::cerr << "Invalid role! Use: primary/standby" << std::endl;
            return 1;
        }
    }

    test_task(1, role);
    std::cout << "Test completed" << std::endl;
    return 0;
}