/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 */

#ifndef OMNISTREAM_RECOVEREDINPUTCHANNEL_H
#define OMNISTREAM_RECOVEREDINPUTCHANNEL_H

#include "deque"
#include "stdint.h"
#include "InputChannel.h"
#include "mutex"
#include "runtime/buffer/Buffer.h"
#include "runtime/buffer/ObjectBufferDataType.h"
#include "runtime/checkpoint/channel/ChannelStateWriter.h"
#include "runtime/partition/consumer/SingleInputGate.h"
#include "runtime/executiongraph/descriptor/ResultPartitionIDPOD.h"
#include "runtime/metrics/Counter.h"
#include "core/include/common.h"
#include "runtime/io/network/api/serialization/EventSerializer.h"
#include "runtime/event/EndOfChannelStateEvent.h"
#include "BufferAndAvailability.h"
#include "BufferManager.h"
#include "core/utils/threads/CompletableFutureV2.h"

class RecoveredInputChannel : public omnistream::InputChannel {
public:

    RecoveredInputChannel(std::shared_ptr<omnistream::SingleInputGate> inputGate, int channelIndex,
                          omnistream::ResultPartitionIDPOD partitionId,
                          int consumedSubpartitionIndex, int initialBackoff, int maxBackoff,
                          std::shared_ptr<omnistream::Counter> numBytesIn,
                          std::shared_ptr<omnistream::Counter> numBuffersIn, int networkBuffersPerChannel)
            : omnistream::InputChannel(inputGate, channelIndex, partitionId, initialBackoff, maxBackoff, numBytesIn,
                                       numBuffersIn) {
        setConsumedSubpartitionIndex(consumedSubpartitionIndex);
        bufferManager = std::make_shared<BufferManager>(inputGate->getSegmentProvider(), this, 1);
        this->networkBuffersPerChannel =networkBuffersPerChannel;
        stateConsumedFuture = std::make_shared<CompletableFutureV2<void>>();
    }

    void SetChannelStateWriter(std::shared_ptr<ChannelStateWriter> stateWriter) {
        if (stateWriter == nullptr){
            LOG("invalid param, nullptr!");
            return;
        }
        if(this->channelStateWriter != nullptr){
            LOG("channel state writer already init!");
            return;
        }
        this->channelStateWriter = stateWriter;
    }

    std::shared_ptr<ChannelStateWriter> GetChannelStateWriter() const {
        return channelStateWriter;
    }

    std::shared_ptr<omnistream::InputChannel> toInputChannel();

    virtual std::shared_ptr<omnistream::InputChannel> toInputChannelInternal() = 0;

    void CheckpointStopped(long checkpointId) override
    {
        lastStoppedCheckpointId = checkpointId;
    }

    std::shared_ptr<CompletableFutureV2<void>> getStateConsumedFuture()
    {
        return stateConsumedFuture;
    }

    bool getStateConsumedFuture1()
    {
        return stateConsumedFuture1.load();
    }

    void onRecoveredStateBuffer(Buffer *buffer);
    void onRecoveredStateBuffer2(Buffer *buffer);

    void finishReadRecoveredState();

    std::optional<omnistream::BufferAndAvailability>  getNextRecoveredStateBuffer();

    omnistream::ObjectBufferDataType peekDataTypeUnsafe();

    bool isEndOfChannelStateEvent(Buffer *buffer);

    std::optional<BufferAndAvailability> getNextBuffer() override;

    int getBuffersInUseCount() override
    {
        std::lock_guard<std::mutex> lock(bufferLock);
        return static_cast<int>(receivedBuffers.size());
    }

    void resumeConsumption() override
    {
        throw std::invalid_argument("RecoveredInputChannel should never be blocked.");
    }

    void acknowledgeAllRecordsProcessed() override
    {
        throw std::invalid_argument("RecoveredInputChannel should not need acknowledge all records processed.");
    }

    void requestSubpartition(int subpartitionIndex) override
    {
        throw std::invalid_argument("RecoveredInputChannel should never request partition.");
    }

    void sendTaskEvent(std::shared_ptr<TaskEvent> event) override {
        throw std::invalid_argument("RecoveredInputChannel should never send any task events.");
    }

    bool isReleased() override
    {
        std::lock_guard<std::mutex> lock(bufferLock);
        return released;
    }

    void releaseAllResources() override;

    int getNumberOfQueuedBuffers()
    {
        std::lock_guard<std::mutex> lock(bufferLock);
        return static_cast<int>(receivedBuffers.size());
    }

    std::shared_ptr<omnistream::Buffer> requestBufferBlocking();

    void CheckpointStarted(const CheckpointBarrier& barrier) override
    {
        throw std::invalid_argument("Checkpoint was declined (tasks not ready)");
    }

    void announceBufferSize(int newBufferSize) override {}

    int getNetworkBuffersPerChannel()
    {
        return networkBuffersPerChannel;
    }

    void SetIsOmniChannel(bool isOmniChannel)
    {
        toOmniChannel_ = isOmniChannel;
    }

    bool IsOmniChannel() const
    {
        return toOmniChannel_;
    }

private:
    struct RecoveredBufferEntry {
        Buffer* buffer;
        std::shared_ptr<omnistream::Buffer> owner;

        RecoveredBufferEntry(Buffer* b, std::shared_ptr<omnistream::Buffer> o = nullptr)
            : buffer(b), owner(std::move(o)) {}
    };

    std::deque<RecoveredBufferEntry> receivedBuffers;
    std::deque<std::shared_ptr<omnistream::Buffer>> consumedRecoveredBufferOwners;

    std::shared_ptr<CompletableFutureV2<void>> stateConsumedFuture;
    std::atomic<bool> stateConsumedFuture1{false};
    std::shared_ptr<BufferManager> bufferManager;
    bool released = false;
    std::shared_ptr<ChannelStateWriter> channelStateWriter;
    int sequenceNumber = std::numeric_limits<int>::min();
    int networkBuffersPerChannel;
    bool exclusiveBuffersAssigned = false;
    long lastStoppedCheckpointId = -1;
    bool toOmniChannel_ = false;
    std::mutex bufferLock;
};


#endif // OMNISTREAM_RECOVEREDINPUTCHANNEL_H