/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include "RecoveredChannelStateHandler.h"
#include "event/SubtaskConnectionDescriptor.h"
#include "io/network/api/serialization/EventSerializer.h"

namespace omnistream {

ResultSubpartitionRecoveredStateHandler::ResultSubpartitionRecoveredStateHandler(
    std::vector<std::shared_ptr<ResultPartitionWriter>> writers,
    bool notifyAndBlockOnCompletion,
    std::shared_ptr<InflightDataRescalingDescriptor> channelMapping)
    : writers_(std::move(writers)),
      notifyAndBlockOnCompletion_(notifyAndBlockOnCompletion),
      channelMapping_(std::move(channelMapping)) {}

ResultSubpartitionRecoveredStateHandler::~ResultSubpartitionRecoveredStateHandler()
{
    this->close();
}

ResultSubpartitionRecoveredStateHandler::BufferWithContext ResultSubpartitionRecoveredStateHandler::getBuffer(const ResultSubpartitionInfoPOD& subpartitionInfo)
{

    auto channels = getMappedChannels(subpartitionInfo);

    if (channels.empty()) {
        throw std::runtime_error("No mapped channels found");
    }
    BufferBuilder *bufferBuilder = channels.at(0)->requestBufferBuilderBlocking();
    return BufferWithContext(ChannelStateByteBuffer::wrap(bufferBuilder), bufferBuilder);
}

void ResultSubpartitionRecoveredStateHandler::recover(const ResultSubpartitionInfoPOD& subpartitionInfo, int oldSubtaskIndex, const BufferWithContext& bufferWithContext)
{
    BufferBuilder *bufferBuilder = bufferWithContext.context_;
    auto bufferConsumer = bufferBuilder->createBufferConsumerFromBeginning();
    bufferBuilder->finish();

    if (!bufferConsumer->isDataAvailable()) {
        return;
    }

    auto channels = getMappedChannels(subpartitionInfo);
    if (channels.empty()) {
        throw std::runtime_error("No mapped channels found in recover()");
    }

    try {
        for (const auto &item : channels){
            auto channelSelector = std::make_shared<SubtaskConnectionDescriptor>(subpartitionInfo.getSubPartitionIdx(),oldSubtaskIndex);
            INFO_RELEASE("send recover buffer :" << item->getSubpartitionInfo().toString());
            item->addRecovered(EventSerializer::ToBufferConsumer(channelSelector, false));
            item->addRecovered(bufferConsumer);
        }
    } catch (const std::exception& e){
        INFO_RELEASE("ResultSubpartitionRecoveredStateHandler::recover exception:" << e.what());
    }

    INFO_RELEASE("Recover state for partition" << ", subpartition " << subpartitionInfo.getSubPartitionIdx()
                                         << ", size " << bufferConsumer->getBufferSize()
                                         << ", mappedChannels=" << channels.size());
}

void ResultSubpartitionRecoveredStateHandler::close()
{
    for (auto& writer : writers_) {
        if (auto checkpointedWriter = std::dynamic_pointer_cast<CheckpointedResultPartition>(writer)) {
            checkpointedWriter->finishReadRecoveredState(notifyAndBlockOnCompletion_);
        }
    }
    LOG("Close ResultSubpartitionRecoveredStateHandler, finishReadRecoveredState writers size:" << writers_.size());
}

std::shared_ptr<CheckpointedResultSubpartition> ResultSubpartitionRecoveredStateHandler::getSubpartition(
        int partitionIndex,
        int subPartitionIdx)
{
    LOG("ResultSubpartitionRecoveredStateHandler getSubpartition111")
    auto writer = writers_.at(partitionIndex);

    auto checkpointedWriter = std::dynamic_pointer_cast<omnistream::CheckpointedResultPartition>(writer);
    if (!checkpointedWriter) {
        LOG("ResultSubpartitionRecoveredStateHandler getSubpartition222")
        throw std::runtime_error(
                "Cannot restore state to a non-checkpointable partition type, partitionIndex=" +
                std::to_string(partitionIndex));
    }
    LOG("ResultSubpartitionRecoveredStateHandler getSubpartition333")
    auto checkpointedSubpartition = checkpointedWriter->getCheckpointedSubpartition(subPartitionIdx);
    if (!checkpointedSubpartition) {
        LOG("ResultSubpartitionRecoveredStateHandler getSubpartition555")
        throw std::runtime_error(
                "Checkpointed subpartition is not a PipelinedSubpartition, partitionIndex=" +
                std::to_string(partitionIndex) + ", subPartitionIdx=" + std::to_string(subPartitionIdx));
    }
    LOG("ResultSubpartitionRecoveredStateHandler getSubpartition666")
    return checkpointedSubpartition;
}

std::vector<std::shared_ptr<CheckpointedResultSubpartition>> ResultSubpartitionRecoveredStateHandler::getMappedChannels(const ResultSubpartitionInfoPOD& subpartitionInfo)
{
    LOG("ResultSubpartitionRecoveredStateHandler getMappedChannels111")
    auto it = rescaledChannels_.find(subpartitionInfo);
    if (it != rescaledChannels_.end()) {
        return it->second;
    }
    LOG("getMappedChannels add InfoPOD: " << subpartitionInfo.toString());
    auto pipelinedSubpartitions = calculateMapping(subpartitionInfo);
    rescaledChannels_.emplace(subpartitionInfo, pipelinedSubpartitions);
    return pipelinedSubpartitions;
}

std::vector<std::shared_ptr<CheckpointedResultSubpartition>>
    ResultSubpartitionRecoveredStateHandler::calculateMapping(const ResultSubpartitionInfoPOD& info)
{
    int pIdx = info.getPartitionIdx();

    auto mapping = channelMapping_ ? channelMapping_->GetChannelMapping(pIdx)
                                   : IdentityRescaleMappings::SYMMETRIC_IDENTITY;

    // 纯恢复 / 未 rescale:直接按 identity 映射,不要去 invert SYMMETRIC_IDENTITY
    if (!mapping || mapping->isIdentity()) {
        return { getSubpartition(pIdx, info.getSubPartitionIdx()) };
    }

    if (oldToNewMappings_.find(pIdx) == oldToNewMappings_.end()) {
        oldToNewMappings_.emplace(pIdx, mapping->invert());
    }

    const auto& oldToNewMapping = oldToNewMappings_.at(pIdx);
//        std::vector<std::shared_ptr<PipelinedSubpartition>> subpartitions;
    std::vector<std::shared_ptr<CheckpointedResultSubpartition>> subpartitions;

    auto mappedIndexes = oldToNewMapping.getMappedIndexes(info.getSubPartitionIdx());
    for (int newIndex : mappedIndexes) {
        subpartitions.push_back(getSubpartition(pIdx, newIndex));
    }

    if (subpartitions.empty()) {
        LOG("ERROR: Recovered a buffer that has no mapping, partitionIdx=" << std::to_string(info.getPartitionIdx())
            << ", subPartitionIdx=" << std::to_string(info.getSubPartitionIdx()));
        throw std::runtime_error(
                "Recovered a buffer that has no mapping, partitionIdx=" +
                std::to_string(info.getPartitionIdx()) +
                ", subPartitionIdx=" + std::to_string(info.getSubPartitionIdx()));
    }
    return subpartitions;
}

InputChannelRecoveredStateHandler::~InputChannelRecoveredStateHandler()
{
    this->close();
}

RecoveredChannelStateHandler<InputChannelInfo, Buffer *>::BufferWithContext
    InputChannelRecoveredStateHandler::getBuffer(const InputChannelInfo &inputChannelInfo)
{

    auto channel = getMappedChannels(inputChannelInfo)[0];

    auto buffer = channel->requestBufferBlocking();
    // support nothing
    return BufferWithContext(ChannelStateByteBuffer::wrap(&*buffer), &*buffer);
}

void InputChannelRecoveredStateHandler::recover(const InputChannelInfo &inputChannelInfo,
    int oldSubtaskIndex,
    const BufferWithContext &bufferWithContext)
{
    auto buffer = bufferWithContext.context_;

    try {
        if (buffer->GetSize() > 0) {
            auto channels = getMappedChannels(inputChannelInfo);
            if (channels.empty()) {
                throw std::runtime_error("No mapped channels found in InputChannelRecoveredStateHandler::recover");
            }

            for (const auto &item : channels){
                INFO_RELEASE("send input recover:" << item->getChannelInfo().toString());
                item->onRecoveredStateBuffer(EventSerializer::toBuffer(std::make_shared<SubtaskConnectionDescriptor>(oldSubtaskIndex,inputChannelInfo.getInputChannelIdx()), false));
                item->onRecoveredStateBuffer2(buffer);
            }

            INFO_RELEASE("Recovered state for gate " << inputChannelInfo.getGateIdx()
                                            << ", channel " << inputChannelInfo.getInputChannelIdx()
                                            << ", size " << buffer->GetSize()
                                            << ", mappedChannels=" << channels.size());
        }
    } catch (const std::exception& e){
        buffer->RecycleBuffer();
        INFO_RELEASE("InputChannelRecoveredStateHandler::recover exception:" << e.what());
        throw std::runtime_error("failed to InputChannelRecoveredStateHandler recover:");
    }
}

void InputChannelRecoveredStateHandler::close()
{
    for (const auto& inputGate : inputGates) {
        inputGate->FinishReadRecoveredState();
    }
    LOG("Close InputChannelRecoveredStateHandler, finishReadRecoveredState inputGate size:" << inputGates.size());
}

std::shared_ptr<RecoveredInputChannel> InputChannelRecoveredStateHandler::getChannel(int gateIndex,
    int subPartitionIndex)
{
    auto inputChannel = inputGates.at(gateIndex)->getChannel(subPartitionIndex);
    auto inputChannel2 = std::dynamic_pointer_cast<RecoveredInputChannel>(inputChannel);
    if (!inputChannel2) {
        INFO_RELEASE("ERROR: Cannot restore state to a non-checkpointable partition type");
        throw std::runtime_error("Cannot restore state to a non-checkpointable partition type");
    }
    return inputChannel2;
}

std::vector<std::shared_ptr<RecoveredInputChannel>>
InputChannelRecoveredStateHandler::calculateMapping(InputChannelInfo info)
{
    LOG("InputChannelRecoveredStateHandler calculateMapping111")
    int pIdx = info.getGateIdx();

    auto mapping = channelMapping ? channelMapping->GetChannelMapping(pIdx)
                                  : IdentityRescaleMappings::SYMMETRIC_IDENTITY;

    if (!mapping || mapping->isIdentity()) {
        return { getChannel(pIdx, info.getInputChannelIdx()) };
    }

    if (oldToNewMappings.find(pIdx) == oldToNewMappings.end()) {
        oldToNewMappings.emplace(pIdx, mapping->invert());
    }

    const auto& oldToNewMapping = oldToNewMappings.at(pIdx);
    std::vector<std::shared_ptr<RecoveredInputChannel>> channels;

    auto mappedIndexes = oldToNewMapping.getMappedIndexes(info.getInputChannelIdx());
    for (int newIndex : mappedIndexes) {
        channels.push_back(getChannel(pIdx, newIndex));
    }

    if (channels.empty()) {
        throw std::runtime_error("Recovered a buffer that has no mapping");
    }
    LOG("InputChannelRecoveredStateHandler calculateMapping end")
    return channels;
}

std::vector<std::shared_ptr<RecoveredInputChannel>> InputChannelRecoveredStateHandler::getMappedChannels(
    InputChannelInfo channelInfo)
{
    auto it = rescaledChannels.find(channelInfo);
    if (it != rescaledChannels.end()) {
        return it->second;
    }
    LOG("getMappedChannels add ChannelInfo: " << channelInfo.toString());
    auto channels = calculateMapping(channelInfo);
    rescaledChannels.emplace(channelInfo, channels);
    return channels;
}
} // namespace omnistream