/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include <sstream>
#include <mki/operation.h>
#include <asdops/ops.h>
#include <atbops/ops.h>
#include "atb/utils/tensor_util.h"
#include "atb/runner/mki_node_implement.h"
#include "atb/runner/kernel_graph.h"

namespace atb {
std::string KernelGraph::ToString() const
{
    std::stringstream ss;
    for (size_t i = 0; i < inTensors.size(); ++i) {
        ss << "inTensors[" << i << "]: " << TensorUtil::AsdOpsTensorToString(inTensors[i]) << std::endl;
    }
    for (size_t i = 0; i < outTensors.size(); ++i) {
        ss << "outTensors[" << i << "]: " << TensorUtil::AsdOpsTensorToString(outTensors[i]) << std::endl;
    }
    for (size_t i = 0; i < internalTensors.size(); ++i) {
        ss << "internalTensors[" << i << "]: " << TensorUtil::AsdOpsTensorToString(internalTensors[i]) << std::endl;
    }
    for (size_t i = 0; i < nodes.size(); ++i) {
        for (size_t j = 0; j < nodes[i].inTensors.size(); ++j) {
            ss << "node[" << i << "] inTensors[" << j
               << "]: " << TensorUtil::AsdOpsTensorToString(*nodes[i].inTensors[j]) << std::endl;
        }
        for (size_t j = 0; j < nodes[i].outTensors.size(); ++j) {
            ss << "node[" << i << "] outTensors[" << j
               << "]: " << TensorUtil::AsdOpsTensorToString(*nodes[i].outTensors[j]) << std::endl;
        }
    }
    return ss.str();
}

void KernelGraph::Init()
{
    for (auto &node : nodes) {
        node.Reset();
        node.inTensorsType.reserve(node.inTensors.size());
        node.inTensorsType.resize(node.inTensors.size());
        node.outTensorsType.reserve(node.outTensors.size());
        node.outTensorsType.resize(node.outTensors.size());

        for (size_t i = 0; i < node.inTensors.size(); i++) {
            auto inTensor = node.inTensors.at(i);
            if (IsInternalTensor(inTensor)) {
                node.inTensorsType.at(i) = TensorType::INTERMEDIATE_TENSOR;
            } else {
                node.inTensorsType.at(i) = TensorType::IN_TENSOR;
            }
        }

        for (size_t i = 0; i < node.outTensors.size(); i++) {
            auto outTensor = node.outTensors.at(i);
            if (IsInternalTensor(outTensor)) {
                node.outTensorsType.at(i) = TensorType::INTERMEDIATE_TENSOR;
            } else {
                node.outTensorsType.at(i) = TensorType::OUT_TENSOR;
            }
        }
    }
}

bool KernelGraph::IsInternalTensor(const Mki::Tensor *tensor) const
{
    for (auto &internalTensor : internalTensors) {
        if (tensor == &internalTensor) {
            return true;
        }
    }

    return false;
}

void KernelGraphNode::Reset()
{
    if (impl) {
        impl->Reset();
    }
}

bool KernelGraphNode::CreateImplement()
{
    Mki::Operation *op = AsdOps::Ops::Instance().GetOperationByName(opDesc.opName);
    if (op) {
        impl = std::make_shared<MkiNodeImplement>(op, inferShapePreFunc);
    } else {
        op = AtbOps::Ops::Instance().GetOperationByName(opDesc.opName);
        if (op) {
            impl = std::make_shared<MkiNodeImplement>(op, mkiInferShapePreFunc);
        } else {
            impl.reset();
        }
    }

    if (!impl) {
        ATB_LOG(ERROR) << "node " << opDesc.opName << " implement is null";
        return false;
    }
    return true;
}

std::string KernelGraphNode::GetName() const
{
    return impl == nullptr ? opDesc.opName : impl->GetName();
}
} // namespace atb