/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#ifndef OMNISTREAM_RAWKEYEDSTATEINPUTSTREAMPROXY_H
#define OMNISTREAM_RAWKEYEDSTATEINPUTSTREAMPROXY_H

#pragma once

#include <algorithm>
#include <cstdint>
#include <cstddef>
#include <memory>
#include <string>
#include <vector>

#include "core/fs/FSDataInputStream.h"
#include "core/include/common.h"
#include "core/memory/DataInputView.h"
#include "runtime/checkpoint/TaskStateSnapshotSerializer.h"
#include "runtime/state/KeyGroupsStateHandle.h"
#include "runtime/state/bridge/OmniTaskBridge.h"

/**
 * Input counterpart of KeyedStateCheckpointOutputStream for legacy raw keyed timer restore.
 *
 * Some OmniStream serializers, such as StringValue::readString(), read directly from the
 * DataInputView backing buffer via getData()/getPosition()/setPosition().  Therefore this
 * proxy must be buffer-backed instead of a thin stream wrapper.  For remote checkpoint files,
 * the buffer is filled by opening the Java Flink FSDataInputStream through OmniTaskBridge, in
 * the same spirit as CheckpointStateOutputStreamProxy writes through OmniAdaptor.
 */
class RawKeyedStateInputStreamProxy : public DataInputView {
public:
    explicit RawKeyedStateInputStreamProxy(std::shared_ptr<FSDataInputStream> inputStream)
    {
        loadFromFsInputStream(std::move(inputStream));
    }

    RawKeyedStateInputStreamProxy(
        std::shared_ptr<omnistream::OmniTaskBridge> omniTaskBridge,
        const std::shared_ptr<KeyGroupsStateHandle> &keyGroupsStateHandle)
        : omniTaskBridge_(std::move(omniTaskBridge))
    {
        if (keyGroupsStateHandle == nullptr) {
            INFO_RELEASE("Error: RawKeyedStateInputStreamProxy Raw keyed state handle is null.");
            THROW_LOGIC_EXCEPTION("Raw keyed state handle is null.")
        }

        auto inMemoryBytes = keyGroupsStateHandle->AsBytesIfInMemory();
        if (inMemoryBytes.has_value()) {
            data_ = std::move(inMemoryBytes.value());
            position_ = 0;
            return;
        }

        if (omniTaskBridge_ == nullptr) {
            INFO_RELEASE(
                "Error: RawKeyedStateInputStreamProxy Cannot restore raw keyed state without OmniTaskBridge for remote state handle.");
            THROW_LOGIC_EXCEPTION("Cannot restore raw keyed state without OmniTaskBridge for remote state handle.")
        }

        auto handleJson = TaskStateSnapshotSerializer::parseKeyGroupsStateHandle(keyGroupsStateHandle);
        inputStream_ = omniTaskBridge_->getSavepointInputStream(to_string(handleJson));
        if (inputStream_ == nullptr) {
            INFO_RELEASE(
                "Error: RawKeyedStateInputStreamProxy Failed to open raw keyed state input stream through OmniTaskBridge.");
            THROW_LOGIC_EXCEPTION("Failed to open raw keyed state input stream through OmniTaskBridge.")
        }

        loadFromOmniAdaptorStream();
        closeInputStream();
    }

    ~RawKeyedStateInputStreamProxy() override
    {
        closeInputStream();
    }

    void seek(int64_t offset)
    {
        if (offset < 0 || static_cast<size_t>(offset) > data_.size()) {
            INFO_RELEASE("Error: seek Invalid raw keyed state seek offset: " << offset << ", dataSize=" << data_.size());
            THROW_LOGIC_EXCEPTION("Invalid raw keyed state seek offset: " << offset
                << ", dataSize=" << data_.size())
        }
        position_ = static_cast<size_t>(offset);
    }

    int readUnsignedByte() override
    {
        ensureAvailable(1, "readUnsignedByte");
        return data_[position_++] & 0xff;
    }

    uint8_t readByte() override
    {
        return static_cast<uint8_t>(readUnsignedByte());
    }

    int readUnsignedShort() override
    {
        ensureAvailable(2, "readUnsignedShort");
        int b1 = readUnsignedByte();
        int b2 = readUnsignedByte();
        return (b1 << 8) | b2;
    }

    int readInt() override
    {
        ensureAvailable(4, "readInt");
        uint32_t value = (static_cast<uint32_t>(data_[position_]) << 24) |
                         (static_cast<uint32_t>(data_[position_ + 1]) << 16) |
                         (static_cast<uint32_t>(data_[position_ + 2]) << 8) |
                         static_cast<uint32_t>(data_[position_ + 3]);
        position_ += 4;
        return static_cast<int>(value);
    }

    int64_t readLong() override
    {
        ensureAvailable(8, "readLong");
        uint64_t value = 0;
        for (int i = 0; i < 8; ++i) {
            value = (value << 8) | static_cast<uint64_t>(data_[position_++]);
        }
        return static_cast<int64_t>(value);
    }

    double readDouble() override
    {
        INFO_RELEASE("Error: Raw keyed state timer restore does not support readDouble.");
        THROW_LOGIC_EXCEPTION("Raw keyed state timer restore does not support readDouble.")
    }

    bool readBoolean() override
    {
        return readUnsignedByte() != 0;
    }

    void readFully(uint8_t *buffer, int capacity, int offset, int length) override
    {
        if (buffer == nullptr || offset < 0 || length < 0 || offset + length > capacity) {
            INFO_RELEASE("Error: readFully Invalid readFully bounds for raw keyed state.");
            THROW_LOGIC_EXCEPTION("Invalid readFully bounds for raw keyed state.")
        }
        ensureAvailable(static_cast<size_t>(length), "readFully");
        std::copy(data_.begin() + static_cast<std::ptrdiff_t>(position_),
            data_.begin() + static_cast<std::ptrdiff_t>(position_ + static_cast<size_t>(length)),
            buffer + offset);
        position_ += static_cast<size_t>(length);
    }

    std::string readUTF() override
    {
        int utflen = readUnsignedShort();
        if (utflen == 0) {
            return "";
        }
        ensureAvailable(static_cast<size_t>(utflen), "readUTF");
        std::string result(reinterpret_cast<const char *>(data_.data() + position_), static_cast<size_t>(utflen));
        position_ += static_cast<size_t>(utflen);
        return result;
    }

    void *GetBuffer() override
    {
        return data_.empty() ? nullptr : data_.data();
    }

    const uint8_t *getData() override
    {
        return data_.empty() ? nullptr : data_.data();
    }

    size_t getPosition() override
    {
        return position_;
    }

    void setPosition(size_t position) override
    {
        if (position > data_.size()) {
            INFO_RELEASE("Error: setPosition Invalid raw keyed state position: "
                << position << ", dataSize=" << data_.size());
            THROW_LOGIC_EXCEPTION("Invalid raw keyed state position: " << position
                << ", dataSize=" << data_.size())
        }
        position_ = position;
    }

private:
    static constexpr size_t READ_CHUNK_SIZE = 4096;

    void ensureAvailable(size_t bytes, const std::string &operation)
    {
        if (position_ > data_.size() || bytes > data_.size() - position_) {
            INFO_RELEASE("Error: ensureAvailable EOF while " << operation << " raw keyed state. position="
                << position_ << ", required=" << bytes << ", dataSize=" << data_.size());
            THROW_LOGIC_EXCEPTION("EOF while " << operation << " raw keyed state. position="
                << position_ << ", required=" << bytes << ", dataSize=" << data_.size())
        }
    }

    void loadFromFsInputStream(std::shared_ptr<FSDataInputStream> inputStream)
    {
        if (inputStream == nullptr) {
            INFO_RELEASE("Error: loadFromFsInputStream Raw keyed state input stream is null.");
            THROW_LOGIC_EXCEPTION("Raw keyed state input stream is null.")
        }

        std::vector<uint8_t> chunk(READ_CHUNK_SIZE);
        while (true) {
            int read = inputStream->Read(chunk, 0, static_cast<int>(chunk.size()));
            if (read < 0) {
                break;
            }
            if (read == 0) {
                break;
            }
            data_.insert(data_.end(), chunk.begin(), chunk.begin() + read);
        }
        position_ = 0;
    }

    void loadFromOmniAdaptorStream()
    {
        std::vector<uint8_t> chunk(READ_CHUNK_SIZE);
        while (true) {
            int read = omniTaskBridge_->ReadSavepointInputStream(inputStream_,
                reinterpret_cast<int8_t *>(chunk.data()), 0, chunk.size());
            if (read < 0) {
                break;
            }
            if (read == 0) {
                break;
            }
            data_.insert(data_.end(), chunk.begin(), chunk.begin() + read);
        }
        position_ = 0;
    }

    void closeInputStream()
    {
        if (inputStream_ != nullptr && omniTaskBridge_ != nullptr) {
            omniTaskBridge_->closeSavepointInputStream(inputStream_);
            inputStream_ = nullptr;
        }
    }

    std::shared_ptr<omnistream::OmniTaskBridge> omniTaskBridge_;
    jobject inputStream_ = nullptr;
    std::vector<uint8_t> data_;
    size_t position_ = 0;
};

#endif // OMNISTREAM_RAWKEYEDSTATEINPUTSTREAMPROXY_H