* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/inference_context.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/ge_context.h"
#include "graph/resource_context_mgr.h"
namespace af {
class ShapeAndTypeImpl {
public:
ShapeAndTypeImpl() = default;
~ShapeAndTypeImpl() = default;
ShapeAndTypeImpl(const Shape &shape, const DataType data_type) : shape_(shape), data_type_(data_type) {}
private:
Shape shape_;
DataType data_type_ = DT_UNDEFINED;
friend class ShapeAndType;
};
struct InnerInferenceContext {
private:
std::vector<std::string> marks;
std::vector<std::vector<ShapeAndType>> input_handle_shapes_and_types;
std::vector<std::vector<ShapeAndType>> output_handle_shapes_and_types;
std::set<AscendString> changed_resource_keys;
std::set<AscendString> relied_resource_keys;
ResourceContextMgr *resource_context_mgr = nullptr;
friend class InferenceContext;
};
ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(); }
ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) {
shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(shape, data_type);
}
void ShapeAndType::SetShape(const Shape &shape) {
if (shape_and_type_impl_ != nullptr) {
shape_and_type_impl_->shape_ = shape;
}
}
void ShapeAndType::SetType(DataType data_type) {
if (shape_and_type_impl_ != nullptr) {
shape_and_type_impl_->data_type_ = data_type;
}
}
Shape ShapeAndType::GetShape() const {
if (shape_and_type_impl_ != nullptr) {
return shape_and_type_impl_->shape_;
}
return Shape();
}
DataType ShapeAndType::GetDataType() const {
if (shape_and_type_impl_ != nullptr) {
return shape_and_type_impl_->data_type_;
}
return DT_UNDEFINED;
}
InferenceContext::InferenceContext(std::unique_ptr<InnerInferenceContext> &inner_context) {
inner_inference_context_ = std::move(inner_context);
}
std::unique_ptr<InferenceContext> InferenceContext::Create(void *resource_context_mgr) {
std::unique_ptr<InnerInferenceContext> inner_context = ComGraphMakeUnique<InnerInferenceContext>();
if (inner_context == nullptr) {
return nullptr;
}
inner_context->resource_context_mgr = PtrToPtr<void, ResourceContextMgr>(resource_context_mgr);
return std::unique_ptr<InferenceContext>(new (std::nothrow) InferenceContext(inner_context));
}
void InferenceContext::SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) {
inner_inference_context_->input_handle_shapes_and_types.swap(shapes_and_types);
}
const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetInputHandleShapesAndTypes() const {
return inner_inference_context_->input_handle_shapes_and_types;
}
const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetOutputHandleShapesAndTypes() const {
return inner_inference_context_->output_handle_shapes_and_types;
}
void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types) {
inner_inference_context_->output_handle_shapes_and_types = shapes_and_types;
}
void InferenceContext::SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) {
inner_inference_context_->output_handle_shapes_and_types.swap(shapes_and_types);
}
void InferenceContext::SetMarks(const std::vector<std::string> &marks) { inner_inference_context_->marks = marks; }
void InferenceContext::SetMarks(const std::vector<AscendString> &marks) {
std::vector<std::string> impl_marks;
for (const auto &mark : marks) {
if (mark.GetString() != nullptr) {
impl_marks.emplace_back(mark.GetString());
}
}
inner_inference_context_->marks = impl_marks;
}
const std::vector<std::string> &InferenceContext::GetMarks() const { return inner_inference_context_->marks; }
void InferenceContext::GetMarks(std::vector<AscendString> &marks) const {
for (auto &str_mark : inner_inference_context_->marks) {
marks.emplace_back(str_mark.c_str());
}
}
ResourceContext *InferenceContext::GetResourceContext(const ge::AscendString &key) {
if (inner_inference_context_->resource_context_mgr == nullptr) {
return nullptr;
}
return inner_inference_context_->resource_context_mgr->GetResourceContext(key.GetString());
}
graphStatus InferenceContext::SetResourceContext(const ge::AscendString &key, ResourceContext *resource_context) {
if (std::string(key.GetString()).empty()) {
GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param.");
return GRAPH_PARAM_INVALID;
}
if (inner_inference_context_->resource_context_mgr == nullptr) {
GELOGE(GRAPH_FAILED, "No resource context mgr exist, resource context can not deliver in graph."
"Please check session initialized success or not.");
return GRAPH_FAILED;
}
(void)inner_inference_context_->resource_context_mgr->SetResourceContext(key.GetString(), resource_context);
return GRAPH_SUCCESS;
}
graphStatus InferenceContext::AddChangedResourceKey(const ge::AscendString &key) {
if (std::string(key.GetString()).empty()) {
GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param.");
return GRAPH_PARAM_INVALID;
}
(void)inner_inference_context_->changed_resource_keys.insert(key.GetString());
return GRAPH_SUCCESS;
}
void InferenceContext::ClearChangedResourceKeys() {
inner_inference_context_->changed_resource_keys.clear();
}
const std::set<ge::AscendString> &InferenceContext::GetChangedResourceKeys() const {
return inner_inference_context_->changed_resource_keys;
}
graphStatus InferenceContext::RegisterReliedOnResourceKey(const ge::AscendString &key) {
if (std::string(key.GetString()).empty()) {
GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param.");
return GRAPH_PARAM_INVALID;
}
(void)inner_inference_context_->relied_resource_keys.insert(key.GetString());
return GRAPH_SUCCESS;
}
const std::set<ge::AscendString> &InferenceContext::GetReliedOnResourceKeys() const {
return inner_inference_context_->relied_resource_keys;
}
}