/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_
#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "graph/operator_factory.h"
#include "register/infer_data_slice_registry.h"
#include "register/infer_axis_slice_registry.h"
#include "register/op_impl_kernel_registry.h"
#include "graph/op_desc.h"

namespace ge {
class ShapeInferOp;
using InferShapeV2Func = uint32_t (*)(const ge::Operator &op, const OpDescPtr &);
using InferDataTypeFunc = uint32_t (*)(const OpDescPtr &);
using InferShapeRangeFunc = uint32_t (*)(const ge::Operator &op, const OpDescPtr &);
using InferFormatV2Func = uint32_t (*)(const ge::Operator &, const OpDescPtr &);
using IsInferFormatV2RegisteredFunc = bool (*)(const OpDescPtr &);
using IsInferShapeV2RegisteredFunc = bool (*)(const OpDescPtr &);
using CustomOpInferShapeFunc = uint32_t (*)(ShapeInferOp *, const Operator &, const OpDescPtr &);
using CustomOpInferDataTypeFunc = uint32_t (*)(ShapeInferOp *, const OpDescPtr &);

struct InferValueRangePara {
 public:
  InferValueRangePara() = default;
  InferValueRangePara(const WHEN_CALL call, const bool cpu_kernel, const InferValueRangeFunc func) {
    is_initialized = true;
    use_cpu_kernel = cpu_kernel;
    when_call = call;
    infer_value_func = func;
  }
  friend class OpDescImpl;
  friend class InferValueRangePass;
  friend class OpDescUtilsEx;
  ~InferValueRangePara() = default;
private:
  bool is_initialized = false;
  bool use_cpu_kernel = false;
  WHEN_CALL when_call = INPUT_IS_DYNAMIC;
  InferValueRangeFunc infer_value_func = nullptr;
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl {
 public:
  static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type);

  static graphStatus GetOpsTypeList(std::vector<std::string> &all_ops);

  static bool IsExistOp(const std::string &operator_type);

  static InferShapeFunc GetInferShapeFunc(const std::string &operator_type);

  static InferShapeV2Func GetInferShapeV2Func();

  static InferDataTypeFunc GetInferDataTypeFunc();

  static InferShapeRangeFunc GetInferShapeRangeFunc();

  static InferFormatFunc GetInferFormatFunc(const std::string &operator_type);

  static InferValueRangePara GetInferValueRangePara(const std::string &operator_type);

  static VerifyFunc GetVerifyFunc(const std::string &operator_type);

  static InferDataSliceFunc GetInferDataSliceFunc(const std::string &operator_type);

  static InferAxisSliceFunc GetInferAxisSliceFunc(const std::string &operator_type);

  static InferAxisTypeInfoFunc GetInferAxisTypeInfoFunc(const std::string &operator_type);

  static void SetRegisterOverridable(const bool &is_overridable);

  static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator);

  static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreatorV2 const &op_creator);

  static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func);

  static void RegisterInferShapeV2Func(InferShapeV2Func const infer_shape_func);

  static void RegisterInferDataTypeFunc(InferDataTypeFunc const infer_data_type_func);

  static void RegisterInferShapeRangeFunc(InferShapeRangeFunc const infer_shape_range_func);

  static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func);

  static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func);

  static graphStatus RegisterInferDataSliceFunc(const std::string &operator_type,
                                                InferDataSliceFunc const infer_data_slice_func);

  static graphStatus RegisterInferValueRangeFunc(const std::string &operator_type);

  static graphStatus RegisterInferValueRangeFunc(const std::string &operator_type,
                                                 const WHEN_CALL when_call,
                                                 const bool use_cpu_kernel,
                                                 const InferValueRangeFunc &infer_value_range_func);

  static graphStatus RegisterInferAxisSliceFunc(const std::string &operator_type,
                                                const InferAxisSliceFunc &infer_axis_slice_func);

  static graphStatus RegisterInferAxisTypeInfoFunc(const std::string &operator_type,
                                                   const InferAxisTypeInfoFunc &infer_axis_type_info_func);

  static void RegisterInferFormatV2Func(InferFormatV2Func const infer_format_func);

  static InferFormatV2Func GetInferFormatV2Func();

  static void RegisterIsInferFormatV2RegisteredFunc(
      IsInferFormatV2RegisteredFunc const is_infer_format_v2_registered_func);

  static IsInferFormatV2RegisteredFunc GetIsInferFormatV2RegisteredFunc();

  static void RegisterIsInferShapeV2RegisteredFunc(IsInferShapeV2RegisteredFunc const is_infer_shape_v2_registered_func);

  static IsInferShapeV2RegisteredFunc GetIsInferShapeV2RegisteredFunc();

  static void RegisterCustomOpInferShapeFunc(CustomOpInferShapeFunc const custom_op_infer_shape_func);

  static CustomOpInferShapeFunc GetCustomOpInferShapeFunc();

  static void RegisterCustomOpInferDataTypeFunc(CustomOpInferDataTypeFunc const custom_op_infer_datatype_func);

  static CustomOpInferDataTypeFunc GetCustomOpInferDataTypeFunc();

  static void ReleaseRegInfo();

  static void BackupAndClearRegInfoOnce();

  static void MergeBackupCreatorsOnce();

  static std::shared_ptr<std::map<std::string, OpCreator>> operator_creators_;
  static std::shared_ptr<std::map<std::string, OpCreatorV2>> operator_creators_v2_;
  static std::shared_ptr<std::map<std::string, InferShapeFunc>> operator_infershape_funcs_;
  static std::shared_ptr<std::map<std::string, InferFormatFunc>> operator_inferformat_funcs_;
  static std::shared_ptr<std::map<std::string, VerifyFunc>> operator_verify_funcs_;
  static std::shared_ptr<std::map<std::string, InferDataSliceFunc>> operator_infer_data_slice_funcs_;
  static std::shared_ptr<std::map<std::string, InferValueRangePara>> operator_infer_value_range_paras_;
  static std::shared_ptr<std::map<std::string, InferAxisSliceFunc>> operator_infer_axis_slice_funcs_;
  static std::shared_ptr<std::map<std::string, InferAxisTypeInfoFunc>> operator_infer_axis_type_info_funcs_;
  static InferShapeV2Func operator_infer_shape_v2_func_;
  static InferDataTypeFunc operator_infer_datatype_func_;
  static InferShapeRangeFunc operator_infer_shape_range_func_;
  static InferFormatV2Func operator_infer_format_v2_func_;
  static IsInferFormatV2RegisteredFunc is_infer_format_v2_registered_func_;
  static IsInferShapeV2RegisteredFunc is_infer_shape_v2_registered_func_;
  static CustomOpInferShapeFunc custom_op_infer_shape_func_;
  static CustomOpInferDataTypeFunc custom_op_infer_datatype_func_;
};
}  // namespace ge

#endif  // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_