/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "common/model/ge_model.h"
#include <utility>
#include "graph/debug/ge_attr_define.h"

namespace ge {
void GeModel::Init() {
  (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
  (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
  (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
  version_ = 0U;
  ClearWeightDataBuf();
  // default attrSize = 5
}

GeModel::GeModel() : AttrHolder() {
  Init();
}

const ComputeGraphPtr &GeModel::GetGraph() const { return this->graph_; }

void GeModel::SetGraph(const ComputeGraphPtr &graph) { this->graph_ = graph; }

std::shared_ptr<domi::ModelTaskDef> GeModel::GetModelTaskDefPtr() const { return this->task_; }

TBEKernelStore &GeModel::GetTBEKernelStore() { return this->tbe_kernel_store_; }

CustAICPUKernelStore &GeModel::GetCustAICPUKernelStore() { return this->cust_aicpu_kernel_store_; }

// use GetWeightData and GetWeightSize instead
Buffer GeModel::GetWeight() const { return this->weights_buffer_; }

uint8_t* GeModel::GetWeightData() const {
  if (this->weight_data_buffer_.data != nullptr) {
    return reinterpret_cast<uint8_t *>(this->weight_data_buffer_.data);
  }
  return GetWeight().GetData();
}

size_t GeModel::GetWeightSize() const {
  if (this->weight_data_buffer_.data != nullptr) {
    return this->weight_data_buffer_.length;
  }
  return GetWeight().GetSize();
}

void GeModel::SetWeightDataBuf(const DataBuffer &data_buffer) {
  this->weight_data_buffer_ = data_buffer;
}

void GeModel::ClearWeightDataBuf() {
  this->weight_data_buffer_.data = nullptr;
  this->weight_data_buffer_.length = 0U;
}

std::string GeModel::GetName() const { return this->name_; }

uint32_t GeModel::GetVersion() const { return this->version_; }

std::string GeModel::GetPlatformVersion() const { return this->platform_version_; }

uint8_t GeModel::GetPlatformType() const { return this->platform_type_; }

void GeModel::SetModelTaskDef(const std::shared_ptr<domi::ModelTaskDef> &task) { this->task_ = task; }

void GeModel::SetTBEKernelStore(const TBEKernelStore &tbe_kernel_store) {
  this->tbe_kernel_store_ = tbe_kernel_store;
}

void GeModel::SetCustAICPUKernelStore(const CustAICPUKernelStore &cust_aicpu_kernel_store) {
  this->cust_aicpu_kernel_store_ = cust_aicpu_kernel_store;
}

bool GeModel::LoadTBEKernelStore(const uint8_t *const data, const size_t len) {
  return tbe_kernel_store_.Load(data, len);
}

bool GeModel::LoadAICPUKernelStore(const uint8_t *const data, const size_t len) {
  return cust_aicpu_kernel_store_.Load(data, len);
}

void GeModel::SetWeight(const Buffer &weights_buffer) { this->weights_buffer_ = weights_buffer; }

void GeModel::SetName(const std::string &name) { this->name_ = name; }

void GeModel::SetVersion(const uint32_t version) { this->version_ = version; }

void GeModel::SetPlatformVersion(const std::string &platform_version) { this->platform_version_ = platform_version; }

void GeModel::SetPlatformType(const uint8_t platform_type) { this->platform_type_ = platform_type; }

void GeModel::SetAttrMap(const ProtoAttrMap &attrs) { attrs_ = attrs; }

ProtoAttrMap &GeModel::MutableAttrMap() { return attrs_; }

ConstProtoAttrMap &GeModel::GetAttrMap() const {
  return attrs_;
}

Status GeModel::GetSessionId(const uint32_t model_id, uint64_t &session_id) const {
  const auto it = model_id_to_session_id_map_.find(model_id);
  if (it != model_id_to_session_id_map_.end()) {
    session_id = it->second;
    return SUCCESS;
  }
  GELOGW("No session id were found with model id [%u].", model_id);
  return INTERNAL_ERROR;
}
}  // namespace ge