// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef INC_EXTERNAL_ATB_OPERATION_H
#define INC_EXTERNAL_ATB_OPERATION_H
#include <cstdint>
#include <functional>
#include <string>
#include "./types.h"
#include "./svector.h"
#include "./context.h"

//!
//! \file operation.h
//!
//! \brief 定义加速库Operation类
//!

namespace atb {

//!
//! \class Operation.
//!
//! \brief 加速库Operation类.
//!
//! 该接口类定义了算子准备与执行的需要的一系列的接口,通过创建Operation可以执行算子
//!
class Operation {
public:
    //! \brief 默认构造函数.
    Operation() = default;

    //! \brief 默认析构函数.
    virtual ~Operation() = default;
    //!
    //! \brief 获取创建的Operation的名字
    //!
    //! \return 返回字符串
    //!
    virtual std::string GetName() const = 0;

    //!
    //! \brief 根据输入Tensor描述信息推导出输出Tensor的描述信息。
    //!
    //! \param inTensorDescs 存放所有输入tensor描述信息的SVector
    //! \param outTensorDescs 存放所有输出tensor描述信息的SVector
    //!
    //! \return 状态值,如果成功,返回NO_ERROR
    //!
    virtual Status InferShape(const SVector<TensorDesc> &inTensorDescs, SVector<TensorDesc> &outTensorDescs) const = 0;

    //!
    //! \brief 获取Op/GraphOp输入Tensor个数接口。
    //!
    //! \return 整数值
    //!
    virtual uint32_t GetInputNum() const = 0;

    //!
    //! \brief 获取Op/GraphOp输出Tensor个数接口。
    //!
    //! \return 整数值
    //!
    virtual uint32_t GetOutputNum() const = 0;

    //!
    //! \brief Operation执行前的一系列准备工作
    //!
    //! 主要是计算Operation执行过程需要分配的内存空间workspaceSize
    //!
    //! \param variantPack 输入与输出Tensor
    //! \param workspaceSize 获取Operation执行需要分配的内存空间
    //! \param context Operation执行准备工作所在的上下文
    //!
    //! \return 状态值,如果成功,返回NO_ERROR
    //!
    virtual Status Setup(const VariantPack &variantPack, uint64_t &workspaceSize, Context *context) = 0;

    //!
    //! \brief Operation执行的流程
    //!
    //! 根据setup过程中得到的workspaceSize为Operation执行分配实际的内存,并执行Operation
    //!
    //! \param variantPack 输入与输出Tensor
    //! \param workspace Operation执行分配的内存地址
    //! \param workspaceSize Operation执行需要分配的内存空间
    //! \param context Operation执行所在的上下文
    //!
    //! \return 状态值,如果成功,返回NO_ERROR
    //!
    virtual Status Execute(const VariantPack &variantPack, uint8_t *workspace, uint64_t workspaceSize,
                           Context *context) = 0;
};

//!
//! \brief 创建Operation
//!
//! \param opParam 根据参数来指定调用的Operation
//! \param operation Operation指针地址
//!
//! \return 状态值,如果成功,返回NO_ERROR
//!
template <typename OpParam> Status CreateOperation(const OpParam &opParam, Operation **operation);

//!
//! \brief 销毁Operation
//!
//! \param operation Operation指针
//!
//! \return 状态值,如果成功,返回NO_ERROR
//!
//! \note 调用CreateOperation接口创建Operation,执行完Operation后需要调用DestroyOperation接口进行销毁。否则将导致内存泄漏。
//!
Status DestroyOperation(Operation *operation);

//!
//! \brief 拷贝Operation的Param参数
//!
//! \param operation Operation指针
//! \param opParam OpParam的引用,将返回operation的opParam浅拷贝
//!
//! \return 状态值,如果成功,返回NO_ERROR
//!
template <typename OpParam> Status CloneOperationParam(const Operation *operation, OpParam &opParam);

//!
//! \brief 更新Operation的Param参数
//!
//! \param operation Operation指针
//! \param opParam Operation新的param值
//!
//! \return 状态值,如果成功,返回NO_ERROR
//!
template <typename OpParam> Status UpdateOperationParam(Operation *operation, const OpParam &opParam);

} // namespace atb
#endif