/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#ifndef SEQUENTIAL_CHANNEL_STATE_READER_IMPL_H
#define SEQUENTIAL_CHANNEL_STATE_READER_IMPL_H

#include <fstream>
#include <memory>
#include <vector>
#include <map>
#include "ChannelStateSerializer.h"
#include "RecoveredChannelStateHandler.h"
#include "SequentialChannelStateReader.h"
#include "checkpoint/OperatorSubtaskState.h"
#include "runtime/checkpoint/TaskStateSnapshot.h"

#include "core/memory/MemorySegmentFactory.h"
#include "runtime/buffer/NetworkBuffer.h"
#include "runtime/buffer/BufferRecycler.h"
#include "runtime/buffer/OriginalNetworkBufferRecycler.h"

namespace omnistream {
template <typename Info>
class RescaledOffset {
public:
    long long offset;
    Info channelInfo;
    int oldSubtaskIndex;

    RescaledOffset(long long offset, Info channelInfo, int oldSubtaskIndex)
        :offset(offset), channelInfo(channelInfo), oldSubtaskIndex(oldSubtaskIndex) {}

    std::string toString() const {
        std::ostringstream oss;
        oss << "RescaledOffset{offset=" << offset
            << ", channelInfo=" << (channelInfo ? "Valid" : "Null")
            << ", oldSubtaskIndex=" << oldSubtaskIndex << "}";
        return oss.str();
    }
};

class ChannelStateChunkReader {
public:
    std::shared_ptr<ChannelStateSerializer> serializer;

    ChannelStateChunkReader(std::shared_ptr<ChannelStateSerializer> &serializer) : serializer(serializer) {}

    void readChunkByByteStreamForInputChannel(
        std::shared_ptr<ByteStateHandleInputStream>& source,
        long long sourceOffset,
        std::shared_ptr<InputChannelRecoveredStateHandler> stateHandler,
        const InputChannelInfo &channelInfo,
        int oldSubtaskIndex)
    {
        if (source->GetPos() != sourceOffset) {
            source->Seek(sourceOffset);
        }

        int length = serializer->ReadLength2(source);

        while (length > 0) {
            auto bufferWithContext = newTempInputBufferWithContext();

            try {
                while (length > 0 && bufferWithContext.buffer_->isWritable()) {
                    length -= serializer->ReadData2(source, bufferWithContext.buffer_, length);
                }
            } catch (const std::exception& e){
                bufferWithContext.close();
                throw;
            }
            stateHandler->recover(channelInfo, oldSubtaskIndex, bufferWithContext);
        }
    }

    static RecoveredChannelStateHandler<InputChannelInfo, Buffer *>::BufferWithContext newTempInputBufferWithContext() {
        constexpr int kTempRestoreBufferSize = 32 * 1024;

        auto* memorySegment = MemorySegmentFactory::wrap(kTempRestoreBufferSize);

        auto buffer =new NetworkBuffer(memorySegment, std::make_shared<OriginalNetworkBufferRecycler>(),
            true);

        return {ChannelStateByteBuffer::wrap(buffer), buffer};
    }
    
    void readChunkByByteStreamForResultSubpartition(
        std::shared_ptr<ByteStateHandleInputStream>& source,
        long long sourceOffset,
        std::shared_ptr<ResultSubpartitionRecoveredStateHandler> stateHandler,
        const ResultSubpartitionInfoPOD &channelInfo,
        int oldSubtaskIndex)
    {
        if (source->GetPos() != sourceOffset) {
            source->Seek(sourceOffset);
        }

        int length = serializer->ReadLength2(source);

        while (length > 0) {
            auto bufferWithContext = stateHandler->getBuffer(channelInfo);
            
            try {
                while (length > 0 && bufferWithContext.buffer_->isWritable()) {
                    length -= serializer->ReadData2(source, bufferWithContext.buffer_, length);
                }
            } catch (const std::exception& e){
                bufferWithContext.close();
                throw;
            }
            stateHandler->recover(channelInfo, oldSubtaskIndex, bufferWithContext);
        }
    }

    void readChunkByFsForInputChannel(
        std::ifstream& source,
        long long sourceOffset,
        std::shared_ptr<InputChannelRecoveredStateHandler> stateHandler,
        const InputChannelInfo &channelInfo,
        int oldSubtaskIndex)
    {
        if (source.tellg() != sourceOffset) {
            source.seekg(sourceOffset);
        }

        int length = serializer->ReadLength(source);

        while (length > 0) {
            auto bufferWithContext = newTempInputBufferWithContext();
            try {
                while (length > 0 && bufferWithContext.buffer_->isWritable()) {
                    length -= serializer->ReadData(source, bufferWithContext.buffer_, length);
                }
            } catch (const std::exception& e){
                bufferWithContext.close();
                throw;
            }
            stateHandler->recover(channelInfo, oldSubtaskIndex, bufferWithContext);
        }
    }

    void readChunkByFsForResultSubpartition(
        std::ifstream& source,
        long long sourceOffset,
        std::shared_ptr<ResultSubpartitionRecoveredStateHandler> stateHandler,
        const ResultSubpartitionInfoPOD &channelInfo,
        int oldSubtaskIndex)
    {
        if (source.tellg() != sourceOffset) {
            source.seekg(sourceOffset);
        }

        int length = serializer->ReadLength(source);

        while (length > 0) {
            auto bufferWithContext = stateHandler->getBuffer(channelInfo);

            try {
                while (length > 0 && bufferWithContext.buffer_->isWritable()) {
                    length -= serializer->ReadData(source, bufferWithContext.buffer_, length);
                }
            } catch (const std::exception& e){
                bufferWithContext.close();
                throw;
            }
            stateHandler->recover(channelInfo, oldSubtaskIndex, bufferWithContext);
        }
    }
};
class SequentialChannelStateReaderOPImpl : public SequentialChannelStateReader {
public:
    void readInputData(const std::vector<std::shared_ptr<InputGate>> &inputGates) override
    {}
    void readOutputData(const std::vector<std::shared_ptr<ResultPartitionWriter>> &writers, bool notifyAndBlockOnCompletion) override
    {}
    void close() override
    {}
};

using InputChannelHandleGroup =
            std::map<std::string, std::pair<std::shared_ptr<StreamStateHandle>, std::vector<InputChannelStateHandle>>>;

using ResultSubpartitionHandleGroup =
        std::map<std::string, std::pair<std::shared_ptr<StreamStateHandle>, std::vector<ResultSubpartitionStateHandle>>>;
class SequentialChannelStateReaderImpl : public SequentialChannelStateReader {
public:
    explicit SequentialChannelStateReaderImpl(std::shared_ptr<TaskStateSnapshot> taskStateSnapshot, std::shared_ptr<OmniTaskBridge> omniTaskBridge)
            :taskStateSnapshot(taskStateSnapshot), serializer(std::make_shared<ChannelStateSerializerImpl>()),
            omniTaskBridge_(omniTaskBridge)
    {
        chunkReader = std::make_shared<ChannelStateChunkReader>(serializer);
    }

    void readInputData(const std::vector<std::shared_ptr<InputGate>> &inputGates) override;

    void readOutputData(const std::vector<std::shared_ptr<ResultPartitionWriter>> &writers,
        bool notifyAndBlockOnCompletion) override;

    void close() override;

    void readByInputChannel(
        std::shared_ptr<InputChannelRecoveredStateHandler> stateHandler,
        const InputChannelHandleGroup &streamStateHandleListMap);

    void readByResultSubpartition(
        std::shared_ptr<ResultSubpartitionRecoveredStateHandler> stateHandler,
        const ResultSubpartitionHandleGroup &streamStateHandleListMap);
    void readSequentiallyByInputChannel(
        std::shared_ptr<StreamStateHandle> streamStateHandle,
        std::vector<InputChannelStateHandle> channelStateHandles,
        std::shared_ptr<InputChannelRecoveredStateHandler> stateHandle);
    void readSequentiallyByResultSubpartition(
        std::shared_ptr<StreamStateHandle> streamStateHandle,
        std::vector<ResultSubpartitionStateHandle> channelStateHandles,
        std::shared_ptr<ResultSubpartitionRecoveredStateHandler> stateHandle);
    InputChannelHandleGroup groupByDelegateByInputChannel(
        std::vector<std::shared_ptr<OperatorSubtaskState>> states);
    ResultSubpartitionHandleGroup groupByDelegateByResultSubpartition(
        std::vector<std::shared_ptr<OperatorSubtaskState>> states);
    std::vector<std::shared_ptr<OperatorSubtaskState>> streamSubtaskStates();
private:
    std::vector<RescaledOffset<InputChannelInfo>> extractOffsetsSortedByInputChannel(
    const std::vector<InputChannelStateHandle>& channelStateHandles);
    std::vector<RescaledOffset<ResultSubpartitionInfoPOD>> extractOffsetsSortedByResultSubpartition(
    const std::vector<ResultSubpartitionStateHandle>& channelStateHandles);
    std::vector<RescaledOffset<InputChannelInfo>> extractOffsetsByInputChannel(const InputChannelStateHandle& handle);
    std::vector<RescaledOffset<ResultSubpartitionInfoPOD>> extractOffsetsByResultSubpartition(const ResultSubpartitionStateHandle& handle);
    std::shared_ptr<TaskStateSnapshot> taskStateSnapshot;
    std::shared_ptr<ChannelStateSerializer> serializer;
    std::shared_ptr<ChannelStateChunkReader> chunkReader;
    std::shared_ptr<OmniTaskBridge> omniTaskBridge_;
};
}
#endif // SEQUENTIAL_CHANNEL_STATE_READER_IMPL_H