/**
 * This program is free software, you can redistribute it and/or modify it.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "msg_handler_plugin.h"
#include <netinet/tcp.h>
#include <csignal>
#include "common/llm_utils.h"
#include "common/llm_checker.h"
#include "common/llm_scope_guard.h"

namespace llm {
namespace {
constexpr int32_t kListenBacklog = 128;
constexpr int64_t kDefaultSleepTime = 1;
}

void MsgHandlerPlugin::Initialize() {
  (void) std::signal(SIGPIPE, SIG_IGN);
}

ge::Status MsgHandlerPlugin::GetAiFamily(const std::string &ip, int32_t &ai_family) {
  struct sockaddr_in ipv4_addr;
  struct sockaddr_in6 ipv6_addr;
  (void)memset_s(&ipv4_addr, sizeof(ipv4_addr), 0, sizeof(ipv4_addr));
  if (inet_pton(AF_INET, ip.c_str(), &ipv4_addr.sin_addr) == 1) {
    ai_family = AF_INET;
    return ge::SUCCESS;
  }

  (void)memset_s(&ipv6_addr, sizeof(ipv6_addr), 0, sizeof(ipv6_addr));
  if (inet_pton(AF_INET6, ip.c_str(), &ipv6_addr.sin6_addr) == 1) {
    ai_family = AF_INET6;
    return ge::SUCCESS;
  }
  return ge::LLM_PARAM_INVALID;
}

ge::Status MsgHandlerPlugin::Connect(const std::string &ip, uint32_t port, int32_t &conn_fd,
                                     int32_t timeout, ge::Status default_err) {
  auto start = std::chrono::high_resolution_clock::now();
  struct ::addrinfo hints;
  struct ::addrinfo *result = nullptr;
  struct ::addrinfo *rp = nullptr;
  (void)memset_s(&hints, sizeof(hints), 0, sizeof(hints));
  LLM_CHK_STATUS_RET(GetAiFamily(ip, hints.ai_family), "Failed to get ai_family, ip:%s", ip.c_str());
  hints.ai_socktype = SOCK_STREAM;

  auto socket_ret = getaddrinfo(ip.c_str(), std::to_string(port).c_str(), &hints, &result);
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0,
                         ge::LLM_PARAM_INVALID,
                         "Failed to get IP address of peer %s:%u, please check addr and port, "
                         "socket_ret:%d, error msg:%s, errno:%d",
                         ip.c_str(), port, socket_ret, strerror(errno), errno);
  LLM_MAKE_GUARD(free_addr, ([result]() { freeaddrinfo(result); }));

  ge::Status ret = ge::SUCCESS;
  int32_t err_no = 0; // for last error record
  for (rp = result; rp != nullptr; rp = rp->ai_next) {
    ret = DoConnect(rp, conn_fd, err_no, timeout, default_err);
    if (ret == ge::SUCCESS) {
      break;
    }
    LLM_CHK_BOOL_RET_STATUS(!LLMUtils::IsTimeout(start, timeout), ge::LLM_TIMEOUT,
                           "connect to the peer %s:%u timed out, timeout:%d ms.",
                           ip.c_str(), port, timeout);
  }
  if (ret != ge::SUCCESS) {
    LLMLOGE(ret, "Failed to connect peer %s:%u, error msg:%s, errno:%d",
           ip.c_str(), port, strerror(err_no), err_no);
  }
  return ret;
}

ge::Status MsgHandlerPlugin::DoConnect(struct ::addrinfo *addr, int32_t &conn_fd, int32_t &err_no,
                                       int32_t timeout, ge::Status default_err) {
  int32_t on = 1;
  LLMLOGI("Attempting to create socket with family:%d, type:%d, protocol:%d",
         addr->ai_family, addr->ai_socktype, addr->ai_protocol);
  LLM_DISMISSABLE_GUARD(record_err, ([&err_no]() { err_no = errno; }));
  conn_fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(conn_fd == -1, default_err,
                                 "Try to create socket, error msg:%s, errno:%d", strerror(errno), errno);

  LLM_DISMISSABLE_GUARD(close_fd, ([conn_fd]() { close(conn_fd); }));
  auto socket_ret = setsockopt(conn_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, default_err,
                                 "Try to setsockopt(SO_REUSEADDR), socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  constexpr int32_t kTimeInSec = 1000;
  struct timeval socket_timeout;
  socket_timeout.tv_sec = timeout / kTimeInSec;
  socket_timeout.tv_usec = (timeout % kTimeInSec) * kTimeInSec;
  socket_ret = setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO,  &socket_timeout, sizeof(socket_timeout));
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, default_err,
                                 "Try to setsockopt(SO_RCVTIMEO), socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  int32_t flag = 1;
  socket_ret = setsockopt(conn_fd, IPPROTO_TCP, TCP_NODELAY,  &flag, sizeof(flag));
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, default_err,
                                 "Try to setsockopt(TCP_NODELAY), socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  socket_ret = setsockopt(conn_fd, SOL_SOCKET, SO_SNDTIMEO,  &socket_timeout, sizeof(socket_timeout));
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, default_err,
                                 "Try to setsockopt(SO_SNDTIMEO), socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  socket_ret = connect(conn_fd, addr->ai_addr, addr->ai_addrlen);
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, default_err,
                                 "Try to socket connect, socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  LLM_DISMISS_GUARD(close_fd);
  LLM_DISMISS_GUARD(record_err);
  return ge::SUCCESS;
}

void MsgHandlerPlugin::Disconnect(int32_t conn_fd) {
  close(conn_fd);
}

void MsgHandlerPlugin::ListenClose() {
  if (listen_fd_ >= 0) {
    close(listen_fd_);
    listen_fd_ = -1;
  }
}

ssize_t MsgHandlerPlugin::Write(int32_t fd, const void *buf, size_t len) {
  const char *pos = static_cast<const char *>(buf);
  size_t nbytes = len;
  while (nbytes > 0U) {
    auto rc = write(fd, pos, nbytes);
    if (rc < 0 && (errno == EAGAIN || errno == EINTR)) {
      continue;
    } else if (rc < 0) {
      LLMLOGE(ge::FAILED, "Socket write failed, error msg:%s, errno:%d", strerror(errno), errno);
      return rc;
    } else if (rc == 0) {
      LLMLOGW("Socket write incompleted: expected %zu bytes, actual %zu bytes", len, len - nbytes);
      return static_cast<ssize_t>(len - nbytes);
    }
    pos += rc;
    nbytes -= rc;
  }
  LLMLOGI("Socket write completed: %zu bytes", len);
  return static_cast<ssize_t>(len);
}

ssize_t MsgHandlerPlugin::Read(int32_t fd, void *buf, size_t len) {
  auto pos = static_cast<uint8_t *>(buf);
  size_t nbytes = len;
  while (nbytes > 0U) {
    auto rc = read(fd, pos, nbytes);
    if (rc < 0 && (errno == EAGAIN || errno == EINTR)) {
      continue;
    } else if (rc < 0) {
      LLMLOGE(ge::FAILED, "Socket read failed, error msg:%s, errno:%d", strerror(errno), errno);
      return rc;
    } else if (rc == 0) {
      LLMLOGW("Socket read incompleted: expected %zu bytes, actual %zu bytes", len, len - nbytes);
      return static_cast<ssize_t>(len - nbytes);
    }
    pos += rc;
    nbytes -= rc;
  }
  return static_cast<ssize_t>(len);
}

void MsgHandlerPlugin::RegisterConnectedProcess(ConnectedProcess proc) {
  connected_process_ = proc;
}

ge::Status MsgHandlerPlugin::DoConnectedProcess(int32_t conn_fd) {
  LLM_DISMISSABLE_GUARD(close_fd, ([conn_fd]() { close(conn_fd); }));
  LLM_CHK_BOOL_RET_STATUS(rtCtxSetCurrent(rt_context_) == RT_ERROR_NONE, ge::LLM_PARAM_INVALID,
                         "Set runtime context failed.");
  constexpr int32_t kTimeInSec = 60;
  struct timeval timeout;
  timeout.tv_sec = kTimeInSec;
  timeout.tv_usec = 0;
  auto socket_ret = setsockopt(conn_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0, ge::FAILED,
                         "Failed to setsockopt(SO_RCVTIMEO), socket_ret:%d, error msg:%s, errno:%d",
                         socket_ret, strerror(errno), errno);
  int32_t flag = 1;
  socket_ret = setsockopt(conn_fd, IPPROTO_TCP, TCP_NODELAY,  &flag, sizeof(flag));
  LLM_CHK_BOOL_RET_SPECIAL_STATUS(socket_ret != 0, ge::FAILED,
                                 "Try to setsockopt(TCP_NODELAY), socket_ret:%d, error msg:%s, errno:%d",
                                 socket_ret, strerror(errno), errno);
  bool keep_fd = false;
  connected_process_(conn_fd, keep_fd);
  if (keep_fd) {
    LLM_DISMISS_GUARD(close_fd);
    return ge::SUCCESS;
  }
  socket_ret = shutdown(conn_fd, SHUT_RDWR);
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0, ge::FAILED,
                         "Failed to shutdown conn_fd, connection may be incomplete, "
                         "socket_ret:%d, error msg:%s, errno:%d",
                         socket_ret, strerror(errno), errno);
  // Wait for the client to close the connection
  char byte;
  auto rc = read(conn_fd, &byte, sizeof(byte));
  LLM_CHK_BOOL_RET_STATUS(rc == 0U,
                         ge::FAILED, "Failed to wait client close, byte = %d, rc = %zu",
                         static_cast<int32_t>(byte), static_cast<size_t>(rc));
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::DoAccept() {
  struct sockaddr_storage addr;
  socklen_t addr_len = sizeof(addr);
  auto conn_fd = accept(listen_fd_, reinterpret_cast<sockaddr *>(&addr), &addr_len);
  if (conn_fd < 0) {
    LLM_CHK_BOOL_RET_STATUS(errno == EWOULDBLOCK || errno == EINTR || errno == ECONNABORTED, ge::FAILED,
                           "Failed to accept, error msg=%s, errno=%d",
                           strerror(errno), errno);
    return ge::SUCCESS;
  }
  LLM_DISMISSABLE_GUARD(close_fd, ([conn_fd]() { close(conn_fd); }));
  LLMLOGI("accept success, fd:%d, addr.sin_family:%d", conn_fd, addr.ss_family);
  if (addr.ss_family == AF_INET || addr.ss_family == AF_INET6) {
    (void)thread_pool_->commit([this, conn_fd]() -> void { (void)DoConnectedProcess(conn_fd); });
    LLM_DISMISS_GUARD(close_fd);
  }
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::SockAddrInit(const std::string &ip, uint32_t listen_port, int32_t ai_family,
                                          struct sockaddr_storage &server_addr, socklen_t &addr_len) {
  if (ai_family == AF_INET) {
    struct sockaddr_in* ipv4_addr = reinterpret_cast<struct sockaddr_in*>(&server_addr);
    (void)memset_s(ipv4_addr, sizeof(*ipv4_addr), 0, sizeof(*ipv4_addr));
    ipv4_addr->sin_family = AF_INET;
    ipv4_addr->sin_port = htons(listen_port);
    (void) inet_pton(AF_INET, ip.c_str(), &ipv4_addr->sin_addr);
    addr_len = sizeof(*ipv4_addr);
  } else {
    struct sockaddr_in6* ipv6_addr = reinterpret_cast<struct sockaddr_in6*>(&server_addr);
    (void)memset_s(ipv6_addr, sizeof(*ipv6_addr), 0, sizeof(*ipv6_addr));
    ipv6_addr->sin6_family = AF_INET6;
    ipv6_addr->sin6_port = htons(listen_port);
    (void) inet_pton(AF_INET6, ip.c_str(), &ipv6_addr->sin6_addr);
    addr_len = sizeof(*ipv6_addr);
    int v6only = 1;
    (void) setsockopt(listen_fd_, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only));
  }
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::StartDaemon(const std::string &ip, uint32_t listen_port) {
  LLM_ASSERT_RT_OK(rtCtxGetCurrent(&rt_context_));
  int32_t ai_family = 0;
  LLM_CHK_STATUS_RET(GetAiFamily(ip, ai_family), "Failed to get ai_family, ip:%s", ip.c_str());
  listen_fd_ = socket(ai_family, SOCK_STREAM, 0);
  LLM_CHK_BOOL_RET_STATUS(listen_fd_ >= 0, ge::FAILED, "Failed to create socket.");
  LLM_DISMISSABLE_GUARD(fail_guard, ([this]() { ListenClose(); }));

  struct sockaddr_storage server_addr;
  socklen_t addr_len;
  LLM_CHK_STATUS_RET(SockAddrInit(ip, listen_port, ai_family, server_addr, addr_len),
                     "Failed to init sockaddr, ip:%s", ip.c_str());

  struct timeval timeout;
  timeout.tv_sec = 1;
  timeout.tv_usec = 0;
  auto socket_ret = setsockopt(listen_fd_, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0, ge::FAILED,
                         "Failed to set socket opt timeout, socket_ret:%d, error msg:%s, errno:%d",
                         socket_ret, strerror(errno), errno);
  int32_t on = 1;
  socket_ret = setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0, ge::FAILED,
                         "Failed to set socket opt SO_REUSEADDR, socket_ret:%d, error msg:%s, errno:%d",
                         socket_ret, strerror(errno), errno);
  LLM_CHK_BOOL_RET_STATUS(bind(listen_fd_, reinterpret_cast<sockaddr *>(&server_addr),
                               addr_len) >= 0,
                         ge::FAILED, "Failed to bind port:%u, error msg:%s, errno:%d.",
                         listen_port, strerror(errno), errno);
  socket_ret = listen(listen_fd_, kListenBacklog);
  LLM_CHK_BOOL_RET_STATUS(socket_ret == 0, ge::FAILED, "Failed to listen, socket_ret:%d, error msg:%s, errno:%d",
                         socket_ret, strerror(errno), errno);
  constexpr uint32_t kThreadPoolSize = 16U;
  thread_pool_ = MakeUnique<LLMThreadPool>("ge_llm_mhp", kThreadPoolSize);
  LLM_CHECK_NOTNULL(thread_pool_);
  listener_running_ = true;
  listener_ = std::thread([this]() {
    while (listener_running_) {
      auto ret = DoAccept();
      if (ret != ge::SUCCESS) {
        std::this_thread::sleep_for(std::chrono::seconds(kDefaultSleepTime));
      }
    }
    return;
  });

  LLM_DISMISS_GUARD(fail_guard);
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::SendMsg(int32_t fd, int32_t msg_type, const std::string &msg_str) {
  uint64_t length = msg_str.size() + sizeof(msg_type);
  auto len = Write(fd, &length, sizeof(length));
  LLM_CHK_BOOL_RET_STATUS(len == static_cast<ssize_t>(sizeof(length)), ge::FAILED,
                         "Failed to send msg len:%zu, expect write len:%zu, actually write len:%zd",
                         length, sizeof(length), len);
  len = Write(fd, &msg_type, sizeof(msg_type));
  LLM_CHK_BOOL_RET_STATUS(len == static_cast<ssize_t>(sizeof(msg_type)),
                         ge::FAILED, "Failed to send msg type:%d, expect write len:%zu, actually write len:%zd",
                         msg_type, sizeof(msg_type), len);
  len = Write(fd, msg_str.c_str(), msg_str.size());
  LLM_CHK_BOOL_RET_STATUS(len == static_cast<ssize_t>(msg_str.size()),
                         ge::FAILED, "Failed to send msg:%s, expect write len:%zu, actually write len:%zd",
                         msg_str.c_str(), msg_str.size(), len);
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::RecvMsg(int32_t fd, int32_t &msg_type, std::vector<char> &msg) {
  const static size_t kMaxLength = 1ULL << 20;
  uint64_t length = 0;
  auto n = Read(fd, &length, sizeof(length));
  LLM_CHK_BOOL_RET_STATUS(n == static_cast<ssize_t>(sizeof(length)),
                         ge::FAILED, "Failed to recv msg len:%zd, expect len:%zu", n, sizeof(length));
  LLM_CHK_BOOL_RET_STATUS(length <= kMaxLength && length > sizeof(int32_t),
                         ge::FAILED, "Failed to check msg len:%lu, must in range: (%zu, %zu]",
                         length, sizeof(int32_t), kMaxLength);
  int32_t type = 0;
  n = Read(fd, &type, sizeof(type));
  LLM_CHK_BOOL_RET_STATUS(n == static_cast<ssize_t>(sizeof(type)),
                         ge::FAILED, "Failed to recv msg type len:%zd, expect len:%zu", n, sizeof(type));
  msg_type = type;
  size_t msg_len = static_cast<size_t>(length) - sizeof(int32_t);
  msg.resize(msg_len + 1U);
  n = Read(fd, msg.data(), msg_len);
  LLM_CHK_BOOL_RET_STATUS(n == static_cast<ssize_t>(msg_len),
                         ge::FAILED, "Failed to check recv msg type:%d, recv msg len:%zd, expect len:%zu",
                         type, n, msg_len);
  msg[msg_len] = '\0';
  return ge::SUCCESS;
}

ge::Status MsgHandlerPlugin::RecvMsg(int32_t fd, int32_t &msg_type, std::vector<char> &msg, uint64_t length) {
  int32_t type = 0;
  auto n = Read(fd, &type, sizeof(type));
  LLM_CHK_BOOL_RET_STATUS(n == static_cast<ssize_t>(sizeof(type)),
                         ge::FAILED, "Failed to recv msg type len:%zd, expect len:%zu", n, sizeof(type));
  msg_type = type;
  size_t msg_len = static_cast<size_t>(length) - sizeof(int32_t);
  msg.resize(msg_len + 1U);
  n = Read(fd, msg.data(), msg_len);
  LLM_CHK_BOOL_RET_STATUS(n == static_cast<ssize_t>(msg_len),
                         ge::FAILED, "Failed to check recv msg type:%d, recv msg len:%zd, expect len:%zu",
                         type, n, msg_len);
  msg[msg_len] = '\0';
  return ge::SUCCESS;
}

MsgHandlerPlugin::~MsgHandlerPlugin() {
  if (listener_running_) {
    Finalize();
  }
}

void MsgHandlerPlugin::Finalize() {
  ListenClose();
  if (listener_running_) {
    listener_running_ = false;
    listener_.join();
  }
}
}  // namespace llm