/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * Description: Defines the worker class to communicate with the worker service.
 */
#include "datasystem/worker/object_cache/worker_request_manager.h"
#include <atomic>
#include <memory>
#include <mutex>

#include "datasystem/common/iam/tenant_auth_manager.h"
#include "datasystem/common/metrics/kv_metrics.h"
#include "datasystem/common/object_cache/object_base.h"
#include "datasystem/common/object_cache/shm_guard.h"
#include "datasystem/common/os_transport_pipeline/os_transport_pipeline_worker_api.h"
#include "datasystem/common/rdma/fast_transport_manager_wrapper.h"
#include "datasystem/common/util/raii.h"
#include "datasystem/common/util/status_helper.h"
#include "datasystem/common/util/thread_local.h"
#include "datasystem/object/buffer.h"
#include "datasystem/utils/status.h"
#include "datasystem/worker/client_manager/client_manager.h"
#include "datasystem/worker/object_cache/cache_hit_info.h"
#include "datasystem/worker/object_cache/service/worker_oc_service_crud_common_api.h"

namespace datasystem {
namespace object_cache {
std::function<Status(const std::string &, uint64_t)> WorkerRequestManager::deleteFunc_ = nullptr;

Status GetRequest::Init(const std::string &tenantId, const GetReqPb &req,
                        std::shared_ptr<SharedMemoryRefTable> shmRefTable,
                        std::shared_ptr<ServerUnaryWriterReader<GetRspPb, GetReqPb>> api,
                        std::shared_ptr<ThreadPool> threadPool, const ClientKey &clientId)
{
    CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(Validator::IsBatchSizeUnderLimit(req.object_keys_size()),
                                         StatusCode::K_INVALID, "invalid object size");

    rawObjectKeys_ = TenantAuthManager::ConstructNamespaceUriWithTenantId(tenantId, req.object_keys());

    // Get offset and size.
    uint64_t objectsCount = rawObjectKeys_.size();
    uint64_t readOffsetCount = static_cast<uint64_t>(req.read_offset_list_size());
    uint64_t readSizeCount = static_cast<uint64_t>(req.read_size_list_size());
    CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(
        objectsCount == readOffsetCount || readOffsetCount == 0, K_INVALID,
        FormatString("Invalid readOffsetCount %zu, should be 0 or eqeal to objectCount %zu", readOffsetCount,
                     objectsCount));
    CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(
        objectsCount == readSizeCount || readSizeCount == 0, K_INVALID,
        FormatString("Invalid readSizeCount %zu, should be 0 or eqeal to objectCount %zu", readSizeCount,
                     objectsCount));

    CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(
        readOffsetCount == readSizeCount, K_INVALID,
        FormatString("readOffsetCount %zu should be the same with readSizeCount %zu", readOffsetCount, readSizeCount));

    clientId_ = clientId;
    subTimeout_ = req.sub_timeout();
    requestTimeoutMs_ = req.request_timeout();
    shmRefTable_ = std::move(shmRefTable);
    serverApi_ = std::move(api);
    noQueryL2Cache_ = req.no_query_l2cache();
    enableReturnObjectIndex_ = req.return_object_index();
    clientCommId_ = req.comm_id();
    const bool isPipelineH2D = OsXprtPipln::IsPiplnH2DRequest(req);
    const bool needResponseObjectKeys = !enableReturnObjectIndex_ && !isPipelineH2D;
    hasUbGetInfo_ = req.has_urma_info();
    if (hasUbGetInfo_) {
        ubUrmaInfo_ = req.urma_info();
        ubBufferSize_ = req.ub_buffer_size();
        if (ubBufferSize_ == 0) {
            LOG(WARNING) << "Disable UB Get for client " << clientId_ << " due to empty ub_buffer_size.";
            hasUbGetInfo_ = false;
        }
    }
    objects_.reserve(objectsCount);
    orderedObjectInfos_.reserve(objectsCount);
    if (needResponseObjectKeys) {
        responseObjectKeys_.reserve(objectsCount);
    }
    for (size_t i = 0; i < objectsCount; i++) {
        const auto &objectKey = rawObjectKeys_[i];
        OffsetInfo offsetInfo;
        if (readOffsetCount > 0 && readSizeCount > 0) {
            offsetInfo.readOffset = req.read_offset_list(static_cast<int>(i));
            offsetInfo.readSize = req.read_size_list(static_cast<int>(i));
        }
        GetObjInfo info{ .offsetInfo = offsetInfo, .params = nullptr, .rc = Status::OK() };
        auto [iter, insert] = objects_.emplace(objectKey, std::move(info));
        if (!insert) {
            CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(iter->second.offsetInfo == offsetInfo, K_INVALID,
                                                 FormatString("Duplicate offset read for objectKey %s", objectKey));
        }
        orderedObjectInfos_.emplace_back(&(iter->second));
        if (needResponseObjectKeys) {
            ObjectKey responseObjectKey;
            TenantAuthManager::Instance()->NamespaceUriToObjectKey(objectKey, responseObjectKey);
            responseObjectKeys_.emplace_back(std::move(responseObjectKey));
        }
        if (isPipelineH2D) {
            std::shared_ptr<worker::ClientInfo> clientInfo;
            clientInfo = worker::ClientManager::Instance().GetClientInfo(ClientKey::Intern(req.client_id()));
            CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(clientInfo, K_RUNTIME_ERROR,
                                                 "no clientInfo for client id " + req.client_id());
            CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(!clientInfo->GetDeviceId().empty(), K_RUNTIME_ERROR,
                                                 "device id is empty for pipeline rh2d");
            RETURN_IF_NOT_OK_PRINT_ERROR_MSG(
                OsXprtPipln::ParsePiplnH2DRequest(req, GetH2DChunkManager(), objectKey, i, clientInfo->GetDeviceId()),
                "ParsePiplnH2DRequest failed");
        }
        VLOG(1) << "objectKey " << objectKey << " add to GetRequest success";
    }
    threadPool_ = std::move(threadPool);
    return Status::OK();
}

Status GetRequest::UpdateAfterLocalGet(Status rc, size_t remoteObjectCount)
{
    CHECK_FAIL_RETURN_STATUS(!Registered(), K_RUNTIME_ERROR,
                             FormatString("UpdateAfterLocalGet called after GetRequest Register"));
    auto uniqueObjectCount = objects_.size();
    CHECK_FAIL_RETURN_STATUS(readyCount_ == 0, K_RUNTIME_ERROR,
                             FormatString("Invalid readyCount_ %zu, should be 0 when call UpdateAfterLocalGet"));
    CHECK_FAIL_RETURN_STATUS(uniqueObjectCount >= remoteObjectCount, K_RUNTIME_ERROR,
                             FormatString("The remote object key count %zu exceed the request object count %zu",
                                          remoteObjectCount, uniqueObjectCount));
    // exist in local or failed when get from local
    readyCount_ = uniqueObjectCount - remoteObjectCount;
    if (rc.IsError()) {
        lastRc_ = std::move(rc);
    }

    // Direct return to client if get all objects.
    return remoteObjectCount == 0 ? ReturnToClient() : Status::OK();
}

Status GetRequest::MarkSuccess(const ObjectKey &objectKey, SafeObjType &safeObj)
{
    VLOG(1) << "MarkSuccess for object key " << objectKey;
    auto params = GetObjEntryParams::Create(objectKey, safeObj);
    return MarkSuccessImpl(objectKey, std::move(params));
}

Status GetRequest::MarkFailed(const ObjectKey &objectKey, const Status &rc)
{
    VLOG(1) << "MarkFailed for object key " << objectKey;
    CHECK_FAIL_RETURN_STATUS(rc.IsError(), K_RUNTIME_ERROR, "Invalid Status when MarkFailed");
    auto iter = objects_.find(objectKey);
    CHECK_FAIL_RETURN_STATUS(iter != objects_.cend(), K_RUNTIME_ERROR,
                             FormatString("Not found object key %s in GetRequest", objectKey));
    readyCount_.fetch_add(1, std::memory_order_relaxed);
    {
        std::lock_guard<std::mutex> locker(mutex_);
        lastRc_ = rc;
        iter->second.rc = rc;
    }
    return Status::OK();
}

Status GetRequest::MarkSuccessForNotify(const ObjectKey &objectKey, std::unique_ptr<GetObjEntryParams> params)
{
    VLOG(1) << "MarkSuccessForNotify for object key " << objectKey;
    CHECK_FAIL_RETURN_STATUS(Registered(), K_RUNTIME_ERROR,
                             FormatString("MarkSuccessForNotify called before GetRequest Register"));
    RETURN_IF_NOT_OK(MarkSuccessImpl(objectKey, std::move(params)));
    return GetNotReadyCount() == 0 ? ReturnToClient() : Status::OK();
}

Status GetRequest::MarkSuccessImpl(const ObjectKey &objectKey, std::unique_ptr<GetObjEntryParams> params)
{
    auto iter = objects_.find(objectKey);
    CHECK_FAIL_RETURN_STATUS(iter != objects_.cend(), K_RUNTIME_ERROR,
                             FormatString("Not found object key %s in GetRequest", objectKey));
    // In RH2D scenario, check that the client communicator id corresponds before returning the get response.
    // When data is found local, either shm unit will exist, or the comm id mapping is empty.
    if (IsRemoteH2DEnabled() && params->shmUnit == nullptr && params->remoteH2DHostInfo
        && !params->remoteH2DHostInfo->empty() && !GetClientCommUuid().empty()) {
        RemoteH2DHostInfoMap::const_accessor constAccessor;
        bool found = params->remoteH2DHostInfo->find(constAccessor, GetClientCommUuid());
        CHECK_FAIL_RETURN_STATUS(
            found, K_TRY_AGAIN,
            FormatString("Response is not ready yet for object %s, comm id %s", objectKey, GetClientCommUuid()));
    }
    {
        std::lock_guard<std::mutex> locker(mutex_);
        RETURN_OK_IF_TRUE(iter->second.params != nullptr);
        iter->second.params = std::move(params);
    }
    readyCount_.fetch_add(1, std::memory_order_relaxed);
    return Status::OK();
}

void GetRequest::SetStatus(const Status &rc)
{
    if (rc.IsError()) {
        lastRc_ = rc;
    }
}

size_t GetRequest::GetReadyCount() const
{
    return readyCount_;
}

size_t GetRequest::GetNotReadyCount() const
{
    return objects_.size() - readyCount_;
}

bool GetRequest::AlreadyReturn() const
{
    return isReturn_;
}

const std::string &GetRequest::GetClientId() const
{
    return clientId_;
}

bool GetRequest::NoQueryL2Cache() const
{
    return noQueryL2Cache_;
}

const std::string &GetRequest::GetClientCommUuid() const
{
    return clientCommId_;
}

H2DChunkManager &GetRequest::GetH2DChunkManager()
{
    return chunkManager_;
}

const std::vector<ObjectKey> &GetRequest::GetRawObjectKeys() const
{
    return rawObjectKeys_;
}

std::unordered_map<ObjectKey, GetObjInfo> &GetRequest::GetObjects()
{
    return objects_;
}

std::vector<ObjectKey> GetRequest::GetUniqueObjectkeys() const
{
    std::vector<ObjectKey> objectKeys;
    objectKeys.reserve(objects_.size());
    for (const auto &kv : objects_) {
        objectKeys.emplace_back(kv.first);
    }
    return objectKeys;
}

std::shared_ptr<ServerUnaryWriterReader<GetRspPb, GetReqPb>> GetRequest::GetServerApi() const
{
    return serverApi_;
}

void GetRequest::Register(WorkerRequestManager *workerRequestManager)
{
    workerRequestManager_ = workerRequestManager;
    auto request = shared_from_this();
    for (auto &[objectKey, objectInfo] : objects_) {
        // The object key not found in local and remote
        VLOG(1) << "Register GetRequest for objectKey " << objectKey << ", params "
                << (objectInfo.params == nullptr ? "is null" : "not null") << ", status: " << objectInfo.rc.ToString();
        if (objectInfo.params == nullptr && objectInfo.rc.IsOk()) {
            workerRequestManager_->AddRequest(objectKey, request);
        }
    }
}

void GetRequest::UnRegister()
{
    if (Registered()) {
        workerRequestManager_->RemoveGetRequest(shared_from_this());
    }
}

void GetRequest::SetTimer(std::unique_ptr<TimerQueue::TimerImpl> timer)
{
    std::lock_guard<std::mutex> locker(mutex_);
    timer_ = std::move(timer);
}

bool GetRequest::Registered() const
{
    return workerRequestManager_ != nullptr;
}

Status GetRequest::ReturnToClient(const Status &rc)
{
    INJECT_POINT("worker.Get.beforeReturn");
    bool expected = false;
    RETURN_OK_IF_TRUE(!isReturn_.compare_exchange_strong(expected, true));
    VLOG(1) << "Begin to ReturnToClient, client id: " << clientId_;
    Status lastRc;
    {
        std::lock_guard<std::mutex> locker(mutex_);
        lastRc = lastRc_;
    }
    if (rc.IsError()) {
        lastRc = rc;
    }
    uint64_t totalSize = 0;
    Raii raii([this, &totalSize, &lastRc] {
        GetReqPb reqPb;
        Status accessRc = (lastRc.GetCode() == K_NOT_FOUND) ? Status::OK() : lastRc;
        recorder_->ObjectKeysSummaryRef(rawObjectKeys_).SubTimeoutMs(subTimeout_).DataSize(totalSize)
            .Result(accessRc).Record();
    });
    std::map<std::string, uint64_t> needDeleteObjects;
    Raii deleteRaii([this, &needDeleteObjects] {
        threadPool_->Submit(
            [keysWithVersion = std::move(needDeleteObjects)] { WorkerRequestManager::DeleteObjects(keysWithVersion); });
    });
    int64_t remainingTimeMs = reqTimeoutDuration.CalcRealRemainingTime();
    if (remainingTimeMs <= 0) {
        LOG(ERROR) << "ReturnFromGetRequest timeout when get object: " << VectorToString(rawObjectKeys_);
        UnRegister();
        auto rc = lastRc.IsOk() ? Status(K_RPC_DEADLINE_EXCEEDED, "Rpc timeout") : lastRc;
        Status sendStatusRc = serverApi_->SendStatus(rc);
        this->GetServerApi()->SetRequestComplete();
        return sendStatusRc;
    }
    GetRspPb resp;
    std::vector<RpcMessage> payloads;
    PerfPoint constructPoint(PerfKey::WORKER_RETURN_TO_CLIENT_CONSTRUCT_RESPONSE);
    auto constructRc = ConstructResponse(totalSize, resp, payloads, needDeleteObjects);
    constructPoint.Record();
    if (constructRc.IsError() && lastRc.GetCode() != K_OUT_OF_MEMORY) {
        lastRc = constructRc;
    }
    // Remove the get request from each of the relevant object_get_requests hash
    // tables if it is present there. It should only be present there if the get request timed out.
    PerfPoint unregisterPoint(PerfKey::WORKER_RETURN_TO_CLIENT_UNREGISTER);
    UnRegister();
    unregisterPoint.Record();

    {
        // Close the request time out event.
        std::lock_guard<std::mutex> locker(mutex_);
        if (timer_ != nullptr) {
            if (!TimerQueue::GetInstance()->Cancel(*timer_)) {
                LOG(ERROR) << "Failed to Cancel the timer: " << timer_->GetId();
            }
            timer_.reset();
        }
    }
    resp.mutable_last_rc()->set_error_code(lastRc.GetCode());
    resp.mutable_last_rc()->set_error_msg(lastRc.GetMsg());
    PerfPoint writePoint(PerfKey::WORKER_RETURN_TO_CLIENT_WRITE);
    RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi_->Write(resp), "Write reply to client stream failed.");
    RETURN_IF_NOT_OK_PRINT_ERROR_MSG(serverApi_->SendPayload(payloads), "SendPayload to client stream failed");
    this->GetServerApi()->SetRequestComplete();
    writePoint.Record();
    return Status::OK();
}

Status GetRequest::ConstructResponse(uint64_t &totalSize, GetRspPb &resp, std::vector<RpcMessage> &payloads,
                                     std::map<std::string, uint64_t> &needDeleteObjects)
{
    auto clientInfo = worker::ClientManager::Instance().GetClientInfo(clientId_);
    bool shmEnabled = clientInfo != nullptr && clientInfo->ShmEnabled();
    bool useUbGet = IsUrmaEnabled() && !shmEnabled && hasUbGetInfo_;
    uint64_t ubWriteOffset = 0;

    if (OsXprtPipln::IsPiplnH2DRequest(chunkManager_))
        return OsXprtPipln::ConstructPipelineRH2DResponse(resp, chunkManager_, rawObjectKeys_);

    if (!rawObjectKeys_.empty()) {
        // Avoid per-key protobuf repeated-field growth for large batch get responses.
        resp.mutable_objects()->Reserve(static_cast<int>(rawObjectKeys_.size()));
        if (!shmEnabled) {
            resp.mutable_payload_info()->Reserve(static_cast<int>(rawObjectKeys_.size()));
        }
    }
    Status lastRc;
    for (size_t objectIndex = 0; objectIndex < rawObjectKeys_.size(); objectIndex++) {
        auto &objectKeyUri = rawObjectKeys_[objectIndex];
        Status rc;
        auto *objectInfo = orderedObjectInfos_[objectIndex];
        if (objectInfo == nullptr || objectInfo->params == nullptr) {
            LOG(INFO) << FormatString("Can't find object %s, clientId %s", objectKeyUri, clientId_);
            CacheHitInfo::Instance().IncMissHit(1);
            SetDefaultObjectInfoPb(objectKeyUri, objectIndex, *resp.add_objects());
            continue;
        }
        const auto &params = objectInfo->params;
        totalSize += params->dataSize;
        rc = AddObjectToResponse(objectKeyUri, *objectInfo, objectIndex, shmEnabled, useUbGet, ubWriteOffset, resp,
                                 payloads);
        if (shmEnabled
            && !(IsRemoteH2DEnabled() && params->shmUnit == nullptr && params->remoteH2DHostInfo
                 && !params->remoteH2DHostInfo->empty())) {
            // If object is shm, we increase the refCnt for client.
            // The client will be using this object and be responsible for releasing this object.
            shmRefTable_->AddShmUnit(clientId_, params->shmUnit, requestTimeoutMs_);
        }

        bool needDeleted = params->objectState.IsNeedToDelete();
        INJECT_POINT("worker.AddEntryToGetResponse", [&needDeleted] {
            needDeleted = true;
            return Status::OK();
        });
        if (needDeleted) {
            needDeleteObjects.emplace(objectKeyUri, params->version);
        }
        if (rc.IsError()) {
            LOG(ERROR) << FormatString("Can't find object %s or AddObjectToResponse failed, clientId %s, rc %s",
                                       objectKeyUri, clientId_, rc.ToString());
            lastRc = rc;
            SetDefaultObjectInfoPb(objectKeyUri, objectIndex, *resp.add_objects());
        }
    }
    VLOG(1) << FormatString("The total size of the currently get is %llu", totalSize);
    return lastRc;
}

Status GetRequest::UbWriteHelper(const ObjectKey &objectKeyUri, uint64_t metaSize, uint64_t readSize,
                                 uint64_t readOffset, std::shared_ptr<ShmUnit> shmUnit, GetObjInfo &objectInfo,
                                 size_t objectIndex, uint64_t &ubWriteOffset, GetRspPb &resp)
{
    bool hasCapacity = ubWriteOffset <= ubBufferSize_ && readSize <= ubBufferSize_ - ubWriteOffset;
    if (hasCapacity) {
        METRIC_TIMER(metrics::KvMetricId::WORKER_URMA_WRITE_LATENCY);
        const uint64_t localObjectAddressBase = reinterpret_cast<uint64_t>(shmUnit->GetPointer());
        uint64_t localSegAddress;
        uint64_t localSegSize;
        GetSegmentInfoFromShmUnit(shmUnit, localObjectAddressBase, localSegAddress, localSegSize);
        UrmaRemoteAddrPb urmaInfo = ubUrmaInfo_;
        urmaInfo.set_seg_data_offset(ubUrmaInfo_.seg_data_offset() + ubWriteOffset);
        const uint8_t srcChipId = NumaIdToChipId(shmUnit->GetNumaId());
        const uint8_t dstChipId =
            ubUrmaInfo_.has_chip_id() ? static_cast<uint8_t>(ubUrmaInfo_.chip_id()) : INVALID_CHIP_ID;
        std::vector<uint64_t> eventKeys;
        Status ubRc = UrmaWritePayload(urmaInfo, localSegAddress, localSegSize, localObjectAddressBase + readOffset, 0,
                                       readSize, metaSize, srcChipId, dstChipId, true, eventKeys);
        if (ubRc.IsOk()) {
            ubWriteOffset += readSize;
            METRIC_ADD(metrics::KvMetricId::CLIENT_GET_URMA_READ_TOTAL_BYTES, readSize);
            METRIC_ADD(metrics::KvMetricId::WORKER_TO_CLIENT_TOTAL_BYTES, readSize);
            GetRspPb::PayloadInfoPb *payloadInfo = resp.add_payload_info();
            SetNoShmObjectInfoPb(objectKeyUri, objectIndex, objectInfo, *payloadInfo);
            INJECT_POINT_NO_RETURN("worker.get.urma_write_ok");
            return Status::OK();
        }
        LOG(WARNING) << "UB get write failed for object " << objectKeyUri
                     << ", fallback to TCP payload: " << ubRc.ToString();
        return ubRc;
    }
    LOG(WARNING) << "UB get comm buffer insufficient for object " << objectKeyUri << ", readSize " << readSize
                 << ", used " << ubWriteOffset << ", capacity " << ubBufferSize_ << ", fallback to TCP payload.";
    return Status(K_INVALID, "UB get comm buffer insufficient");
}

Status GetRequest::AddObjectToResponse(const ObjectKey &objectKeyUri, GetObjInfo &objectInfo, size_t objectIndex,
                                       bool shmEnabled, bool useUbGet, uint64_t &ubWriteOffset, GetRspPb &resp,
                                       std::vector<RpcMessage> &outPayloads)
{
    const auto &params = objectInfo.params;
    if (shmEnabled
        || (IsRemoteH2DEnabled() && objectInfo.params->remoteH2DHostInfo
            && !objectInfo.params->remoteH2DHostInfo->empty())) {
        GetRspPb::ObjectInfoPb *object = resp.add_objects();
        SetShmObjectInfoPb(objectKeyUri, objectIndex, *params, *object);
        return Status::OK();
    }

    const uint64_t metaSize = params->metaSize;
    const uint64_t dataSize = params->dataSize;
    objectInfo.offsetInfo.AdjustReadSize(dataSize);
    const uint64_t readOffset = objectInfo.offsetInfo.readOffset;
    const uint64_t readSize = objectInfo.offsetInfo.readSize;

    ShmGuard shmGuard(params->shmUnit, dataSize, metaSize);
    if (WorkerOcServiceCrudCommonApi::ShmEnable()) {
        RETURN_IF_NOT_OK_PRINT_ERROR_MSG(
            shmGuard.TryRLatch(),
            FormatString("Try read latch failed while getting object %s from shmUnit.", objectKeyUri));
    }

    Status ubRc = Status::OK();
    if (useUbGet) {
        ubRc = UbWriteHelper(objectKeyUri, metaSize, readSize, readOffset, params->shmUnit, objectInfo, objectIndex,
                             ubWriteOffset, resp);
        RETURN_OK_IF_TRUE(ubRc.IsOk());
    }

    auto curIndex = outPayloads.size();
    LOG(INFO) << FormatString("CopyShmUnitToPayloads, objectKey: %s, read offset: %ld, read size: %ld", objectKeyUri,
                              readOffset, readSize);
    METRIC_TIMER(metrics::KvMetricId::WORKER_TCP_WRITE_LATENCY);
    if (ubRc.IsError() || (IsUrmaEnabled() && !shmEnabled)) {
        const Status &transportStatus =
            ubRc.IsError() ? ubRc : Status(K_URMA_ERROR, "UB get request fallback to TCP payload before worker UB");
        auto rc = shmGuard.TrackUrmaFallbackTcp(readSize, transportStatus, "worker->client");
        if (rc.IsError()) {
            LOG(WARNING) << "Worker-to-client TCP fallback payload rejected for object " << objectKeyUri
                         << ": " << rc.ToString();
            return rc;
        }
    }
    RETURN_IF_NOT_OK(shmGuard.TransferTo(outPayloads, readOffset, readSize));
    METRIC_ADD(metrics::KvMetricId::CLIENT_GET_TCP_READ_TOTAL_BYTES, readSize);
    METRIC_ADD(metrics::KvMetricId::WORKER_TO_CLIENT_TOTAL_BYTES, readSize);
    auto lastIndex = outPayloads.size();
    GetRspPb::PayloadInfoPb *payloadInfo = resp.add_payload_info();
    SetNoShmObjectInfoPb(objectKeyUri, objectIndex, objectInfo, *payloadInfo);
    for (auto index = curIndex; index < lastIndex; index++) {
        payloadInfo->add_part_index(index);
    }
    return Status::OK();
}

void GetRequest::SetShmObjectInfoPb(const ObjectKey &, size_t objectIndex, GetObjEntryParams &safeEntry,
                                    GetRspPb::ObjectInfoPb &info)
{
    if (enableReturnObjectIndex_) {
        info.set_object_index(objectIndex);
    } else {
        info.set_object_key(responseObjectKeys_[objectIndex]);
    }
    // The existence should have been checked at MarkSuccessImpl.
    RemoteH2DHostInfoMap::const_accessor constAccessor;
    if (IsRemoteH2DEnabled() && safeEntry.remoteH2DHostInfo
        && safeEntry.remoteH2DHostInfo->find(constAccessor, clientCommId_)) {
        *(info.mutable_host_info()) = *(constAccessor->second);
        // Leave the shm unit stuff empty to be clearer.
        info.set_store_fd(0);
        info.set_offset(0);
        info.set_mmap_size(0);
        info.set_shm_id(std::string{});
    } else {
        auto &shmUnit = safeEntry.shmUnit;
        info.set_store_fd(shmUnit->GetFd());
        info.set_offset(static_cast<int64_t>(shmUnit->GetOffset()));
        info.set_mmap_size(static_cast<int64_t>(shmUnit->GetMmapSize()));
        info.set_shm_id(shmUnit->id);
    }
    info.set_data_size(static_cast<int64_t>(safeEntry.dataSize));
    info.set_metadata_size(static_cast<int64_t>(safeEntry.metaSize));
    info.set_version(static_cast<int64_t>(safeEntry.createTime));
    info.set_is_seal(safeEntry.isSealed);
    info.set_write_mode(static_cast<uint32_t>(safeEntry.objectMode.GetWriteMode()));
    info.set_consistency_type(static_cast<uint32_t>(safeEntry.objectMode.GetConsistencyType()));
}

void GetRequest::SetNoShmObjectInfoPb(const ObjectKey &, size_t objectIndex, const GetObjInfo &objectInfo,
                                      GetRspPb::PayloadInfoPb &info)
{
    if (enableReturnObjectIndex_) {
        info.set_object_index(objectIndex);
    } else {
        info.set_object_key(responseObjectKeys_[objectIndex]);
    }
    const auto &safeEntry = *objectInfo.params;
    info.set_data_size(static_cast<int64_t>(objectInfo.offsetInfo.readSize));
    info.set_version(static_cast<int64_t>(safeEntry.createTime));
    info.set_is_seal(safeEntry.isSealed);
    info.set_write_mode(static_cast<uint32_t>(safeEntry.objectMode.GetWriteMode()));
    info.set_consistency_type(static_cast<uint32_t>(safeEntry.objectMode.GetConsistencyType()));
}

void GetRequest::SetDefaultObjectInfoPb(const ObjectKey &, size_t objectIndex, GetRspPb::ObjectInfoPb &info)
{
    if (enableReturnObjectIndex_) {
        info.set_object_index(objectIndex);
    } else {
        info.set_object_key(responseObjectKeys_[objectIndex]);
    }
    info.set_store_fd(-1);
    info.set_offset(-1);
    info.set_data_size(-1);
    info.set_metadata_size(-1);
    info.set_mmap_size(-1);
    info.set_version(-1);
    info.set_is_seal(false);
    info.set_write_mode(static_cast<uint32_t>(WriteMode::NONE_L2_CACHE));
    info.set_consistency_type(static_cast<uint32_t>(ConsistencyType::PRAM));
}

Status WorkerRequestManager::AddRequest(const std::string &objectKey, std::shared_ptr<GetRequest> &request)
{
    return requestTable_.AddRequest(objectKey, request);
}

Status WorkerRequestManager::NotifyPendingGetRequest(ObjectKV &objectKV)
{
    SafeObjType &safeObj = objectKV.GetObjEntry();
    CHECK_FAIL_RETURN_STATUS(safeObj.Get() != nullptr, K_INVALID,
                             "The pointer of entry and memoryRefApi for UpdateRequest is null.");
    auto params = GetObjEntryParams::Create(objectKV.GetObjKey(), safeObj);
    return requestTable_.NotifyPendingGetRequest(objectKV.GetObjKey(), std::move(params));
}

void WorkerRequestManager::RemoveGetRequest(const std::shared_ptr<GetRequest> &request)
{
    VLOG(1) << "Begin to RemoveGetRequest, client id: " << request->GetClientId();
    requestTable_.RemoveRequest(request);
}

void WorkerRequestManager::SetDeleteObjectsFunc(std::function<Status(const std::string &, uint64_t)> deleteFunc)
{
    deleteFunc_ = std::move(deleteFunc);
}

void WorkerRequestManager::DeleteObjects(const std::map<std::string, uint64_t> &objects)
{
    if (deleteFunc_ == nullptr) {
        LOG(ERROR) << "WorkerRequestManager deleteFunc not set.";
        return;
    }

    if (objects.empty()) {
        return;
    }

    LOG(INFO) << "Start to delete objects " << VectorToString(objects);
    for (const auto &kv : objects) {
        // If two client get the same objectKey in the same time, may call delete object twice, and the second
        // call will
        // return fail because the object don't exist on objectTable_. So we don't handle the error when
        // delete.
        LOG_IF_ERROR_EXCEPT(deleteFunc_(kv.first, kv.second), FormatString("delete object %s failed", kv.first),
                            K_NOT_FOUND);
    }
}
}  // namespace object_cache
}  // namespace datasystem