/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "graph/manager/graph_context.h"

#include "graph/utils/graph_utils_ex.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"

namespace ge {
GraphContext::GraphContext(const GraphNodePtr &graph_node) {
  if (graph_node == nullptr) {
    GELOGE(GE_GRAPH_PARAM_NULLPTR, "graphNode is NULL!");
    return;
  }

  compute_graph_ = graph_node->GetComputeGraph();
  current_graph_id_ = graph_node->GetGraphId();

  if (compute_graph_ == nullptr) {
    std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph();
    if (graph == nullptr) {
      GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[Get][Graph] failed, compute_graph by graphNode is NULL!");
      return;
    }

    compute_graph_ = GraphUtilsEx::GetComputeGraph(*graph);
    return;
  }
}

Status GraphContext::SetComputeGraph(const GraphNodePtr &graph_node) {
  if (graph_node == nullptr) {
    REPORT_INNER_ERR_MSG("E19999", "Param graph_node is nullptr, check invalid");
    GELOGE(GE_GRAPH_PARAM_NULLPTR, "[Check][Param] graphNode is NULL!");
    return GE_GRAPH_PARAM_NULLPTR;
  }

  compute_graph_ = graph_node->GetComputeGraph();
  current_graph_id_ = graph_node->GetGraphId();

  if (compute_graph_ == nullptr) {
    std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph();
    if (graph == nullptr) {
      REPORT_INNER_ERR_MSG("E19999", "Param graph in graph_node is nullptr, check invalid");
      GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[Get][Graph] failed, compute_graph by graphNode is NULL!");
      return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL;
    }

    compute_graph_ = GraphUtilsEx::GetComputeGraph(*graph);
    return SUCCESS;
  }
  return SUCCESS;
}

Status GraphContext::Initialize(const std::map<std::string, std::string> &options) const {
  (void)options;
  return SUCCESS;
}

Status GraphContext::Finalize() const { return SUCCESS; }

Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTensor &returned_tensor) const {
  if (var_data_name.empty()) {
    REPORT_INNER_ERR_MSG("E19999", "Param var_data_name is empty, check invalid");
    GELOGE(GE_GRAPH_EMPTY_STRING_NAME, "[Check][Param] Variable data name is empty!");
    return GE_GRAPH_EMPTY_STRING_NAME;
  }

  if (GetVarNodeTensorTable().empty()) {
    REPORT_INNER_ERR_MSG("E19999", "VarNodeTensorTable is empty, var_data_name:%s, check invalid",
                       var_data_name.c_str());
    GELOGE(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, "[Check][Param] VarNodeTensorTable is empty, var_data_name:%s",
           var_data_name.c_str());
    return GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE;
  }
  for (auto &var_record : GetVarNodeTensorTable()) {
    if (var_data_name == std::get<0>(var_record.first)) {
      returned_tensor.SetTensorDesc(var_record.second.GetTensorDesc());
      auto ret = returned_tensor.SetData(var_record.second.GetData());
      if (ret != SUCCESS) {
        REPORT_INNER_ERR_MSG("E19999", "SetData to tensor fail, var_data_name:%s", var_data_name.c_str());
        GELOGE(ret, "[Set][Data] to Tensor failed, var_data_name:%s", var_data_name.c_str());
        return ret;
      }

      return SUCCESS;
    }
  }

  REPORT_INNER_ERR_MSG("E19999", "VarRecord with data_name:%s does not exist, check invalid",
                     var_data_name.c_str());
  GELOGE(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, "[Check][Param] VarRecord with data_name %s does NOT exist!",
         var_data_name.c_str());

  return GE_GRAPH_VARIABLE_DOES_NOT_EXIST;
}

VarNodeTensorTable &GraphContext::GetVarNodeTensorTable() {
  static VarNodeTensorTable _this;
  return _this;
}
}  // namespace ge