/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "graph/operator_factory_impl.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) {
  return OperatorFactoryImpl::CreateOperator(operator_name, operator_type);
}

Operator OperatorFactory::CreateOperator(const char_t *const operator_name, const char_t *const operator_type) {
  if ((operator_name == nullptr) || (operator_type == nullptr)) {
    REPORT_INNER_ERR_MSG("E18888", "Create Operator input parameter is nullptr, check invalid.");
    GELOGE(GRAPH_FAILED, "[Check][Param] Create Operator input parameter is nullptr.");
    return Operator();
  }
  const std::string op_name = operator_name;
  const std::string op_type = operator_type;
  return OperatorFactoryImpl::CreateOperator(op_name, op_type);
}

graphStatus OperatorFactory::GetOpsTypeList(std::vector<std::string> &all_ops) {
  return OperatorFactoryImpl::GetOpsTypeList(all_ops);
}

graphStatus OperatorFactory::GetOpsTypeList(std::vector<AscendString> &all_ops) {
  std::vector<std::string> all_op_types;
  if (OperatorFactoryImpl::GetOpsTypeList(all_op_types) != GRAPH_SUCCESS) {
    REPORT_INNER_ERR_MSG("E18888", "Get ops type list failed.");
    GELOGE(GRAPH_FAILED, "[Get][OpsTypeList] failed.");
    return GRAPH_FAILED;
  }
  for (auto &op_type : all_op_types) {
    all_ops.emplace_back(op_type.c_str());
  }
  return GRAPH_SUCCESS;
}

bool OperatorFactory::IsExistOp(const std::string &operator_type) {
  return OperatorFactoryImpl::IsExistOp(operator_type);
}

bool OperatorFactory::IsExistOp(const char_t *const operator_type) {
  if (operator_type == nullptr) {
    REPORT_INNER_ERR_MSG("E18888", "Operator type is nullptr, check invalid.");
    GELOGE(GRAPH_FAILED, "[Check][Param] Operator type is nullptr.");
    return false;
  }
  const std::string op_type = operator_type;
  return OperatorFactoryImpl::IsExistOp(op_type);
}

OperatorCreatorRegister::OperatorCreatorRegister(const std::string &operator_type, OpCreator const &op_creator) {
  (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator);
}

OperatorCreatorRegister::OperatorCreatorRegister(const char_t *const operator_type, OpCreatorV2 const &op_creator) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterOperatorCreator(op_type, op_creator);
}

InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type,
                                               const InferShapeFunc &infer_shape_func) {
  (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func);
}

InferShapeFuncRegister::InferShapeFuncRegister(const char_t *const operator_type,
                                               const InferShapeFunc &infer_shape_func) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterInferShapeFunc(op_type, infer_shape_func);
}

InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type,
                                                 const InferFormatFunc &infer_format_func) {
  (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func);
}

InferFormatFuncRegister::InferFormatFuncRegister(const char_t *const operator_type,
                                                 const InferFormatFunc &infer_format_func) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterInferFormatFunc(op_type, infer_format_func);
}

InferValueRangeFuncRegister::InferValueRangeFuncRegister(const char_t *const operator_type,
                                                         const WHEN_CALL when_call,
                                                         const InferValueRangeFunc &infer_value_range_func) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterInferValueRangeFunc(op_type, when_call, false, infer_value_range_func);
}

InferValueRangeFuncRegister::InferValueRangeFuncRegister(const char_t *const operator_type) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterInferValueRangeFunc(op_type);
}

VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) {
  (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func);
}

VerifyFuncRegister::VerifyFuncRegister(const char_t *const operator_type, const VerifyFunc &verify_func) {
  std::string op_type;
  if (operator_type != nullptr) {
    op_type = operator_type;
  }
  (void)OperatorFactoryImpl::RegisterVerifyFunc(op_type, verify_func);
}
}  // namespace ge