/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * Description: UcpWorker class that handles the RDMA put process and managed
 * by UcpWorkerPool. It handles the reuse and removal of resources to previously
 * connected nodes.
 */

#include "datasystem/common/rdma/ucp_worker.h"

#include <chrono>

#include "datasystem/common/flags/flags.h"
#include "datasystem/common/log/log.h"
#include "datasystem/common/perf/perf_manager.h"
#include "datasystem/common/rdma/rdma_util.h"
#include "datasystem/common/util/format.h"
#include "datasystem/common/util/status_helper.h"

namespace datasystem {
namespace {
constexpr uint32_t DEFAULT_COMM_SUBMIT_BUDGET = 128;
constexpr uint32_t COMM_IDLE_WAIT_US = 100;

uint64_t GetSteadyClockNs()
{
    return std::chrono::duration_cast<std::chrono::nanoseconds>(
               std::chrono::steady_clock::now().time_since_epoch())
        .count();
}

uint32_t GetCommSubmitBudget()
{
    return DEFAULT_COMM_SUBMIT_BUDGET;
}

}  // namespace

UcpWorker::UcpWorker(const ucp_context_h &ucpContext, const uint32_t workerId)
    : context_(ucpContext),
      workerId_(workerId),
      errorMsgHead_("[UcpWorker " + std::to_string(workerId_) + "]")
{
}

UcpWorker::~UcpWorker()
{
    Clean();
}

Status UcpWorker::Init()
{
    ucp_worker_params_t workerParams = {};
    workerParams.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
    workerParams.thread_mode = UCS_THREAD_MODE_SERIALIZED;

    ucs_status_t status = ds_ucp_worker_create(context_, &workerParams, &worker_);
    if (status != UCS_OK) {
        LOG(ERROR) << errorMsgHead_ << " Failed to create worker. Status: " << ds_ucs_status_string(status);
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to create worker.");
    }

    size_t workerAddrLen;

    status = ds_ucp_worker_get_address(worker_, &localWorkerAddr_, &workerAddrLen);
    if (status != UCS_OK) {
        LOG(ERROR) << errorMsgHead_ << " Failed to get worker address. Status: " << ds_ucs_status_string(status);
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to get worker address.");
    }

    localWorkerAddrStr_ = std::string(reinterpret_cast<const char *>(localWorkerAddr_), workerAddrLen);

    StartSubmitThread();

    return Status::OK();
}

Status UcpWorker::Write(const std::string &remoteRkey, const uintptr_t remoteSegAddr,
                        const std::string &remoteWorkerAddr, const std::string &ipAddr, const uintptr_t localSegAddr,
                        size_t localSegSize, uint64_t requestID, std::shared_ptr<Event> event)
{
    auto req = std::make_shared<SubmitRequest>();
    req->remoteRkey = remoteRkey;
    req->remoteAddr = remoteSegAddr;
    req->remoteWorkerAddr = remoteWorkerAddr;
    req->ipAddr = ipAddr;
    req->localSegAddr = localSegAddr;
    req->localSegSize = localSegSize;
    req->requestID = requestID;
    req->event = std::move(event);
    req->enqueueNs = GetSteadyClockNs();
    return EnqueueSubmit(std::move(req));
}

Status UcpWorker::WriteDirect(const std::string &remoteRkey, uintptr_t remoteSegAddr,
                              const std::string &remoteWorkerAddr, const std::string &ipAddr, uintptr_t localSegAddr,
                              size_t localSegSize, uint64_t requestID, std::shared_ptr<Event> event)
{
    PerfPoint totalPoint(PerfKey::RDMA_UCP_WORKER_WRITE_TOTAL);
    const auto &ucpEp = GetOrCreateEndpoint(ipAddr, remoteWorkerAddr);
    if (ucpEp == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to create Endpoint.");
    }
    const ucp_ep_h &ep = ucpEp->GetEp();
    if (ep == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " UcpEndpoint contained an empty endpoint?");
    }

    ucp_rkey_h rkey = ucpEp->GetOrUnpackRkey(remoteRkey);
    if (rkey == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + ipAddr + " Failed to get an unpack rkey.");
    }

    ucp_request_param_t putParam{};

    void *putRequest =
        ds_ucp_put_nbx(ep, reinterpret_cast<const void *>(localSegAddr), localSegSize, remoteSegAddr, rkey, &putParam);

    if (UCS_PTR_IS_ERR(putRequest)) {
        ucs_status_t status = UCS_PTR_STATUS(putRequest);
        LOG(ERROR) << errorMsgHead_ << " Failed to execute ucp_put_nbx. Status: " << ds_ucs_status_string(status);
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to send data immediately.");
    }

    CallbackContext *flushCtx =
        new CallbackContext{ this, ucpEp, requestID, putRequest, nullptr, std::move(event), GetSteadyClockNs() };
    EnqueueFlush(flushCtx);
    totalPoint.Record();

    return Status::OK();
}

std::vector<ucp_dt_iov_t> *UcpWorker::PrepareIovBuffer(const std::vector<IovSegment> &segments)
{
    std::vector<ucp_dt_iov_t> *iov = new std::vector<ucp_dt_iov_t>(segments.size());
    for (size_t i = 0; i < segments.size(); ++i) {
        (*iov)[i].buffer = reinterpret_cast<void *>(segments[i].localAddr);
        (*iov)[i].length = segments[i].size;
    }
    return iov;
}

Status UcpWorker::WriteN(const std::string &remoteRkey, uintptr_t remoteBaseAddr, const std::string &remoteWorkerAddr,
                         const std::string &ipAddr, const std::vector<IovSegment> &segments, uint64_t requestID,
                         std::shared_ptr<Event> event)
{
    auto req = std::make_shared<SubmitRequest>();
    req->isIov = true;
    req->remoteRkey = remoteRkey;
    req->remoteAddr = remoteBaseAddr;
    req->remoteWorkerAddr = remoteWorkerAddr;
    req->ipAddr = ipAddr;
    req->segments = segments;
    req->requestID = requestID;
    req->event = std::move(event);
    req->enqueueNs = GetSteadyClockNs();
    return EnqueueSubmit(std::move(req));
}

Status UcpWorker::WriteNDirect(const std::string &remoteRkey, uintptr_t remoteBaseAddr,
                               const std::string &remoteWorkerAddr, const std::string &ipAddr,
                               const std::vector<IovSegment> &segments, uint64_t requestID,
                               std::shared_ptr<Event> event)
{
    if (segments.empty()) {
        return Status::OK();
    }

    const auto &ucpEp = GetOrCreateEndpoint(ipAddr, remoteWorkerAddr);
    if (ucpEp == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to create Endpoint.");
    }
    const ucp_ep_h &ep = ucpEp->GetEp();
    if (ep == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " UcpEndpoint contained an empty endpoint?");
    }

    ucp_rkey_h rkey = ucpEp->GetOrUnpackRkey(remoteRkey);
    if (rkey == nullptr) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + ipAddr + " Failed to get an unpack rkey.");
    }

    std::vector<ucp_dt_iov_t> *iov = PrepareIovBuffer(segments);
    ucp_request_param_t putParam{};
    putParam.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE;
    putParam.datatype = ucp_dt_make_iov();

    void *putRequest = ds_ucp_put_nbx(ep, iov->data(), segments.size(), remoteBaseAddr, rkey, &putParam);
    if (UCS_PTR_IS_ERR(putRequest)) {
        ucs_status_t status = UCS_PTR_STATUS(putRequest);
        LOG(ERROR) << errorMsgHead_
                   << " Failed to execute ucp_put_nbx with IOV. Status: " << ds_ucs_status_string(status);
        delete iov;
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Failed to send data immediately with IOV.");
    }

    EnqueueFlush(new CallbackContext{ this, ucpEp, requestID, putRequest, iov, std::move(event), GetSteadyClockNs() });

    return Status::OK();
}

Status UcpWorker::RemoveEndpointByIp(const std::string &ipAddr)
{
    std::unique_lock writeLock(mapLock_);
    if (remoteEndpointMap_.erase(ipAddr) > 0) {
        return Status::OK();
    }
    VLOG(1) << errorMsgHead_ << " Try to remove by IP but never sent to " << ipAddr;
    return Status::OK();
}

void UcpWorker::StartSubmitThread()
{
    if (submitRunning_.load()) {
        return;
    }
    submitRunning_.store(true);
    submitThread_ = std::make_unique<Thread>(&UcpWorker::SubmitLoop, this);
    submitThread_->set_name("UcpWorker_" + std::to_string(workerId_) + "_Submit");
}

void UcpWorker::StopSubmitThread()
{
    if (!submitRunning_.load()) {
        return;
    }

    submitRunning_.store(false);
    submitCv_.notify_all();
    if (submitThread_ && submitThread_->joinable()) {
        submitThread_->join();
        submitThread_.reset();
    }
}

void UcpWorker::EnqueueFlush(CallbackContext *ctx)
{
    {
        std::lock_guard<std::mutex> lock(submitMutex_);
        flushQueue_.emplace_back(ctx);
    }
    submitCv_.notify_one();
}

Status UcpWorker::EnqueueSubmit(std::shared_ptr<SubmitRequest> req)
{
    if (!submitRunning_.load()) {
        RETURN_STATUS(K_RDMA_ERROR, errorMsgHead_ + " Submit thread is not running.");
    }

    {
        std::lock_guard<std::mutex> lock(submitMutex_);
        submitQueue_.emplace_back(req);
    }
    submitCv_.notify_one();

    return Status::OK();
}

void UcpWorker::SubmitLoop()
{
    for (;;) {
        std::deque<std::shared_ptr<SubmitRequest>> pendingSubmits;
        {
            std::unique_lock<std::mutex> lock(submitMutex_);
            if (submitRunning_.load() && submitQueue_.empty() && flushQueue_.empty() && outstandingFlushes_ == 0) {
                submitCv_.wait_for(lock, std::chrono::microseconds(COMM_IDLE_WAIT_US));
            }
            if (!submitRunning_.load() && submitQueue_.empty() && flushQueue_.empty() && outstandingFlushes_ == 0) {
                break;
            }

            const uint32_t budget = GetCommSubmitBudget();
            while (!submitQueue_.empty() && pendingSubmits.size() < budget) {
                pendingSubmits.emplace_back(std::move(submitQueue_.front()));
                submitQueue_.pop_front();
            }
        }

        for (auto &req : pendingSubmits) {
            PerfPoint::RecordElapsed(PerfKey::RDMA_UCP_WORKER_SUBMIT_QUEUE_WAIT,
                                     GetSteadyClockNs() - req->enqueueNs);
            Status status = req->isIov
                                ? WriteNDirect(req->remoteRkey, req->remoteAddr, req->remoteWorkerAddr, req->ipAddr,
                                               req->segments, req->requestID, req->event)
                                : WriteDirect(req->remoteRkey, req->remoteAddr, req->remoteWorkerAddr, req->ipAddr,
                                              req->localSegAddr, req->localSegSize, req->requestID, req->event);
            if (status.IsError() && req->event != nullptr) {
                req->event->SetFailed();
                req->event->NotifyAll();
            }
        }

        std::deque<CallbackContext *> pendingFlushes;
        {
            std::lock_guard<std::mutex> lock(submitMutex_);
            pendingFlushes.swap(flushQueue_);
        }
        std::unordered_map<ucp_ep_h, std::vector<CallbackContext *>> batches;
        for (auto *ctx : pendingFlushes) {
            batches[ctx->ep].emplace_back(ctx);
        }
        for (auto &entry : batches) {
            auto *batchCtx = new FlushBatchContext{ this, std::move(entry.second), false };
            PerfPoint::RecordElapsed(PerfKey::RDMA_UCP_WORKER_FLUSH_BATCH_SIZE, batchCtx->contexts.size());
            ucp_request_param_t flushParam{};
            flushParam.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA;
            flushParam.cb.send = FlushCallBack;
            flushParam.user_data = batchCtx;

            void *flushRequest = ds_ucp_ep_flush_nbx(entry.first, &flushParam);
            if (UCS_PTR_IS_ERR(flushRequest)) {
                ucs_status_t status = UCS_PTR_STATUS(flushRequest);
                LOG(ERROR) << errorMsgHead_ << " Failed to execute ucp_ep_flush_nbx. Status: "
                           << ds_ucs_status_string(status);
                FinishBatch(batchCtx, true);
            } else if (flushRequest == nullptr) {
                FlushCallBack(flushRequest, UCS_OK, batchCtx);
            } else {
                batchCtx->inflight = true;
                ++outstandingFlushes_;
            }
        }

        bool progressed = false;
        while (ds_ucp_worker_progress(worker_)) {
            progressed = true;
        }
        if (!progressed && outstandingFlushes_ > 0) {
            std::unique_lock<std::mutex> lock(submitMutex_);
            submitCv_.wait_for(lock, std::chrono::microseconds(COMM_IDLE_WAIT_US), [this]() {
                return !submitRunning_.load() || !submitQueue_.empty() || !flushQueue_.empty();
            });
        }
    }
}

std::shared_ptr<UcpEndpoint> UcpWorker::GetOrCreateEndpoint(const std::string &ipAddr,
                                                            const std::string &remoteWorkerAddr)
{
    {
        std::shared_lock readLock(mapLock_);

        auto it = remoteEndpointMap_.find(ipAddr);
        if (it != remoteEndpointMap_.end()) {
            return it->second;
        }
    }

    std::unique_lock writeLock(mapLock_);

    auto it = remoteEndpointMap_.find(ipAddr);
    if (it != remoteEndpointMap_.end()) {
        return it->second;
    }

    std::shared_ptr<UcpEndpoint> ep = std::make_shared<UcpEndpoint>(worker_, remoteWorkerAddr);
    Status status = ep->Init();
    if (status.IsError()) {
        LOG(ERROR) << "In " << errorMsgHead_ << ": " << status.ToString();
        return nullptr;
    }
    remoteEndpointMap_.emplace(ipAddr, std::move(ep));

    return remoteEndpointMap_[ipAddr];
}

void UcpWorker::Clean()
{
    StopSubmitThread();

    remoteEndpointMap_.clear();

    if (localWorkerAddr_) {
        ds_ucp_worker_release_address(worker_, localWorkerAddr_);
        localWorkerAddr_ = nullptr;
    }

    if (worker_ != nullptr) {
        ds_ucp_worker_destroy(worker_);
        worker_ = nullptr;
    }
}

void UcpWorker::FinishContext(CallbackContext *ctx, bool failed)
{
    if (ctx->flush_start_ns != 0) {
        PerfPoint::RecordElapsed(PerfKey::RDMA_UCP_WORKER_FLUSH_CALLBACK,
                                 GetSteadyClockNs() - ctx->flush_start_ns);
    }
    if (ctx->event != nullptr) {
        if (failed) {
            ctx->event->SetFailed();
        }
        ctx->event->NotifyAll();
    }

    if (ctx->put_request != nullptr) {
        ds_ucp_request_free(ctx->put_request);
    }
    if (ctx->iov != nullptr) {
        delete ctx->iov;
    }
    delete ctx;
}

void UcpWorker::FinishBatch(FlushBatchContext *batchCtx, bool failed)
{
    for (auto *ctx : batchCtx->contexts) {
        FinishContext(ctx, failed);
    }
    delete batchCtx;
}

void UcpWorker::FlushCallBack(void *request, ucs_status_t status, void *userData)
{
    auto *batchCtx = static_cast<FlushBatchContext *>(userData);
    const bool failed = status != UCS_OK;
    if (failed) {
        LOG(ERROR) << batchCtx->worker->errorMsgHead_ << " Flush failed. Status: " << ds_ucs_status_string(status);
    }
    if (batchCtx->inflight) {
        if (batchCtx->worker->outstandingFlushes_ == 0) {
            LOG(ERROR) << batchCtx->worker->errorMsgHead_ << " Outstanding flush counter underflow.";
        } else {
            --batchCtx->worker->outstandingFlushes_;
        }
    }
    batchCtx->worker->FinishBatch(batchCtx, failed);
    if (request != nullptr) {
        ds_ucp_request_free(request);
    }
}
}  // namespace datasystem