/**
 * Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file infer_discontinuous_input_checker.cpp
 * \brief
 */

#include "infer_discontinuous_input_checker.h"
#include <queue>
#include <set>
#include "passes/pass_log/pass_log.h"
#include "passes/pass_utils/graph_utils.h"
#include "tilefwk/error_code.h"

#define MODULE_NAME "InferDiscontinuousInputChecker"

namespace npu {
namespace tile_fwk {
std::unordered_set<Opcode> inplaceNodes{
    Opcode::OP_VIEW, Opcode::OP_ASSEMBLE, Opcode::OP_RESHAPE, Opcode::OP_INDEX_OUTCAST};

Status checkAssemble(
    const std::unordered_map<LogicalTensorPtr, int64_t>& tensorMap,
    const std::unordered_map<LogicalTensorPtr, std::pair<Offset, Offset>>& offsetMap,
    std::unordered_map<int64_t, int64_t>& rawTensorSize)
{
    std::unordered_map<int, Offset> rawMagicToRawOffset;
    for (auto [logicTensor, rawMagic] : tensorMap) {
        auto shape = logicTensor->GetShape();
        int shapeSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
        rawTensorSize[rawMagic] -= shapeSize;
        size_t rawshapeSize = logicTensor->GetRawTensor()->GetRawShape().size();
        Offset rawOffset(rawshapeSize, 0);
        std::pair<Offset, Offset> p = offsetMap.at(logicTensor);
        for (size_t dim = 0; dim < rawshapeSize; dim++) {
            rawOffset[dim] = p.second[dim] - p.first[dim];
        }
        if (rawMagicToRawOffset.find(rawMagic) == rawMagicToRawOffset.end()) {
            rawMagicToRawOffset[rawMagic] = rawOffset;
        } else if (rawMagicToRawOffset[rawMagic] != rawOffset) {
            APASS_LOG_ERROR_C(
                TensorErr::TENSOR_SHAPE_MISMATCH, Elements::Operation,
                "LogicTensor(%d) relative position to rawTensor(%ld) changed after the assemble op.",
                logicTensor->GetMagic(), static_cast<long>(rawMagic));
            return FAILED;
        }
    }
    for (auto& [rawMagic, shape] : rawTensorSize) {
        if (shape != 0) {
            APASS_LOG_ERROR_C(
                TensorErr::TENSOR_SHAPE_MISMATCH, Elements::Tensor, "RawTensor(%ld) is not fully covered.",
                static_cast<long>(rawMagic));
            return FAILED;
        }
    }
    return SUCCESS;
}

Status checkView(Operation* op)
{
    for (const auto& logicTensor : op->GetIOperands()) {
        auto producers = logicTensor->GetProducers();
        if (producers.size() != 1 || (*producers.begin())->GetOpcode() != Opcode::OP_VIEW) {
            continue;
        }
        auto shape = logicTensor->GetShape();
        if (std::any_of(shape.begin(), shape.end(), [](int64_t num) { return num < 0; })) {
            continue;
        }
        if (logicTensor->GetMemoryTypeOriginal() == MemoryType::MEM_DEVICE_DDR) {
            APASS_LOG_ERROR_C(
                TensorErr::TENSOR_INVALID_MEMORY_TYPE, Elements::Tensor,
                "Tensor(%d) memory type is MEM_DEVICE_DDR, which is not supported for VIEW->ASSEMBLE case.",
                logicTensor->GetMagic());
            return FAILED;
        }
    }
    return SUCCESS;
}

Status checkTensor(const LogicalTensorPtr& tensor)
{
    std::unordered_map<int64_t, int64_t> rawTensorSize;
    std::unordered_map<LogicalTensorPtr, int64_t> tensorMap;
    std::unordered_map<LogicalTensorPtr, std::pair<Offset, Offset>> offsetMap;
    bool allAssemble = true;
    for (auto producer : tensor->GetProducers()) {
        if (inplaceNodes.find(producer->GetOpcode()) == inplaceNodes.end()) {
            continue;
        }
        if (producer->GetOpcode() != Opcode::OP_ASSEMBLE) {
            allAssemble = false;
            continue;
        }
        if (checkView(producer) != SUCCESS) {
            APASS_LOG_ERROR_F(Elements::Function, "CheckView Failed.");
            return FAILED;
        }
        std::shared_ptr<AssembleOpAttribute> attr =
            std::dynamic_pointer_cast<AssembleOpAttribute>(producer->GetOpAttribute());
        if (attr == nullptr) {
            APASS_LOG_ERROR_C(
                OperationErr::OP_NULL_POINTER, Elements::Operation, "Assemble op %d do not have attribute. %s",
                producer->GetOpMagic(), GetFormatBacktrace(producer).c_str());
            return FAILED;
        }
        LogicalTensorPtr inputTensor = *(producer->GetIOperands().begin());
        rawTensorSize[inputTensor->tensor->GetRawMagic()] = inputTensor->tensor->GetRawShapeSize();
        tensorMap[inputTensor] = inputTensor->GetRawTensor()->GetRawMagic();
        offsetMap[inputTensor] = std::make_pair(inputTensor->GetOffset(), attr->GetToOffset());
    }
    if (!allAssemble) {
        return SUCCESS;
    }
    if (checkAssemble(tensorMap, offsetMap, rawTensorSize) != SUCCESS) {
        APASS_LOG_ERROR_F(Elements::Function, "CheckAssemble Failed.");
        return FAILED;
    }

    return SUCCESS;
}

Status InferDisContinuousInputChecker::DoPostCheck(Function& function)
{
    APASS_LOG_INFO_F(Elements::Function, "PostCheck for DisContinuousInput.");
    if (CheckGraphLoop(function) != SUCCESS) {
        APASS_LOG_ERROR_F(Elements::Function, "Find loop");
        return FAILED;
    }

    auto tensorMap = GraphUtils::GetAllTensors(function);
    for (const auto& logicalTensor : tensorMap) {
        if (checkTensor(logicalTensor) != SUCCESS) {
            APASS_LOG_ERROR_F(Elements::Tensor, "Tensor(%d) CheckTensor Failed.", logicalTensor->GetMagic());
            return FAILED;
        }
    }

    return SUCCESS;
}
} // namespace tile_fwk
} // namespace npu