/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include "LocalObjectBufferPool.h"

#include <iostream>
#include <algorithm>
#include <stdexcept>
#include <climits>

#include "NetworkObjectBufferPool.h"
#include "ObjectBufferListener.h"
#include "VectorBatchBuffer.h"

namespace omnistream {
    LocalObjectBufferPool::LocalObjectBufferPool(std::shared_ptr<NetworkObjectBufferPool> networkObjBufferPool,
                                                 int numberOfRequiredObjectSegments,
                                                 int maxNumberOfMemorySegments,
                                                 int numberOfSubpartitions,
                                                 int maxBuffersPerChannel)
        : LocalBufferPool(networkObjBufferPool, numberOfSubpartitions, maxBuffersPerChannel, numberOfRequiredObjectSegments, numberOfRequiredObjectSegments, maxNumberOfMemorySegments, std::make_shared<AvailabilityHelper>()),
          networkObjBufferPool_(networkObjBufferPool),
          maxNumberOfObjectSegments_(maxNumberOfMemorySegments),
          numberOfRequestedObjectSegments_(0),
          subpartitionBufferRecyclers_(numberOfSubpartitions)
    {
        LOG_PART("Beginning of constructor")
        LOG_PART(" numberOfRequiredObjectSegments_"  << numberOfRequiredSegments_
            << " maxNumberOfMemorySegments_"  << maxNumberOfObjectSegments_
            <<  "currentPoolSize_"  << currentPoolSize_
            << " maxBuffersPerChannel_"  << maxBuffersPerChannel_)

        if (numberOfRequiredSegments_ <= 0) {
            throw std::invalid_argument(
                "Required number of memory segments (" + std::to_string(numberOfRequiredSegments_) +
                ") should be larger than 0.");
        }

        if (maxNumberOfMemorySegments < numberOfRequiredSegments_) {
            throw std::invalid_argument(
                "Maximum number of memory segments (" + std::to_string(maxNumberOfMemorySegments) +
                ") should not be smaller than minimum (" + std::to_string(numberOfRequiredSegments_) + ").");
        }

        if (numberOfSubpartitions > 0) {
            if (maxBuffersPerChannel <= 0) {
                throw std::invalid_argument(
                    "Maximum number of buffers for each channel (" + std::to_string(maxBuffersPerChannel) +
                    ") should be larger than 0.");
            }
        }

        {
            std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
            LOG("constructor get lock")
            /*if (checkAvailability()) {
                availabilityHelper_->resetAvailable();
            }
            checkConsistentAvailability();*/
            checkAndUpdateAvailability();
        }
        LOG("LocalObjectBufferPool constructor end")
    }

    void LocalObjectBufferPool::postConstruct()
    {
        LOG("LocalObjectBufferPool post constructor end")
        for (size_t i = 0; i < subpartitionBufferRecyclers_.size(); i++) {
            subpartitionBufferRecyclers_[i] = std::make_shared<SubpartitionBufferRecycler>(i,   shared_from_this());
        }
    }

    void LocalObjectBufferPool::reserveSegments(int numberOfSegmentsToReserve)
    {
        if (numberOfSegmentsToReserve > numberOfRequiredSegments_) {
            throw std::invalid_argument("Can not reserve more segments than number of required segments.");
        }

        std::shared_ptr<CompletableFuture> toNotify = nullptr;
        {
            std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
            if (isDestroyed_) {
                throw std::runtime_error("Buffer pool has been destroyed.");
            }

            if (numberOfRequestedObjectSegments_ < numberOfSegmentsToReserve) {
                auto segments = networkObjBufferPool_->requestPooledObjectSegmentsBlocking(
                    numberOfSegmentsToReserve - numberOfRequestedObjectSegments_);
                availableSegments.insert(availableSegments.end(), segments.begin(), segments.end());
                toNotify = availabilityHelper_->getUnavailableToResetAvailable();
            }
        }
        mayNotifyAvailable(toNotify);
    }

    bool LocalObjectBufferPool::isDestroyed()
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        return isDestroyed_;
    }


    int LocalObjectBufferPool::getMaxNumberOfSegments() const
    {
        return maxNumberOfObjectSegments_;
    }

    int LocalObjectBufferPool::getNumberOfAvailableSegments()
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        return static_cast<int>(availableSegments.size());
    }

    int LocalObjectBufferPool::getNumBuffers()
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        return static_cast<int>(currentPoolSize_);
    }

    int LocalObjectBufferPool::bestEffortGetNumOfUsedBuffers() const
    {
        int best = numberOfRequestedObjectSegments_ - static_cast<int>(availableSegments.size());
        return best > 0 ? best : 0;
    }

    std::shared_ptr<Buffer> LocalObjectBufferPool::requestBuffer()
    {
        return requestObjectBuffer();
    }

    BufferBuilder *LocalObjectBufferPool::requestBufferBuilder()
    {
        return requestObjectBufferBuilder();
    }

    BufferBuilder *LocalObjectBufferPool::requestBufferBuilder(int targetChannel)
    {
        return requestObjectBufferBuilder(targetChannel);
    }

    BufferBuilder *LocalObjectBufferPool::requestBufferBuilderBlocking()
    {
        return requestObjectBufferBuilderBlocking();
    }

    BufferBuilder *LocalObjectBufferPool::requestBufferBuilderBlocking(int targetChannel)
    {
        return requestObjectBufferBuilderBlocking(targetChannel);
    }


    std::shared_ptr<ObjectBuffer> LocalObjectBufferPool::requestObjectBuffer()
    {
        LOG(">>>")
        return toObjectBuffer(requestObjectSegment());
    }

    ObjectBufferBuilder *LocalObjectBufferPool::requestObjectBufferBuilder()
    {
        LOG(">>>")
        return toObjectBufferBuilder(requestObjectSegment(UNKNOWN_CHANNEL), UNKNOWN_CHANNEL);
    }

    ObjectBufferBuilder *LocalObjectBufferPool::requestObjectBufferBuilder(int targetChannel)
    {
        LOG(">>>")
        return toObjectBufferBuilder(requestObjectSegment(targetChannel), targetChannel);
    }

    ObjectBufferBuilder *LocalObjectBufferPool::requestObjectBufferBuilderBlocking()
    {
        LOG(">>>")
        return toObjectBufferBuilder(requestObjectSegmentBlocking(), UNKNOWN_CHANNEL);
    }

    ObjectSegment *LocalObjectBufferPool::requestObjectSegmentBlocking()
    {
        LOG(">>>")
        return requestObjectSegmentBlocking(UNKNOWN_CHANNEL);
    }

    ObjectBufferBuilder *LocalObjectBufferPool::requestObjectBufferBuilderBlocking(int targetChannel)
    {
        LOG(">>>")
        return toObjectBufferBuilder(requestObjectSegmentBlocking(targetChannel), targetChannel);
    }

    std::shared_ptr<ObjectBuffer> LocalObjectBufferPool::toObjectBuffer(ObjectSegment *objectSegment)
    {
        LOG(">>>")
        if (!objectSegment) {
            return nullptr;
        }
        return std::make_shared<VectorBatchBuffer>(objectSegment, shared_from_this());
    }

    ObjectBufferBuilder *LocalObjectBufferPool::toObjectBufferBuilder(
            ObjectSegment *memorySegment, int targetChannel)
    {
        LOG("LocalObjectBufferPool::toObjectBufferBuilder running")
        if (!memorySegment) {
            return nullptr;
        }

        if (targetChannel == UNKNOWN_CHANNEL) {
            LOG("ObjectBufferBuilder with subpartitionBufferRecyclers_ with this")
            return new ObjectBufferBuilder(memorySegment, shared_from_this());
        } else {
            LOG("ObjectBufferBuilder with subpartitionBufferRecyclers_")
            LOG("subpartitionBufferRecyclers_[targetChannel] " << std::to_string(targetChannel)  << "  " \
                << ((subpartitionBufferRecyclers_[targetChannel]) ? std::to_string(reinterpret_cast<long>(subpartitionBufferRecyclers_[targetChannel].get())) : "nullptr")
                << std::endl)

            return new ObjectBufferBuilder(memorySegment, subpartitionBufferRecyclers_[targetChannel]);
        }
    }


    Segment *LocalObjectBufferPool::requestSegmentBlocking(int targetChannel)
    {
        return requestObjectSegmentBlocking(targetChannel);
    }


    ObjectSegment *LocalObjectBufferPool::requestObjectSegmentBlocking(int targetChannel)
    {
        ObjectSegment *segment;
        LOG("requestObjectSegment loop will running")
        LOG_PART(" Back Pressure possible happens, current segment in pool is " << availableSegments.size())
        while (!(segment = requestObjectSegment(targetChannel))) {
            if (cancelled_) {
                throw std::runtime_error("Buffer pool request was cancelled.");
            }
            if (isDestroyed_) {
                throw std::runtime_error("Buffer pool is destroyed.");
            }
            LOG_PART(
                " Back Pressure happens, current segment in pool is " << availableSegments.size() <<
                "for channel "<< targetChannel)
            // workaround sleep for a while
            std::this_thread::sleep_for(std::chrono::milliseconds(100));
        }
        return segment;
    }


    bool LocalObjectBufferPool::requestSegmentFromGlobal()
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        LOG("requestObjectSegmentFromGlobal get lock")
        if (isRequestedSizeReached()) {
            return false;
        }

        if (isDestroyed_) {
            throw std::runtime_error(
                "Destroyed buffer pools should never acquire segments - this will lead to buffer leaks.");
        }

        ObjectSegment *segment = networkObjBufferPool_->requestPooledObjectSegment();
        if (segment != nullptr) {
            availableSegments.push_back(segment);
            numberOfRequestedObjectSegments_++;

            LOG_PART("requestPooledObjectSegment from networkObjBufferPool_ , numberOfRequestedObjectSegments_  " << numberOfRequestedObjectSegments_
                << " currentPoolSize_ :" << currentPoolSize_)
            return true;
        }
        return false;
    }


    Segment *LocalObjectBufferPool::requestSegment()
    {
        return requestObjectSegment(UNKNOWN_CHANNEL);
    }


    Segment *LocalObjectBufferPool::requestSegment(int targetChannel)
    {
        return requestObjectSegment(targetChannel);
    }

    Segment *LocalObjectBufferPool::requestSegmentBlocking()
    {
        return requestSegmentBlocking(UNKNOWN_CHANNEL);
    }

    ObjectSegment *LocalObjectBufferPool::requestObjectSegment(int targetChannel)
    {
        LOG("requestObjectSegment in LocalObjectBufferPool")
        ObjectSegment *segment;
        {
            std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
            LOG("get lock std::this_thread::get_id()" << std::this_thread::get_id())
            if (isDestroyed_) {
                throw std::runtime_error("Buffer pool is destroyed.");
            }

            if (targetChannel != UNKNOWN_CHANNEL && subpartitionBuffersCount_[targetChannel] >= maxBuffersPerChannel_) {
                return nullptr;
            }

            if (!availableSegments.empty()) {
                LOG("availableObjectSegments is not empty")
                segment = dynamic_cast<ObjectSegment*>(availableSegments.front());
                LOG("availableObjectSegments is segment.get()" << segment << "segment " << segment)
                availableSegments.pop_front();
                LOG("availableObjectSegments_.size()" << availableSegments.size())
                LOG_PART("requestObjectSegment for targetChannel " << targetChannel
                      << "availableObjectSegments_.size()" << availableSegments.size())
            } else {
                LOG("availableObjectSegments is empty")
                return nullptr;
            }

            if (targetChannel != UNKNOWN_CHANNEL) {
                if (++subpartitionBuffersCount_[targetChannel] == maxBuffersPerChannel_) {
                    unavailableSubpartitionsCount_++;
                }
            }

//            if (!checkAvailability()) {
//                availabilityHelper_->resetUnavailable();
//            }
//
//            checkConsistentAvailability();
            checkAndUpdateAvailability();
            LOG("unlock std::this_thread::get_id()" << std::this_thread::get_id())
        }
        return segment;
    }


    ObjectSegment *LocalObjectBufferPool::requestObjectSegment()
    {
        return requestObjectSegment(-1);
    }

    void LocalObjectBufferPool::lazyDestroy()
    {
        std::shared_ptr<CompletableFuture> toNotify = nullptr;
        {
            std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
            if (!isDestroyed_) {
                ObjectSegment *segment;
                while (!availableObjectSegments_.empty()) {
                    segment = availableObjectSegments_.front();
                    availableObjectSegments_.pop_front();
                    returnObjectSegment(segment);
                }

                std::shared_ptr<ObjectBufferListener> listener;
                while (!registeredListeners_.empty()) {
                    listener = std::dynamic_pointer_cast<ObjectBufferListener>(registeredListeners_.front());
                    registeredListeners_.pop_front();
                    listener->notifyBufferDestroyed();
                }
                if (!isAvailable()) {
                    toNotify = availabilityHelper_->GetAvailableFuture();
                }

                isDestroyed_ = true;
            }
        }

        mayNotifyAvailable(toNotify);

        networkObjBufferPool_->destroyBufferPool(shared_from_this());
    }

    std::shared_ptr<CompletableFuture> LocalObjectBufferPool::GetAvailableFuture()
    {
        return availabilityHelper_->GetAvailableFuture();
    }

    std::string LocalObjectBufferPool::toString() const
    {
        return "[size: " + std::to_string(currentPoolSize_) +
            ", required: " + std::to_string(numberOfRequiredSegments_) +
            ", requested: " + std::to_string(numberOfRequestedObjectSegments_) +
            ", available: " + std::to_string(availableSegments.size()) +
            ", max: " + std::to_string(maxNumberOfObjectSegments_) +
            ", listeners: " + std::to_string(registeredListeners_.size()) +
            ", subpartitions: " + std::to_string(subpartitionBuffersCount_.size()) +
            ", maxBuffersPerChannel: " + std::to_string(maxBuffersPerChannel_) +
            ", destroyed: " + (isDestroyed_ ? "true" : "false") + "]";
    }

    // void LocalObjectBufferPool::mayNotifyAvailable(std::shared_ptr<CompletableFuture> toNotify)
    // {
    //     {
    //     }
    // }

    void LocalObjectBufferPool::returnSegment(Segment *segment)
    {
        auto toRecycledSegment = dynamic_cast<ObjectSegment*>(segment);
        if (!toRecycledSegment) {
            throw std::runtime_error("Segment is not of type ObjectSegment.");
        }
        returnObjectSegment(toRecycledSegment);
    }

    void LocalObjectBufferPool::returnObjectSegment(ObjectSegment *segment)
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        numberOfRequestedObjectSegments_--;
    }

    void LocalObjectBufferPool::returnExcessSegments()
    {
        returnExcessObjectSegments();
    }

    void LocalObjectBufferPool::returnExcessObjectSegments()
    {
        std::lock_guard<std::recursive_mutex> lock(availableSegmentsLock);
        while (hasExcessBuffers()) {
            if (availableSegments.empty()) {
                return;
            }

            ObjectSegment *segment = dynamic_cast<ObjectSegment*>(availableSegments.front());
            availableSegments.pop_front();
            returnObjectSegment(segment);
        }
    }

    bool LocalObjectBufferPool::hasExcessBuffers()
    {
        return numberOfRequestedObjectSegments_ > currentPoolSize_;
    }

    bool LocalObjectBufferPool::isRequestedSizeReached()
    {
        return numberOfRequestedObjectSegments_ >= currentPoolSize_;
    }

    LocalObjectBufferPool::SubpartitionBufferRecycler::SubpartitionBufferRecycler(
        int channel, std::shared_ptr<LocalBufferPool> bufferPool)
        : channel_(channel), bufferPool_(bufferPool)
    {
    }

    void LocalObjectBufferPool::SubpartitionBufferRecycler::recycle(Segment *segment)
    {
        bufferPool_->recycle(segment, channel_);
    }

} ///