/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "dflow/compiler/data_flow_graph/data_flow_graph_utils.h"
#include "common/util/mem_utils.h"
#include "framework/common/framework_types_internal.h"
#include "dflow/flow_graph/data_flow_attr_define.h"
#include "graph_metadef/common/ge_common/util.h"

namespace ge {
Status DataFlowGraphUtils::CreateFlowNodeOpDesc(const std::string &op_name, uint32_t input_num, uint32_t output_num,
                                                OpDescPtr &op_desc) {
  op_desc = MakeShared<OpDesc>(op_name, FLOWNODE);
  GE_CHECK_NOTNULL(op_desc);
  GE_CHK_STATUS_RET(op_desc->AddDynamicInputDesc(dflow::ATTR_NAME_DATA_FLOW_INPUT, input_num),
                    "Failed to add inputs for op[%s], inputs num[%u].", op_name.c_str(), input_num);
  GE_CHK_STATUS_RET(op_desc->AddDynamicOutputDesc(dflow::ATTR_NAME_DATA_FLOW_OUTPUT, output_num),
                    "Failed to add outputs for op[%s], outputs num[%u].", op_name.c_str(), output_num);
  return SUCCESS;
}

Status DataFlowGraphUtils::CreateFlowFuncOpDesc(const std::string &op_name, uint32_t input_num, uint32_t output_num,
                                                OpDescPtr &op_desc) {
  op_desc = MakeShared<OpDesc>(op_name, FLOWFUNC);
  GE_CHECK_NOTNULL(op_desc);
  GE_CHK_BOOL_RET_STATUS(AttrUtils::SetInt(op_desc, "__dynamic_input_x_cnt", input_num), FAILED,
                         "Failed to set attr[__dynamic_input_x_cnt] for op[%s], value[%u].", op_name.c_str(),
                         input_num);
  GE_CHK_STATUS_RET(op_desc->AddDynamicInputDesc("x", input_num),
                    "Failed to set FlowFunc inputs desc for op[%s], inputs num [%u].", op_name.c_str(), input_num);
  GE_CHK_BOOL_RET_STATUS(AttrUtils::SetInt(op_desc, "__dynamic_output_x_cnt", output_num), FAILED,
                         "Failed to set attr[__dynamic_output_x_cnt] for op[%s], value[%u].", op_name.c_str(),
                         output_num);
  GE_CHK_STATUS_RET(op_desc->AddDynamicOutputDesc("y", output_num),
                    "Failed to set FlowFunc outputs for op[%s], outputs num [%u].", op_name.c_str(), output_num);
  return SUCCESS;
}

Status DataFlowGraphUtils::CreateBuiltInFunctionProcessPoint(const std::string &process_point_name,
                                                             const std::vector<dataflow::ProcessFunc> &udf_funcs,
                                                             const std::map<std::string, proto::AttrDef> &attrs,
                                                             dataflow::ProcessPoint &process_point) {
  process_point.set_name(process_point_name);
  process_point.set_type(dataflow::ProcessPoint_ProcessPointType_FUNCTION);
  process_point.set_is_built_in(true);
  for (const auto &udf_func : udf_funcs) {
    auto *func = process_point.add_funcs();
    GE_CHECK_NOTNULL(func);
    *func = udf_func;
  }
  auto *process_point_attrs = process_point.mutable_attrs();
  GE_CHECK_NOTNULL(process_point_attrs);
  for (const auto &attr : attrs) {
    (*process_point_attrs)[attr.first] = attr.second;
  }
  return SUCCESS;
}

Status DataFlowGraphUtils::BindProcessPointToFlowNode(dataflow::ProcessPoint &process_point, OpDescPtr &op_desc) {
  GE_CHECK_NOTNULL(op_desc);
  std::string name = op_desc->GetName();
  const auto input_size = op_desc->GetInputsSize();
  for (size_t i = 0U; i < input_size; ++i) {
    auto in_edge = process_point.add_in_edges();
    GE_CHECK_NOTNULL(in_edge);
    in_edge->set_node_name(name);
    in_edge->set_index(static_cast<int64_t>(i));
  }
  const auto output_size = op_desc->GetOutputsSize();
  for (size_t i = 0U; i < output_size; ++i) {
    auto out_edge = process_point.add_out_edges();
    GE_CHECK_NOTNULL(out_edge);
    out_edge->set_node_name(name);
    out_edge->set_index(static_cast<int64_t>(i));
  }
  std::string process_point_msg;
  GE_CHK_BOOL_RET_STATUS(process_point.SerializeToString(&process_point_msg), FAILED,
                         "Failed to serialize process point[%s] to string.", process_point.name().c_str());
  std::vector<std::string> attr_pps = {process_point_msg};
  GE_CHK_BOOL_RET_STATUS(AttrUtils::SetListStr(op_desc, dflow::ATTR_NAME_DATA_FLOW_PROCESS_POINTS, attr_pps), FAILED,
                         "Failed to set attr[%s] for node[%s].", dflow::ATTR_NAME_DATA_FLOW_PROCESS_POINTS,
                         name.c_str());
  return SUCCESS;
}

Status DataFlowGraphUtils::EnsureNMappingAttr(const ComputeGraphPtr &graph) {
  bool contains_n_mapping_node = false;
  (void)AttrUtils::GetBool(graph, dflow::ATTR_NAME_DATA_FLOW_CONTAINS_N_MAPPING_NODE, contains_n_mapping_node);
  if (!contains_n_mapping_node) {
    GE_CHK_BOOL_RET_STATUS(AttrUtils::SetBool(graph, dflow::ATTR_NAME_DATA_FLOW_CONTAINS_N_MAPPING_NODE, true), FAILED,
                           "Failed to set attr[%s] to graph[%s]", dflow::ATTR_NAME_DATA_FLOW_CONTAINS_N_MAPPING_NODE,
                           graph->GetName().c_str());
    GELOGI("set attr[%s] to graph[%s] success.", dflow::ATTR_NAME_DATA_FLOW_CONTAINS_N_MAPPING_NODE,
           graph->GetName().c_str());
  }
  return SUCCESS;
}
}  // namespace ge