/**

 * Copyright (c) 2025 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */



/*!

 * \file process_atomic_checker.cpp

 * \brief Checker for ProcessAtomic pass

 */



#include "process_atomic_checker.h"

#include "interface/operation/attribute.h"

#include "interface/operation/operation.h"

#include "tilefwk/tilefwk_op.h"

#include "passes/pass_log/pass_log.h"



#define MODULE_NAME "ProcessAtomicChecker"



namespace npu {

namespace tile_fwk {



Status ProcessAtomicChecker::DoPreCheck(Function& function)

{

    APASS_LOG_INFO_F(Elements::Function, "PreCheck for ProcessAtomic.");

    if (CheckGraphLoop(function) != SUCCESS) {

        APASS_LOG_ERROR_F(Elements::Function, "Loopcheck failed; Please check if there is a Loop.");

        return FAILED;

    }

    if (CheckCompleteness(function) != SUCCESS) {

        APASS_LOG_ERROR_F(Elements::Function, "CheckCompleteness for function[%d] failed!", function.GetFuncMagic());

        return FAILED;

    }

    for (const auto& op : function.Operations()) {

        if (ProcessPreCheck(op) != SUCCESS) {

            APASS_LOG_ERROR_F(

                Elements::Operation, "PreCheck for ProcessAtomic failed.%s", GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    APASS_LOG_INFO_F(Elements::Operation, "PreCheck for ProcessAtomic success.");

    return SUCCESS;

}



Status ProcessAtomicChecker::DoPostCheck(Function& function)

{

    APASS_LOG_INFO_F(Elements::Function, "PostCheck for ProcessAtomic.");

    if (CheckCompleteness(function) != SUCCESS) {

        APASS_LOG_ERROR_F(Elements::Function, "CheckCompleteness for function[%d] failed!", function.GetFuncMagic());

        return FAILED;

    }

    if (CheckNoReduceAcc(function) != SUCCESS) {

        return FAILED;

    }

    if (CheckNoAtomicRMW(function) != SUCCESS) {

        return FAILED;

    }

    APASS_LOG_INFO_F(Elements::Operation, "PostCheck for ProcessAtomic success.");

    return SUCCESS;

}



Status ProcessAtomicChecker::ProcessPreCheck(const Operation& op)

{

    if (op.GetOpcode() == Opcode::OP_A_MUL_B || op.GetOpcode() == Opcode::OP_A_MULACC_B) {

        if (CheckMulOpValidity(op) != SUCCESS) {

            return FAILED;

        }

    }

    if (op.GetOpcode() == Opcode::OP_REDUCE_ACC) {

        if (CheckReduceAccOpValidity(op) != SUCCESS) {

            return FAILED;

        }

    }

    if (op.GetOpcode() == Opcode::OP_ATOMIC_RMW) {

        if (ValidateAtomicRMW(op) != SUCCESS) {

            APASS_LOG_ERROR_F(

                Elements::Operation, "Op[%d] validation failed; Please check the atomic rmw operation validity.%s",

                op.GetOpMagic(), GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckMulOpValidity(const Operation& op)

{

    if (op.GetOOperands().size() != 1) {

        APASS_LOG_ERROR_F(

            Elements::Operation,

            "Invalid op: [%d] has output num not equal to one; Please check if the output num is one.%s",

            op.GetOpMagic(), GetFormatBacktrace(op).c_str());

        return FAILED;

    }

    auto output = op.GetOOperands().front();

    if ((output->GetMemoryTypeOriginal() != MemoryType::MEM_L0C) || (*output->GetConsumers().begin() == nullptr)) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] has invalid output tensor[%d]; Please check if the output tensor is vaild.%s",

            op.GetOpMagic(), output->magic, GetFormatBacktrace(op).c_str());

        return FAILED;

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckReduceAccOpValidity(const Operation& op)

{

    if (op.GetIOperands().size() < 1) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] has input num less than 1; Please check the input num.%s", op.GetOpMagic(),

            GetFormatBacktrace(op).c_str());

        return FAILED;

    }

    if (op.GetOOperands().size() != 1) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] has output num not equal to one; Please check if the output num for is one.%s",

            op.GetOpMagic(), GetFormatBacktrace(op).c_str());

        return FAILED;

    }

    for (const auto& in : op.GetIOperands()) {

        if (in->GetMemoryTypeOriginal() != MemoryType::MEM_DEVICE_DDR) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] has non-DDR input tensor[%d]; Please check the memory type of the input tensor.%s",

                op.GetOpMagic(), in->magic, GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    for (const auto& out : op.GetOOperands()) {

        if (out->GetMemoryTypeOriginal() != MemoryType::MEM_DEVICE_DDR) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] has non-DDR output tensor[%d]; Please check the memory type of the output tensor.%s",

                op.GetOpMagic(), out->magic, GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::ValidateAtomicRMW(const Operation& op)

{

    if (op.GetIOperands().size() < 1) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] has input producers num less than 1; Please check the input num.",

            op.GetOpMagic());

        return FAILED;

    }

    if (CheckAtomicRMWMemoryType(op) != SUCCESS) {

        return FAILED;

    }

    if (CheckAtomicRMWShape(op) != SUCCESS) {

        return FAILED;

    }

    if (CheckAtomicRMWOffset(op) != SUCCESS) {

        return FAILED;

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckAtomicRMWMemoryType(const Operation& op)

{

    for (const auto& in : op.GetIOperands()) {

        if (in->GetMemoryTypeOriginal() != MemoryType::MEM_DEVICE_DDR) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] has non-DDR input tensor[%d]; Please check the memory type of the input tensor.",

                op.GetOpMagic(), in->magic);

            return FAILED;

        }

    }

    for (const auto& out : op.GetOOperands()) {

        if (out->GetMemoryTypeOriginal() != MemoryType::MEM_DEVICE_DDR) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] has non-DDR output tensor[%d]; Please check the memory type of the output tensor.",

                op.GetOpMagic(), out->magic);

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckAtomicRMWShape(const Operation& op)

{

    auto& inputShape = op.GetIOperands().front()->GetShape();

    auto& outputShape = op.GetOOperands().front()->GetShape();

    if (outputShape.size() < inputShape.size()) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] output shape size less than input shape size; Please check shape validity.",

            op.GetOpMagic());

        return FAILED;

    }

    for (size_t i = 0; i < inputShape.size(); ++i) {

        if (outputShape[i] < inputShape[i]) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] output shape[%zu]=%ld less than input shape[%zu]=%ld; Please check shape validity.",

                op.GetOpMagic(), i, outputShape[i], i, inputShape[i]);

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckAtomicRMWOffset(const Operation& op)

{

    auto assembleAttr = std::dynamic_pointer_cast<AssembleOpAttribute>(op.GetOpAttribute());

    if (assembleAttr == nullptr) {

        APASS_LOG_ERROR_F(

            Elements::Operation, "Op[%d] missing AssembleOpAttribute; Please check if offset attribute is set.",

            op.GetOpMagic());

        return FAILED;

    }

    auto& toOffset = assembleAttr->GetToOffset();

    for (size_t i = 0; i < toOffset.size(); ++i) {

        if (toOffset[i] < 0) {

            APASS_LOG_ERROR_F(

                Elements::Operation, "Op[%d] offset[%zu]=%ld is negative; Please check offset validity.",

                op.GetOpMagic(), i, toOffset[i]);

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckNoReduceAcc(Function& function)

{

    for (const auto& op : function.Operations()) {

        if (op.GetOpcode() == Opcode::OP_REDUCE_ACC) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] OP_REDUCE_ACC still exists after ProcessAtomic pass; "

                "Please check if the ReduceAcc was properly eliminated.%s",

                op.GetOpMagic(), GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    return SUCCESS;

}



Status ProcessAtomicChecker::CheckNoAtomicRMW(Function& function)

{

    for (const auto& op : function.Operations()) {

        if (op.GetOpcode() == Opcode::OP_ATOMIC_RMW) {

            APASS_LOG_ERROR_F(

                Elements::Operation,

                "Op[%d] OP_ATOMIC_RMW still exists after ProcessAtomic pass; "

                "Please check if the AtomicRMW was properly eliminated.%s",

                op.GetOpMagic(), GetFormatBacktrace(op).c_str());

            return FAILED;

        }

    }

    return SUCCESS;

}



} // namespace tile_fwk

} // namespace npu