* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "framework/common/helper/model_helper.h"
#include "external/graph/custom_op.h"
#include "graph/custom_op_factory.h"
namespace ge {
Status ModelHelper::CollectUsedCustomOpTypes(const GeRootModelPtr &ge_root_model,
std::set<std::string> &used_custom_op_types) const {
if (ge_root_model->GetRootGraph() != nullptr) {
const auto &root_graph = ge_root_model->GetRootGraph();
for (const auto &node : root_graph->GetAllNodes()) {
const std::string op_type = node->GetType();
if (CustomOpFactory::IsExistOp(AscendString(op_type.c_str()))) {
used_custom_op_types.insert(op_type);
}
}
}
const auto &subgraph_map = ge_root_model->GetSubgraphInstanceNameToModel();
for (const auto &subgraph_pair : subgraph_map) {
const auto &ge_model = subgraph_pair.second;
if (ge_model == nullptr || ge_model->GetGraph() == nullptr) {
continue;
}
const auto &graph = ge_model->GetGraph();
if (graph == ge_root_model->GetRootGraph()) {
continue;
}
for (const auto &node : graph->GetAllNodes()) {
const std::string op_type = node->GetType();
if (CustomOpFactory::IsExistOp(AscendString(op_type.c_str()))) {
used_custom_op_types.insert(op_type);
}
}
}
return SUCCESS;
}
Status ModelHelper::SerializeCustomOpKernel(PortableOp *serializable_op, const std::string &op_type_str,
std::vector<uint8_t> &merged_buffers) const {
if (serializable_op == nullptr) {
GELOGE(FAILED, "[CUSTOM OP] serializable custom op is null, op_type:%s", op_type_str.c_str());
return FAILED;
}
std::vector<uint8_t> buffer;
const auto ret = serializable_op->Serialize(buffer);
if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "[CUSTOM OP] serialize failed, op_type:%s", op_type_str.c_str());
return ret;
}
if (buffer.empty()) {
GELOGW("[CUSTOM OP] serialized buffer is empty, skip, op_type:%s", op_type_str.c_str());
return SUCCESS;
}
CustomKernelItemHeader header;
header.magic = kCustomKernelItemMagic;
header.name_len = static_cast<uint32_t>(op_type_str.size());
header.bin_len = static_cast<uint32_t>(buffer.size());
const auto *header_ptr = reinterpret_cast<const uint8_t *>(&header);
merged_buffers.insert(merged_buffers.end(), header_ptr, header_ptr + sizeof(header));
merged_buffers.insert(merged_buffers.end(), op_type_str.begin(), op_type_str.end());
merged_buffers.insert(merged_buffers.end(), buffer.begin(), buffer.end());
GELOGD("[CUSTOM OP] Serialized custom op '%s', bin size:%zu", op_type_str.c_str(), buffer.size());
return SUCCESS;
}
Status ModelHelper::SaveCustomOpsPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper,
const GeRootModelPtr &ge_root_model) const {
std::set<std::string> used_custom_op_types;
GE_ASSERT_SUCCESS(CollectUsedCustomOpTypes(ge_root_model, used_custom_op_types));
if (used_custom_op_types.empty()) {
GELOGI("[CUSTOM OP] No custom ops used in graph, skip saving custom kernels partition.");
return SUCCESS;
}
bool has_serializable_custom_op = false;
bool has_non_serializable_custom_op = false;
std::vector<std::pair<std::string, PortableOp *>> serializable_ops;
serializable_ops.reserve(used_custom_op_types.size());
for (const auto &op_type_str : used_custom_op_types) {
auto op = CustomOpFactory::CreateOrGetCustomOp(AscendString(op_type_str.c_str()));
if (op == nullptr) {
GELOGE(FAILED, "[CUSTOM OP] create custom op failed, op_type:%s", op_type_str.c_str());
return FAILED;
}
auto *serializable_op = dynamic_cast<PortableOp *>(op);
if (serializable_op == nullptr) {
has_non_serializable_custom_op = true;
} else {
has_serializable_custom_op = true;
serializable_ops.emplace_back(op_type_str, serializable_op);
}
if (has_serializable_custom_op && has_non_serializable_custom_op) {
GELOGE(FAILED,
"[CUSTOM OP] graph contains both serializable and non-serializable custom ops.");
return FAILED;
}
}
std::vector<uint8_t> merged_buffers;
for (const auto &serializable_op : serializable_ops) {
GE_ASSERT_SUCCESS(SerializeCustomOpKernel(serializable_op.second, serializable_op.first, merged_buffers));
}
if (merged_buffers.empty()) {
GELOGI("[CUSTOM OP] no custom ops serialized, skip saving custom_ops partition.");
return SUCCESS;
}
GELOGI("[CUSTOM OP] custom ops partition size:%zu", merged_buffers.size());
return om_file_save_helper->AddOwnedPartition(ModelPartitionType::CUSTOM_OPS, std::move(merged_buffers), 0U);
}
Status ModelHelper::LoadCustomOps(const OmFileLoadHelper &om_load_helper) const {
ModelPartition custom_ops_partition;
if (om_load_helper.GetModelPartition(ModelPartitionType::CUSTOM_OPS, custom_ops_partition, 0U) != SUCCESS) {
GELOGI("[CUSTOM OP] custom ops partition not found, skip load.");
return SUCCESS;
}
if ((custom_ops_partition.data == nullptr) || (custom_ops_partition.size == 0U)) {
GELOGI("[CUSTOM OP] custom ops partition is empty, skip load.");
return SUCCESS;
}
GE_CHK_STATUS_RET(CustomOpFactory::LoadCustomOpsPartition(custom_ops_partition.data,
custom_ops_partition.size),
"[CUSTOM OP] Load custom ops partition failed.");
return SUCCESS;
}
}