* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/build/run_context_util.h"
#include <cinttypes>
#include "framework/common/util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "common/omg_util/omg_util.h"
#include "base/err_msg.h"
namespace ge {
Status RunContextUtil::InitMemInfo(uint8_t *data_mem_base, uint64_t data_mem_size,
std::map<int64_t, uint8_t *> mem_type_to_data_mem_base,
std::map<int64_t, uint64_t> mem_type_to_data_mem_size, uint8_t *weight_mem_base,
uint64_t weight_mem_size) {
if ((data_mem_size > 0) && (data_mem_base == nullptr)) {
REPORT_INNER_ERR_MSG("E19999", "InitMemInfo param data_mem_base is null but data_mem_size = %" PRIu64 "", data_mem_size);
GELOGE(PARAM_INVALID, "[Check][Param] InitMemInfo param data_mem_base is null but data_mem_size = %" PRIu64 ".",
data_mem_size);
return PARAM_INVALID;
}
if ((weight_mem_size > 0) && (weight_mem_base == nullptr)) {
REPORT_INNER_ERR_MSG("E19999", "InitMemInfo param weight_mem_base is null but weight_mem_size = %" PRIu64 "",
weight_mem_size);
GELOGE(PARAM_INVALID, "[Check][Param] InitMemInfo param weight_mem_base is null but weight_mem_size = %" PRIu64 ".",
weight_mem_size);
return PARAM_INVALID;
}
if (mem_type_to_data_mem_base.empty() || mem_type_to_data_mem_size.empty() ||
mem_type_to_data_mem_base.size() != mem_type_to_data_mem_size.size()) {
REPORT_INNER_ERR_MSG("E19999", "InitMemInfo param mem_type_to_data_mem_base size[%zu] "
"is not equal to the size of mem_type_to_data_mem_size[%zu].",
mem_type_to_data_mem_base.size(), mem_type_to_data_mem_size.size());
GELOGE(PARAM_INVALID,
"[Check][Param] InitMemInfo param mem_type_to_data_mem_base size[%zu] is not equal to the size of "
"mem_type_to_data_mem_size[%zu].", mem_type_to_data_mem_base.size(), mem_type_to_data_mem_size.size());
return PARAM_INVALID;
}
data_mem_base_ = data_mem_base;
data_mem_size_ = data_mem_size;
weight_mem_base_ = weight_mem_base;
weight_mem_size_ = weight_mem_size;
mem_type_to_data_mem_base_ = mem_type_to_data_mem_base;
mem_type_to_data_mem_size_ = mem_type_to_data_mem_size;
return SUCCESS;
}
Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &graph, Buffer &buffer,
const uint64_t session_id) {
GELOGD("Begin to Create RunContext, session_id = %" PRIu64 "", session_id);
if (graph == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Check param graph nullptr, session_id:%" PRIu64 ",", session_id);
GELOGE(PARAM_INVALID, "[Check][Param] CreateRunContext param graph is null. session_id=%" PRIu64 "", session_id);
return PARAM_INVALID;
}
uint32_t notify_num = 0;
if (!AttrUtils::GetInt(&model, ATTR_MODEL_NOTIFY_NUM, notify_num)) {
REPORT_INNER_ERR_MSG("E19999", "Get Attr:%s failed from model, session_id:%" PRIu64 ",", ATTR_MODEL_NOTIFY_NUM.c_str(),
session_id);
GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model, session_id:%" PRIu64 ",", ATTR_MODEL_NOTIFY_NUM.c_str(),
session_id);
return INTERNAL_ERROR;
}
GELOGD("Notify_num = %u", notify_num);
uint32_t event_num = 0;
if (!AttrUtils::GetInt(&model, ATTR_MODEL_EVENT_NUM, event_num)) {
REPORT_INNER_ERR_MSG("E19999", "Get Attr:%s failed from model, session_id:%" PRIu64 ",",
ATTR_MODEL_EVENT_NUM.c_str(), session_id);
GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model, session_id:%" PRIu64 ",",
ATTR_MODEL_EVENT_NUM.c_str(), session_id);
return INTERNAL_ERROR;
}
GELOGD("Event_num = %u", event_num);
uint32_t label_num = 0;
if (!AttrUtils::GetInt(&model, ATTR_MODEL_LABEL_NUM, label_num)) {
REPORT_INNER_ERR_MSG("E19999", "Get Attr:%s failed from model, session_id:%" PRIu64 ",",
ATTR_MODEL_LABEL_NUM.c_str(), session_id);
GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model, session_id:%" PRIu64 ",",
ATTR_MODEL_LABEL_NUM.c_str(), session_id);
return INTERNAL_ERROR;
}
GELOGD("Label_num = %u", label_num);
GELOGI("CreateRunContext: data_mem_base_ = %p, weight_mem_base_ = %p, memory_size = %" PRIu64 ", weight_size = %" PRIu64 "",
data_mem_base_, weight_mem_base_, data_mem_size_, weight_mem_size_);
PrintMemInfo();
run_context_ = {};
run_context_.sessionId = session_id;
run_context_.dataMemSize = data_mem_size_;
run_context_.dataMemBase = data_mem_base_;
run_context_.mem_type_data_mem_size = mem_type_to_data_mem_size_;
run_context_.mem_type_data_mem_base = mem_type_to_data_mem_base_;
run_context_.weightMemSize = weight_mem_size_;
run_context_.weightMemBase = weight_mem_base_;
run_context_.weightsBuffer = buffer;
return SUCCESS;
}
void RunContextUtil::PrintMemInfo() const {
for (auto iter : mem_type_to_data_mem_base_) {
GELOGD("CreateRunContext: memory type = %ld, data memory base = %p", iter.first, iter.second);
}
for (auto iter : mem_type_to_data_mem_size_) {
GELOGD("CreateRunContext: memory type = %ld, data memory size = %" PRIu64 "", iter.first, iter.second);
}
}
RunContext &RunContextUtil::GetRunContext() { return run_context_; }
}