* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "rts_graph_optimizer.h"
#include <string>
#include "securec.h"
#include "common/util.h"
#include "common/constant/constant.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/ge_context.h"
#include "external/ge/ge_api_types.h"
#include "ops_kernel_store/op/op_factory.h"
#include "proto/task.pb.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "common/util/trace_manager/trace_manager.h"
#include "common/util/log.h"
#define RT_ERROR_INVALID_VALUE 0x07110001
const uint32_t RT_MEMORY_DEFAULT = 0x0U;
const uint32_t RT_MEMORY_TS = 0x40U;
using namespace ge;
namespace cce {
namespace runtime {
using domi::TaskDef;
using std::string;
using std::vector;
const std::vector<std::string> RTS_SUPPORTED_OP_TYPE = {
"MemcpyAsync",
"MemcpyAddrAsync",
"Cmo",
};
const std::map<std::string, uint32_t> RTS_SUPPORTED_CONDITION_OP_TYPE = {
{"StreamSwitchN", 1}, {"StreamSwitch", 2}, {"LabelSwitchByIndex", 1}};
const std::vector<std::string> MEM_TYPE_RANGE_NODES = {"MemcpyAddrAsync", "LabelSwitchByIndex", "LabelGoto",
"LabelGotoEx", "StreamSwitchN"};
RtsGraphOptimizer::RtsGraphOptimizer() {}
RtsGraphOptimizer::~RtsGraphOptimizer() {}
ge::Status RtsGraphOptimizer::Initialize(const map<std::string, std::string> &options,
ge::OptimizeUtility *const optimizeUtility) {
(void)options;
(void)optimizeUtility;
return SUCCESS;
}
ge::Status RtsGraphOptimizer::Finalize() {
return SUCCESS;
}
ge::Status RtsGraphOptimizer::OptimizeGraphPrepare(ComputeGraph &graph) {
RTS_LOGI("start rts graph optimizer prepare");
TraceOwnerGuard guard("RTS", "OptPrepare", graph.GetName());
for (auto nodePtr : graph.GetDirectNode()) {
if (!nodePtr) {
RTS_LOGW("OptimizeGraphPrepare: null node exits");
continue;
}
std::string nodeName = nodePtr->GetName();
auto opDescPtr = nodePtr->GetOpDesc();
if (!opDescPtr) {
RTS_LOGW("OptimizeGraphPrepare: desc of node[%s] is null", nodeName.c_str());
continue;
}
if (CheckSupportedOP(opDescPtr->GetType()) == RT_ERROR_NONE) {
const string name = "_format_agnostic";
bool bRet = AttrUtils::SetInt(opDescPtr, name, RTS_FORMAT_PAIRED_INPUT_OUTPUT);
RTS_LOGD("graph[%s]: node [%s] op type [%s] format agnostic value is set", graph.GetName().c_str(),
nodeName.c_str(), opDescPtr->GetType().c_str());
if (!bRet) {
RTS_LOGE("graph[%s]: node [%s] op type [%s] set format agnostic failed", graph.GetName().c_str(),
nodeName.c_str(), opDescPtr->GetType().c_str());
return INTERNAL_ERROR;
}
}
}
RTS_LOGD("end rts graph optimizer prepare");
return SUCCESS;
}
ge::Status RtsGraphOptimizer::OptimizeWholeGraph(ComputeGraph &graph) {
(void)graph;
return SUCCESS;
}
ge::Status RtsGraphOptimizer::GetAttributes(GraphOptimizerAttribute &attrs) const {
attrs.scope = ge::UNIT;
attrs.engineName = RTS_ENGINE_NAME;
return SUCCESS;
}
ge::Status RtsGraphOptimizer::OptimizeOriginalGraph(ComputeGraph &graph) {
(void)graph;
return SUCCESS;
}
ge::Status RtsGraphOptimizer::OptimizeFusedGraph(ComputeGraph &graph) {
(void)graph;
return SUCCESS;
}
ge::Status RtsGraphOptimizer::OptimizeGraphBeforeBuild(ComputeGraph &graph) {
RTS_LOGD("Enter OptimizeGraphBeforeBuild.");
return PorcMemtypeRange(graph);
}
rtError_t RtsGraphOptimizer::CheckSupportedOP(const std::string &sCollectiveType) {
const auto it = std::find(RTS_SUPPORTED_OP_TYPE.begin(), RTS_SUPPORTED_OP_TYPE.end(), sCollectiveType);
return (it != RTS_SUPPORTED_OP_TYPE.end()) ? RT_ERROR_NONE : RT_ERROR_INVALID_VALUE;
}
ge::Status RtsGraphOptimizer::PorcMemtypeRange(ge::ComputeGraph &graph) {
const int32_t kSocVersionLen = 50;
char_t version[kSocVersionLen] = {0};
auto ret = GetSocVersion(version, kSocVersionLen);
if (ret != SUCCESS) {
RTS_LOGE("GetSocVersion failed, graphId=%u.", graph.GetGraphID());
return FAILED;
}
if ((strncmp(version, "Ascend610", strlen("Ascend610")) != 0) &&
(strncmp(version, "Ascend310P1", strlen("Ascend310P1")) != 0) &&
(strncmp(version, "Ascend310P3", strlen("Ascend310P3")) != 0) &&
(strncmp(version, "BS9SX1AA", strlen("BS9SX1AA")) != 0) &&
(strncmp(version, "BS9SX1AB", strlen("BS9SX1AB")) != 0) &&
(strncmp(version, "BS9SX1AC", strlen("BS9SX1AC")) != 0)) {
RTS_LOGI("Soc version is [%s], no need to process the condition node", version);
return SUCCESS;
}
RTS_LOGI("Soc version is [%s]", version);
TraceOwnerGuard guard1("RTS", "SetMemTypeRange", graph.GetName());
ret = SetMemTypeRange(graph);
if (ret != SUCCESS) {
RTS_LOGE("Set mem type range of root graph:%s failed, retCode=%#x.", graph.GetName().c_str(), ret);
return FAILED;
}
for (auto &subGraph : graph.GetAllSubgraphs()) {
ret = SetMemTypeRange(*subGraph);
if (ret != SUCCESS) {
RTS_LOGE("Set mem type range of graph:%s failed, retCode=%#x", subGraph->GetName().c_str(), ret);
return FAILED;
}
}
TraceOwnerGuard guard2("RTS", "ProcConditionNode", graph.GetName());
ret = ProcConditionNode(graph);
if (ret != SUCCESS) {
RTS_LOGE("Process condition node of root graph:%s failed, retCode=%#x.", graph.GetName().c_str(), ret);
return FAILED;
}
for (auto &subGraph : graph.GetAllSubgraphs()) {
ret = ProcConditionNode(*subGraph);
if (ret != SUCCESS) {
RTS_LOGE("Process condition node of graph:%s failed, retCode=%#x.", subGraph->GetName().c_str(), ret);
return FAILED;
}
}
return SUCCESS;
}
ge::Status RtsGraphOptimizer::ProcConditionNode(ge::ComputeGraph &graph) {
if (graph.GetGraphUnknownFlag()) {
RTS_LOGD("Graph[%s] is unknown graph, skip.", graph.GetName().c_str());
return SUCCESS;
}
uint32_t inputNodeNum;
for (auto nodePtr : graph.GetDirectNode()) {
if (nodePtr == nullptr) {
RTS_LOGW("Node ptr is null.");
continue;
}
auto opDescPtr = nodePtr->GetOpDesc();
if (opDescPtr == nullptr) {
RTS_LOGW("Desc of node[%s] is null", nodePtr->GetName().c_str());
continue;
}
inputNodeNum = 0;
if (CheckConditionOPAndGetInputNodeNum(opDescPtr->GetType(), &inputNodeNum) != RT_ERROR_NONE) {
continue;
}
auto ret = InsertMemcpyAsyncNodeAndSetMemType(nodePtr, graph, inputNodeNum);
if (ret != SUCCESS) {
RTS_LOGE("Insert memcpy node failed, inputNodeNum [%u], retCode=%u", inputNodeNum, ret);
return FAILED;
}
}
return SUCCESS;
}
rtError_t RtsGraphOptimizer::CheckConditionOPAndGetInputNodeNum(const std::string &sCollectiveType,
uint32_t *inputNum) {
auto search = RTS_SUPPORTED_CONDITION_OP_TYPE.find(sCollectiveType);
if (search != RTS_SUPPORTED_CONDITION_OP_TYPE.end()) {
*inputNum = search->second;
return RT_ERROR_NONE;
}
RTS_LOGD("Can not find key:%s in op type map.", sCollectiveType.c_str());
return RT_ERROR_INVALID_VALUE;
}
ge::Status RtsGraphOptimizer::InsertMemcpyAsyncNodeFunc(const ge::NodePtr &nexNode, ge::NodePtr &memcpyAsyncNode,
ge::OpDescPtr &memcpyAsyncOpDesc, uint32_t index) {
auto inDataAnchor = nexNode->GetInDataAnchor(index);
if (inDataAnchor == nullptr) {
RTS_REPORT_CALL_ERROR("First in data anchor is null.");
return FAILED;
}
auto srcOutAnchor = inDataAnchor->GetPeerOutAnchor();
if (srcOutAnchor == nullptr) {
RTS_REPORT_CALL_ERROR("Source out anchor is null.");
return FAILED;
}
auto nexDesc = nexNode->GetOpDesc();
if (nexDesc == nullptr) {
RTS_REPORT_CALL_ERROR("Next op desc is null.");
return FAILED;
}
if (nexDesc->HasAttr(ATTR_NAME_STREAM_LABEL)) {
std::string nextIterationName;
if (!AttrUtils::GetStr(nexDesc, ATTR_NAME_STREAM_LABEL, nextIterationName)) {
RTS_REPORT_CALL_ERROR("Get next op[%s] attribute[ATTR_NAME_STREAM_LABEL] failed.", nexDesc->GetName().c_str());
return FAILED;
}
if (!AttrUtils::SetStr(memcpyAsyncOpDesc, ATTR_NAME_STREAM_LABEL, nextIterationName)) {
RTS_REPORT_CALL_ERROR("Set memcpyAsyncOpDesc[%s] attribute[ATTR_NAME_STREAM_LABEL] failed.",
memcpyAsyncOpDesc->GetName().c_str());
return FAILED;
}
}
ge::Status status = GraphUtils::RemoveEdge(srcOutAnchor, inDataAnchor);
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR(
"Graph remove edge failed, src index[%d], dst index[%d], dst node [%s],"
"retCode=%u.",
srcOutAnchor->GetIdx(), inDataAnchor->GetIdx(), inDataAnchor->GetOwnerNode()->GetName().c_str(), status);
return FAILED;
}
status = GraphUtils::AddEdge(srcOutAnchor, memcpyAsyncNode->GetInDataAnchor(0));
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR(
"Graph add edge failed, src index[%d], dst index[%d], dst node [%s],"
"retCode=%u.",
srcOutAnchor->GetIdx(), inDataAnchor->GetIdx(), inDataAnchor->GetOwnerNode()->GetName().c_str(), status);
return FAILED;
}
status = GraphUtils::AddEdge(memcpyAsyncNode->GetOutDataAnchor(0), inDataAnchor);
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR(
"Graph add edge failed, src index[%d], dst index[%d], dst node [%s],"
"retCode=%u.",
srcOutAnchor->GetIdx(), inDataAnchor->GetIdx(), inDataAnchor->GetOwnerNode()->GetName().c_str(), status);
return FAILED;
}
for (const NodePtr &inNode : nexNode->GetInControlNodes()) {
status = GraphUtils::RemoveEdge(inNode->GetOutControlAnchor(), nexNode->GetInControlAnchor());
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR(
"Graph remove edge failed, src index[%d], dst index[%d], dst node [%s],"
"retCode=%u.",
srcOutAnchor->GetIdx(), inDataAnchor->GetIdx(), inDataAnchor->GetOwnerNode()->GetName().c_str(), status);
return FAILED;
}
status = GraphUtils::AddEdge(inNode->GetOutControlAnchor(), memcpyAsyncNode->GetInControlAnchor());
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR(
"Graph add edge failed, src index[%d], dst index[%d], dst node [%s],"
"retCode=%u.",
srcOutAnchor->GetIdx(), inDataAnchor->GetIdx(), inDataAnchor->GetOwnerNode()->GetName().c_str(), status);
return FAILED;
}
}
RTS_LOGI("Insert memcpyAsync op success, src node: %s, dst node: %s", srcOutAnchor->GetOwnerNode()->GetName().c_str(),
nexNode->GetName().c_str());
return SUCCESS;
}
ge::Status RtsGraphOptimizer::InsertMemcpyAsyncNodeAndSetMemType(const ge::NodePtr &nexNode, ge::ComputeGraph &graph,
uint32_t inputNodeNum) {
if (nexNode == nullptr) {
RTS_LOGW("Next node is null.");
return FAILED;
}
uint32_t index;
for (index = 0; index < inputNodeNum; index++) {
#ifndef FEATURE_GE_API
ge::OpDescPtr memcpyAsyncOpDesc = CreateMemcpyAsyncOpByIndex(nexNode, index);
if (memcpyAsyncOpDesc == nullptr) {
RTS_LOGE("Create memcpy async op failed.");
return FAILED;
}
ge::NodePtr memcpyAsyncNode = graph.AddNode(memcpyAsyncOpDesc);
if (memcpyAsyncNode == nullptr) {
RTS_REPORT_CALL_ERROR("Insert memcpy node failed.");
return FAILED;
}
bool isLabelNode = false;
if (AttrUtils::GetBool(nexNode->GetOpDesc(), ATTR_NAME_RTS_LABEL_NODE, isLabelNode)) {
if (!AttrUtils::SetBool(memcpyAsyncOpDesc, ATTR_NAME_RTS_LABEL_NODE, isLabelNode)) {
RTS_REPORT_CALL_ERROR("Node [%s] set ATTR_NAME_RTS_LABEL_NODE failed", nexNode->GetName().c_str());
return FAILED;
}
}
ge::Status status = InsertMemcpyAsyncNodeFunc(nexNode, memcpyAsyncNode, memcpyAsyncOpDesc, index);
if (status != SUCCESS) {
RTS_LOGE("Graph add memcpyAsync node failed, retCode=%u.", status);
return FAILED;
}
status = SetMemTypeOutputRange(memcpyAsyncNode, 0);
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR("Set mem type output range attr failed, node [%s], retCode=%u.",
memcpyAsyncNode->GetName().c_str(), status);
return FAILED;
}
status = SetMemTypeInputRange(nexNode, index);
if (status != SUCCESS) {
RTS_REPORT_CALL_ERROR("Set mem type input range attr failed, node [%s], retCode=%u.", nexNode->GetName().c_str(),
status);
return FAILED;
}
#endif
}
return SUCCESS;
}
ge::OpDescPtr RtsGraphOptimizer::CreateMemcpyAsyncOpByIndex(const ge::NodePtr &nexNode, uint32_t index) {
if (nexNode == nullptr) {
RTS_REPORT_CALL_ERROR("Node is null.");
return nullptr;
}
string nodeName = nexNode->GetName() + std::to_string(index) + "_MemcpyAsync";
OpDescPtr opDesc;
try {
opDesc = std::make_shared<OpDesc>(nodeName.c_str(), "MemcpyAsync");
} catch (std::bad_alloc &) {
RTS_REPORT_INNER_ERROR("Make shared OpDesc failed, bad_alloc raised, node name=%s.", nodeName.c_str());
return nullptr;
} catch (...) {
RTS_REPORT_INNER_ERROR("Make shared OpDesc failed, other exception raised, node name=%s.", nodeName.c_str());
return nullptr;
}
RTS_LOGI("Create memcpyAsync op [%s] success.", opDesc->GetName().c_str());
ge::OpDescPtr nextNodeOpDesc = nexNode->GetOpDesc();
if (nextNodeOpDesc == nullptr) {
RTS_REPORT_CALL_ERROR("OpDesc of next node is invalid, node name=%s.", nodeName.c_str());
return nullptr;
}
size_t inputSize = nextNodeOpDesc->GetInputsSize();
if (inputSize == 0) {
RTS_REPORT_CALL_ERROR("The input size of node [%s] is 0.", nexNode->GetName().c_str());
return nullptr;
}
auto ret = opDesc->AddInputDesc(nextNodeOpDesc->GetInputDesc(index));
if (ret != GRAPH_SUCCESS) {
RTS_REPORT_CALL_ERROR("Create memcpyAsync op [%s], add input desc failed, retCode=%u.", nodeName.c_str(), ret);
return nullptr;
}
ret = opDesc->AddOutputDesc(nextNodeOpDesc->GetInputDesc(index));
if (ret != GRAPH_SUCCESS) {
RTS_REPORT_CALL_ERROR("Create memcpyAsync op[%s], add output desc failed, retCode=%u.", nodeName.c_str(), ret);
return nullptr;
}
return opDesc;
}
ge::Status RtsGraphOptimizer::SetMemTypeInputRange(const ge::NodePtr &node, uint32_t index) {
if (node == nullptr) {
RTS_LOGW("Node is null.");
return FAILED;
}
auto opDesc = node->GetOpDesc();
if (opDesc == nullptr) {
RTS_LOGW("Desc of node[%s] is null", node->GetName().c_str());
return FAILED;
}
if (index >= opDesc->GetInputsSize()) {
RTS_LOGW("Index[%u] large than input size[%zu]", index, opDesc->GetInputsSize());
return FAILED;
}
vector<int64_t> vMemoryType;
(void)ge::AttrUtils::GetListInt(opDesc, ATTR_NAME_INPUT_MEM_TYPE_LIST, vMemoryType);
for (size_t i = vMemoryType.size(); i < opDesc->GetInputsSize(); ++i) {
vMemoryType.push_back(RT_MEMORY_DEFAULT);
}
vMemoryType[index] = RT_MEMORY_TS;
if (!ge::AttrUtils::SetListInt(opDesc, ATTR_NAME_INPUT_MEM_TYPE_LIST, vMemoryType)) {
RTS_REPORT_CALL_ERROR("Node [%s] set ATTR_NAME_INPUT_MEM_TYPE_LIST failed", node->GetName().c_str());
return FAILED;
}
RTS_LOGD("Node [%s] set ATTR_NAME_INPUT_MEM_TYPE_LIST success.", node->GetName().c_str());
return SUCCESS;
}
ge::Status RtsGraphOptimizer::SetMemTypeOutputRange(const ge::NodePtr &node, uint32_t index) {
if (node == nullptr) {
RTS_LOGW("Node is null.");
return FAILED;
}
auto opDesc = node->GetOpDesc();
if (opDesc == nullptr) {
RTS_LOGW("Desc of node[%s] is null", node->GetName().c_str());
return FAILED;
}
if (index >= opDesc->GetOutputsSize()) {
RTS_LOGW("Index[%u] large than output size[%zu]", index, opDesc->GetOutputsSize());
return FAILED;
}
vector<int64_t> vMemoryType;
(void)ge::AttrUtils::GetListInt(opDesc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, vMemoryType);
for (size_t i = vMemoryType.size(); i < opDesc->GetOutputsSize(); ++i) {
vMemoryType.push_back(RT_MEMORY_DEFAULT);
}
vMemoryType[index] = RT_MEMORY_TS;
if (!ge::AttrUtils::SetListInt(opDesc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, vMemoryType)) {
;
RTS_REPORT_CALL_ERROR("Node [%s] set ATTR_NAME_OUTPUT_MEM_TYPE_LIST failed", node->GetName().c_str());
return FAILED;
}
RTS_LOGD("Node [%s] set ATTR_NAME_OUTPUT_MEM_TYPE_LIST success.", node->GetName().c_str());
return SUCCESS;
}
ge::Status RtsGraphOptimizer::SetMemTypeRange(const ge::NodePtr &node) {
if (node == nullptr) {
RTS_LOGW("Node is null.");
return FAILED;
}
auto opDesc = node->GetOpDesc();
if (opDesc == nullptr) {
RTS_LOGW("Desc of node[%s] is null", node->GetName().c_str());
return FAILED;
}
if (!AttrUtils::SetBool(opDesc, ge::ATTR_NAME_MEMORY_TYPE_RANGE, true)) {
RTS_REPORT_CALL_ERROR("Node [%s] set ATTR_NAME_MEMORY_TYPE_RANGE failed", node->GetName().c_str());
return FAILED;
}
RTS_LOGD("Node [%s] set ATTR_NAME_MEMORY_TYPE_RANGE success.", node->GetName().c_str());
return SUCCESS;
}
ge::Status RtsGraphOptimizer::SetMemTypeRange(ge::ComputeGraph &graph) {
if (graph.GetGraphUnknownFlag()) {
RTS_LOGD("Graph[%s] is unknown graph, skip.", graph.GetName().c_str());
return SUCCESS;
}
for (const auto &nodePtr : graph.GetDirectNode()) {
if (nodePtr == nullptr) {
RTS_LOGW("Node is null.");
continue;
}
auto opDescPtr = nodePtr->GetOpDesc();
if (opDescPtr == nullptr) {
RTS_LOGW("Desc of node[%s] is null", nodePtr->GetName().c_str());
continue;
}
auto iter = std::find(MEM_TYPE_RANGE_NODES.begin(), MEM_TYPE_RANGE_NODES.end(), opDescPtr->GetType());
if (iter != MEM_TYPE_RANGE_NODES.end()) {
auto ret = SetMemTypeRange(nodePtr);
if (ret != SUCCESS) {
RTS_REPORT_CALL_ERROR("Set mem type range attr failed, node[%s], retCode=%u.", nodePtr->GetName().c_str(), ret);
return FAILED;
}
}
}
return SUCCESS;
}
}
}