* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "pyascir_common_utils.h"
#include <google/protobuf/text_format.h>
#include "proto/af_ir.pb.h"
#include "nlohmann/json.hpp"
#include "graph/ascendc_ir/utils/asc_graph_utils.h"
#include "graph/detail/model_serialize_imp.h"
#include "common/ge_common/debug/log.h"
#include "attribute_group/attr_group_shape_env.h"
#include "common/platform_context.h"
#include "pyascir_types.h"
#include "ascgen_log.h"
#include "common/scope_tracing_recorder.h"
namespace pyascir {
bool ShapeInfoDeserialize(const std::string to_be_deserialized, PyObject *py_obj) {
try {
auto shape_info = reinterpret_cast<pyascir::ShapeInfo::Object *>(py_obj);
nlohmann::json j = nlohmann::json::parse(to_be_deserialized);
if (j.is_object()) {
for (const auto &[dim_name, dim_value] : j.items()) {
if (dim_value.is_string()) {
shape_info->shape_info[dim_name] = dim_value.get<std::string>();
LOG_PRINT("parse shape info %s : %s", dim_name.c_str(), shape_info->shape_info[dim_name].c_str());
} else {
ERROR_PRINT("parse shape info not string %s : %s", dim_name.c_str(), dim_value.dump().c_str());
return false;
}
}
}
return true;
} catch (const nlohmann::json::parse_error &e) {
PyErr_SetString(PyExc_RuntimeError, "ShapeInfo parse fail");
return false;
}
}
bool OutputSymbolShapeDeserialize(PyObject *output_shape_obj, std::vector<std::vector<std::string>> &output_shape) {
std::vector<std::string> inner_vec;
size_t output_shape_obj_size = PyList_Size(output_shape_obj);
for (size_t i = 0UL; i < output_shape_obj_size; i++) {
PyObject *inner_list = PyList_GetItem(output_shape_obj, i);
if (PyList_Check(inner_list) == kPythonFail) {
ERROR_PRINT("OutputSymbolShape inner error, expected a list of lists");
return false;
}
size_t inner_size = PyList_Size(inner_list);
for (size_t j = 0UL; j < inner_size; j++) {
PyObject *item = PyList_GetItem(inner_list, j);
if (PyUnicode_Check(item) == kPythonFail) {
ERROR_PRINT("OutputSymbolShape inner error, expected a unicode string");
return false;
}
std::string item_str = PyUnicode_AsUTF8(item);
inner_vec.push_back(item_str);
}
output_shape.push_back(inner_vec);
inner_vec.clear();
}
return true;
}
bool ComputeGraphDeserialize(const std::string to_be_deserialized, PyObject* py_obj) {
pyascir::HintComputeGraph::Init(py_obj, nullptr, nullptr);
auto compute_graph = reinterpret_cast<pyascir::HintComputeGraph::Object *>(py_obj);
GE_CHK_BOOL_RET_SPECIAL_STATUS(compute_graph->compute_graph == nullptr, false, "compute_graph is nullptr");
af::ModelSerializeImp serialize_imp;
af::proto::GraphDef graph_def;
GE_CHK_BOOL_RET_SPECIAL_STATUS(!google::protobuf::TextFormat::ParseFromString(to_be_deserialized, &graph_def), false,
"ComputeGraph ParseFromString fail");
GE_CHK_BOOL_RET_SPECIAL_STATUS(!serialize_imp.UnserializeGraph(compute_graph->compute_graph, graph_def), false,
"ModelSerializeImp deserialize ComputeGraph fail");
const auto shape_env_attr = compute_graph->compute_graph->GetAttrsGroup<af::ShapeEnvAttr>();
if (shape_env_attr != nullptr) {
SetCurShapeEnvContext(shape_env_attr);
}
LOG_PRINT("ComputeGraphDeserialize finish");
return true;
}
PyObject *UtilsDeserialize(PyObject *self_pyobject, PyObject *args, PyObject *kwds)
{
(void)self_pyobject;
(void)kwds;
std::string type_graph = "asc_graph";
std::string type_shape_info = "symbol_source_info";
std::string type_compute_graph = "compute_graph";
const char* type = nullptr;
const char* obj = nullptr;
std::string type_str;
std::string obj_str;
if (PyArg_ParseTuple(args, "ss", &type, &obj) == kPythonFail) {
return PyErr_Format(PyExc_TypeError, "UtilsDeserialize param parse failed");
}
type_str = std::string(type);
obj_str = std::string(obj);
LOG_PRINT("UtilsDeserialize type: %s, obj: %s", type_str.c_str(), obj_str.c_str());
if (type_str == type_graph) {
af::AscGraph tmp_graph("fused_graph");
auto ret = af::AscGraphUtils::DeserializeFromReadable(obj_str, tmp_graph);
if (ret != 0) {
return PyErr_Format(PyExc_TypeError, "HintGraph DeserializeFromReadable fail");
}
PyObject* hint_graph_obj = pyascir::HintGraph::New(&pyascir::HintGraph::type, nullptr, nullptr);
GE_CHK_BOOL_RET_SPECIAL_STATUS(hint_graph_obj == nullptr, PyErr_Format(PyExc_TypeError, "HintGraph new fail"),
"HintGraph new fail");
PyObject *name_obj = PyUnicode_FromString(tmp_graph.GetName().c_str());
PyObject *args = PyTuple_Pack(1, name_obj);
auto ret_init = pyascir::HintGraph::Init(hint_graph_obj, args, nullptr);
GE_CHK_BOOL_RET_SPECIAL_STATUS(ret_init != 0, PyErr_Format(PyExc_TypeError, "HintGraph init fail"),
"HintGraph init fail");
Py_DECREF(args);
auto hint_graph = reinterpret_cast<pyascir::HintGraph::Object *>(hint_graph_obj);
PY_ASSERT(hint_graph->graph->CopyFrom(tmp_graph));
return reinterpret_cast<PyObject *>(hint_graph);
} else if (type_str == type_shape_info) {
PyObject* shape_info = pyascir::ShapeInfo::New(&pyascir::ShapeInfo::type, nullptr, nullptr);
GE_CHK_BOOL_RET_SPECIAL_STATUS(shape_info == nullptr, PyErr_Format(PyExc_TypeError, "ShapeInfo new fail"),
"ShapeInfo new fail");
if (!pyascir::ShapeInfoDeserialize(obj_str, shape_info)) {
ERROR_PRINT("ShapeInfo Deserialize fail");
PyErr_Format(PyExc_TypeError, "ShapeInfo Deserialize fail");
return nullptr;
}
return shape_info;
} else if (type_str == type_compute_graph) {
PyObject* hint_compute_graph_obj = pyascir::HintComputeGraph::New(&pyascir::HintComputeGraph::type, nullptr, nullptr);
GE_CHK_BOOL_RET_SPECIAL_STATUS(hint_compute_graph_obj == nullptr,
PyErr_Format(PyExc_TypeError, "HintComputeGraph new fail"), "HintComputeGraph new fail");
if (!pyascir::ComputeGraphDeserialize(obj_str, hint_compute_graph_obj)) {
ERROR_PRINT("HintComputeGraph Deserialize fail");
return PyErr_Format(PyExc_TypeError, "HintComputeGraph Deserialize fail");
}
return hint_compute_graph_obj;
}
return PyErr_Format(PyExc_TypeError, "value of type is invalid");
}
bool PyListToVector(PyObject *list, std::vector<std::string> &vec) {
if (PyList_Check(list) == kPythonFail) {
return false;
}
size_t list_size = PyList_Size(list);
for (size_t i = 0U; i < list_size; i++) {
PyObject *item = PyList_GetItem(list, i);
if (PyUnicode_Check(item) == kPythonFail) {
return false;
}
std::string item_str = PyUnicode_AsUTF8(item);
vec.push_back(item_str);
}
return true;
}
PyObject *UtilsReportDurations(PyObject *self_pyobject, PyObject *args, PyObject *kwds) {
(void)self_pyobject;
(void)args;
(void)kwds;
ReportTracingRecordDuration(ge::TracingModule::kAutoFuseBackend);
Py_RETURN_NONE;
}
PyObject *UtilsDurationRecord(PyObject *self_pyobject, PyObject *args, PyObject *kwds) {
(void)self_pyobject;
(void)kwds;
PyObject* target_list_obj = nullptr;
long long start;
long long duration;
if (PyArg_ParseTuple(args, "OLL", &target_list_obj, &start, &duration) == kPythonFail) {
return PyErr_Format(PyExc_TypeError, "UtilsDurationRecord param parse failed");
}
std::vector<std::string> va_args;
PY_ASSERT(PyListToVector(target_list_obj, va_args), "target param is invalid");
if ((start < 0L) || (duration < 0L)) {
return PyErr_Format(PyExc_TypeError, "duration param is invalid");
}
TracingRecordDuration(ge::TracingModule::kAutoFuseBackend, va_args, static_cast<uint64_t>(start),
static_cast<uint64_t>(duration));
Py_RETURN_NONE;
}
PyObject *UtilsSetPlatform(const PyObject *self_pyobject, PyObject *args, const PyObject *kwds) {
(void)self_pyobject;
(void)kwds;
const char *platform = nullptr;
long long vector_core_num = 0;
long long ub_size = 0;
if (PyArg_ParseTuple(args, "sLL", &platform, &vector_core_num, &ub_size) == kPythonFail) {
return PyErr_Format(PyExc_TypeError, "UtilsSetPlatform param parse failed, expected string, long, long");
}
PY_ASSERT_NOTNULL(platform);
ge::PlatformInfo platform_info;
platform_info.soc_ver = std::string(platform);
platform_info.aiv_num = static_cast<int64_t>(vector_core_num);
platform_info.ub_size = static_cast<int64_t>(ub_size);
ge::PlatformContext::GetInstance().SetPlatformInfo(platform_info);
Py_RETURN_NONE;
}
}