* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "atb/operation/graph_operation.h"
#include "atb/types.h"
#include "atb/utils/log.h"
#include "atb/operation/plugin_operation.h"
#include "atb/utils/tensor_util.h"
#include "atb/utils/common_utils.h"
namespace atb {
const size_t MAX_NODE_NUM = 1024;
const size_t MAX_GRAPH_NAME_LEN = 128;
static bool CheckNodeTensorNum(const Node &node, size_t nodeId)
{
if (!node.operation) {
ATB_LOG(ERROR) << " node[" << nodeId << "].operation is null";
return false;
}
if (node.operation->GetInputNum() != node.inTensorIds.size()) {
ATB_LOG(ERROR) << " node[" << nodeId << "].inTensorIds.size: " << node.inTensorIds.size()
<< " != operation.inputNum: " << node.operation->GetInputNum();
return false;
}
if (node.operation->GetOutputNum() != node.outTensorIds.size()) {
ATB_LOG(ERROR) << " node[" << nodeId << "].outTensorIds.size: " << node.outTensorIds.size()
<< " != operation.outputNum: " << node.operation->GetOutputNum();
return false;
}
return true;
}
static bool CheckNode(const Node &node, const size_t &nodeId, const uint64_t &totalTensorNum,
std::vector<bool> &tensorIsValued, std::vector<bool> &tensorIsUsed)
{
if (!CheckNodeTensorNum(node, nodeId)) {
return false;
}
for (size_t i = 0; i < node.inTensorIds.size(); ++i) {
uint32_t tensorId = node.inTensorIds.at(i);
if (tensorId >= totalTensorNum) {
ATB_LOG(ERROR) << " node[" << nodeId << "].inTensorIds[" << i << "]: " << tensorId
<< " is invalid, need less than:" << totalTensorNum;
return false;
}
if (!tensorIsValued.at(tensorId)) {
ATB_LOG(ERROR) << " node[" << nodeId << "].inTensorIds[" << i << "]: " << tensorId
<< " is not assigned value yet, please check your graph.";
return false;
}
tensorIsUsed.at(tensorId) = true;
}
for (size_t i = 0; i < node.outTensorIds.size(); ++i) {
uint32_t tensorId = node.outTensorIds.at(i);
if (tensorId >= totalTensorNum) {
ATB_LOG(ERROR) << " node[" << nodeId << "].outTensorIds[" << i << "]: " << tensorId
<< " is invalid, need less than: " << totalTensorNum;
return false;
}
if (tensorIsValued.at(tensorId)) {
bool writeInPlaceFlag = false;
for (size_t j = 0; j < node.inTensorIds.size(); ++j) {
if (node.inTensorIds.at(j) == tensorId) {
writeInPlaceFlag = true;
break;
}
}
if (writeInPlaceFlag) {
ATB_LOG(WARN) << " node[" << nodeId << "].outTensorIds[" << i << "]: " << tensorId
<< " is write in place.";
tensorIsUsed.at(tensorId) = false;
} else {
ATB_LOG(WARN) << " node[" << nodeId << "].outTensorIds[" << i << "]: " << tensorId
<< " has already been assigned value, please check your graph.";
}
} else {
tensorIsValued.at(tensorId) = true;
}
}
return true;
}
static bool IsValidGraphName(const std::string &name)
{
if (name.length() > MAX_GRAPH_NAME_LEN) {
return false;
}
for (char c : name) {
if (!isalnum(c) && c != '_') {
return false;
}
}
return true;
}
static bool CheckAllTensor(const GraphParam ¶m, const std::vector<bool> &tensorIsValued,
const std::vector<bool> &tensorIsUsed)
{
uint64_t totalTensorNum = param.inTensorNum + param.outTensorNum + param.internalTensorNum;
for (size_t tensorId = param.inTensorNum; tensorId < totalTensorNum; tensorId++) {
if (!tensorIsValued.at(tensorId)) {
ATB_LOG(ERROR) << "graph tensorId: " << tensorId << " is not assigned value, please check your graph.";
return false;
}
}
for (size_t tensorId = 0; tensorId < param.inTensorNum; tensorId++) {
if (!tensorIsUsed.at(tensorId)) {
ATB_LOG(WARN) << "graph intensorId: " << tensorId << " is not used.";
}
}
uint64_t internalTensorStartIdx = param.inTensorNum + param.outTensorNum;
for (size_t tensorId = internalTensorStartIdx; tensorId < totalTensorNum; tensorId++) {
if (!tensorIsUsed.at(tensorId)) {
ATB_LOG(WARN) << "graph internal tensorId: " << tensorId << " is not used.";
}
}
return true;
}
static bool CheckGraphParam(const GraphParam ¶m)
{
ATB_LOG(INFO) << "start check " << param.name << " graph.";
if (!IsValidGraphName(param.name)) {
ATB_LOG(ERROR) << "GraphParam.name: " << param.name << " is invalid.";
return false;
}
const uint64_t totalTensorNum = param.inTensorNum + param.outTensorNum + param.internalTensorNum;
if (param.inTensorNum > MAX_SVECTOR_SIZE || param.outTensorNum > MAX_SVECTOR_SIZE ||
param.internalTensorNum > MAX_SVECTOR_SIZE) {
ATB_LOG(ERROR) << "graph intensor: " << param.inTensorNum << ", outtensor: " << param.outTensorNum
<< ", internaltensor: " << param.internalTensorNum << " need less than " << MAX_SVECTOR_SIZE;
return false;
}
if (param.nodes.size() > MAX_NODE_NUM) {
ATB_LOG(ERROR) << "graph nodes num is too large, nodes num: " << param.nodes.size()
<< " > max nodes num: " << MAX_NODE_NUM;
return false;
}
std::vector<bool> tensorIsValued(totalTensorNum, false);
std::vector<bool> tensorIsUsed(totalTensorNum, false);
for (size_t tensorId = 0; tensorId < param.inTensorNum; tensorId++) {
tensorIsValued.at(tensorId) = true;
}
for (size_t nodeId = 0; nodeId < param.nodes.size(); ++nodeId) {
const auto &node = param.nodes.at(nodeId);
if (!CheckNode(node, nodeId, totalTensorNum, tensorIsValued, tensorIsUsed)) {
ATB_LOG(ERROR) << "graph: " << param.name << " check failed.";
return false;
}
}
if (!CheckAllTensor(param, tensorIsValued, tensorIsUsed)) {
return false;
}
ATB_LOG(INFO) << "check " << param.name << " graph success.";
return true;
}
static std::string JoinInts(const SVector<uint32_t> &ids)
{
std::string ret;
for (size_t i = 0; i < ids.size(); ++i) {
if (i == 0) {
ret.append(std::to_string(ids.at(i)));
} else {
ret.append(", " + std::to_string(ids.at(i)));
}
}
return ret;
}
static std::string GraphToString(const GraphParam ¶m)
{
std::stringstream ss;
ss << "inTensorNum:" << param.inTensorNum << ", outTensorNum:" << param.outTensorNum
<< ", internalTensorNum:" << param.internalTensorNum;
for (size_t i = 0; i < param.nodes.size(); ++i) {
ss << "\nnode[" << i << "]: operation:" << param.nodes.at(i).operation << ", inTensorIds:["
<< JoinInts(param.nodes.at(i).inTensorIds) << "], outTensorIds:[" << JoinInts(param.nodes.at(i).outTensorIds)
<< "]";
}
return ss.str();
}
template <> Status CreateOperation(const GraphParam &opParam, Operation **operation)
{
if (operation == nullptr) {
ATB_LOG(ERROR) << "invalid param, operation is null";
return ERROR_INVALID_PARAM;
}
if (!CheckGraphParam(opParam)) {
std::string graphString = GraphToString(opParam);
ATB_LOG(ERROR) << "invalid param, graph param is invalid, graph:" << graphString;
return ERROR_INVALID_GRAPH;
}
ATB_LOG(INFO) << GraphToString(opParam);
std::string name = opParam.name.empty() ? "GraphOperation" : opParam.name;
*operation = new (std::nothrow) GraphOperation(name, opParam);
if (*operation == nullptr) {
ATB_LOG(ERROR) << "failed to new operation:" << name;
return ERROR_OUT_OF_HOST_MEMORY;
}
return NO_ERROR;
}
GraphOperation::GraphOperation(const std::string &name) : OperationBase(name)
{
UsePluginOperations();
InitEmptyInTensorPerms();
InitEmptyOutTensorPerms();
}
GraphOperation::GraphOperation(const std::string &name, const GraphParam &opGraph)
: OperationBase(name), opGraph_(opGraph)
{
UsePluginOperations();
InitEmptyInTensorPerms();
InitEmptyOutTensorPerms();
}
GraphOperation::~GraphOperation()
{
for (size_t i = 0; i < opGraph_.nodes.size(); i++) {
if (opGraph_.nodes.at(i).operation != nullptr) {
DestroyOperation(opGraph_.nodes.at(i).operation);
opGraph_.nodes.at(i).operation = nullptr;
}
}
}
uint32_t GraphOperation::GetInputNum() const
{
return opGraph_.inTensorNum;
}
uint32_t GraphOperation::GetOutputNum() const
{
return opGraph_.outTensorNum;
}
Status GraphOperation::InferShapeImpl(const SVector<TensorDesc> &inTensorDescs,
SVector<TensorDesc> &outTensorDescs) const
{
if (opGraph_.inferShapeFunc) {
ATB_LOG(INFO) << GetLogPrefix() << "call user infer shape func";
try {
Status st = opGraph_.inferShapeFunc(inTensorDescs, outTensorDescs);
if (st != NO_ERROR) {
ATB_LOG(ERROR) << GetLogPrefix() << "user infer shape func fail, error: " << st;
return st;
}
} catch (const std::exception &e) {
ATB_LOG(ERROR) << GetLogPrefix() << "user infer shape func throw an exception:" << e.what();
return ERROR_GRAPH_INFERSHAPE_FUNC_FAIL;
}
for (size_t i = 0; i < outTensorDescs.size(); i++) {
TensorDesc desc = outTensorDescs.at(i);
if (desc.shape.dimNum == 0 || desc.dtype == ACL_DT_UNDEFINED || desc.format == ACL_FORMAT_UNDEFINED) {
ATB_LOG(ERROR) << GetLogPrefix() << "user infer shape result error, outTensorDesc[" << i
<< "]: " << TensorUtil::TensorDescToString(desc);
return ERROR_GRAPH_INFERSHAPE_FUNC_FAIL;
}
}
return NO_ERROR;
}
return InferShapeImplDefault(inTensorDescs, outTensorDescs);
}
std::shared_ptr<Runner> GraphOperation::CreateRunner(Context &context) const
{
std::shared_ptr<GraphRunner> runner = std::make_shared<GraphRunner>(GetName() + "Runner");
if (!runner) {
ATB_LOG(ERROR) << GetLogPrefix() << "make_shared GraphRunner fail";
return std::shared_ptr<Runner>();
}
GraphRunner::Graph &runnerGraph = runner->GetGraph();
runnerGraph.inTensors.reserve(opGraph_.inTensorNum);
runnerGraph.outTensors.reserve(opGraph_.outTensorNum);
runnerGraph.internalTensors.reserve(opGraph_.internalTensorNum);
runnerGraph.inTensors.resize(opGraph_.inTensorNum);
runnerGraph.outTensors.resize(opGraph_.outTensorNum);
runnerGraph.internalTensors.resize(opGraph_.internalTensorNum);
runnerGraph.nodes.resize(opGraph_.nodes.size());
std::vector<Tensor *> fullTensorPtrs(opGraph_.inTensorNum + opGraph_.outTensorNum + opGraph_.internalTensorNum);
BuildFullTensorPtrs(fullTensorPtrs, runnerGraph);
size_t nodeIdIdx = operationBaseIds_.size();
std::vector<int64_t> nodeOperationIds = operationBaseIds_;
nodeOperationIds.push_back(0);
for (size_t i = 0; i < runnerGraph.nodes.size(); ++i) {
nodeOperationIds.at(nodeIdIdx) = static_cast<int64_t>(i);
Status ret = CreateRunnerNode(i, runnerGraph, nodeOperationIds, fullTensorPtrs, context);
if (ret != NO_ERROR) {
return std::shared_ptr<Runner>();
}
}
ATB_LOG(INFO) << GetLogPrefix() << "create runner success";
return runner;
}
Status GraphOperation::CreateRunnerNode(const size_t nodeId, GraphRunner::Graph &runnerGraph,
std::vector<int64_t> &nodeOperationIds,
const std::vector<Tensor *> &fullTensorPtrs, Context &context) const
{
auto &opNode = opGraph_.nodes.at(nodeId);
GraphRunner::Node &runnerNode = runnerGraph.nodes.at(nodeId);
if (!opNode.operation) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << nodeId << "] operation is invalid.";
return ERROR_INVALID_PARAM;
}
runnerNode.op.reset(opNode.operation, [](Operation *operation) { (void)operation; });
OperationBase *opBase = dynamic_cast<OperationBase *>(opNode.operation);
if (!opBase) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << nodeId << "] operation is not inherit from OperationBase";
return ERROR_INVALID_PARAM;
}
opBase->runner_ = opBase->CreateRunner(context);
runnerNode.runner = opBase->runner_;
if (!runnerNode.runner) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << nodeId << "] runner is null.";
return ERROR_INVALID_PARAM;
}
runnerNode.runner->SetRunnerInfo(runnerNode.op->GetName(), nodeOperationIds);
runnerNode.inTensorReshapeFuncs = opNode.inTensorReshapeFuncs;
runnerNode.inTensors.reserve(opNode.inTensorIds.size());
runnerNode.outTensors.reserve(opNode.outTensorIds.size());
runnerNode.inTensorChunks = opNode.inTensorChunks;
runnerNode.inTensors.resize(opNode.inTensorIds.size());
runnerNode.outTensors.resize(opNode.outTensorIds.size());
for (size_t j = 0; j < opNode.inTensorIds.size(); ++j) {
runnerNode.inTensors.at(j) = fullTensorPtrs.at(opNode.inTensorIds.at(j));
}
for (size_t k = 0; k < opNode.outTensorIds.size(); ++k) {
runnerNode.outTensors.at(k) = fullTensorPtrs.at(opNode.outTensorIds.at(k));
}
return NO_ERROR;
}
Status GraphOperation::SetNodeOperationIds()
{
for (size_t i = 0; i < opGraph_.nodes.size(); i++) {
auto &opNode = opGraph_.nodes.at(i);
if (!opNode.operation) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << i << "] operation is invalid.";
return ERROR_INVALID_PARAM;
}
OperationBase *opBase = dynamic_cast<OperationBase *>(opNode.operation);
if (!opBase) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << i << "] operation is not inherit from OperationBase";
return ERROR_INVALID_PARAM;
}
Status st = opBase->SetOperationBaseIds(operationBaseIds_, i);
if (st != NO_ERROR) {
ATB_LOG(ERROR) << GetLogPrefix() << "set graphoperation node[" << i << "] operationBaseId fail";
return st;
}
}
return NO_ERROR;
}
void GraphOperation::UsePluginOperations()
{
for (size_t i = 0; i < opGraph_.nodes.size(); ++i) {
auto &opNode = opGraph_.nodes.at(i);
OperationBase *opBase = dynamic_cast<OperationBase *>(opNode.operation);
if (!opBase) {
Operation *oldOperation = opNode.operation;
PluginOperation *pluginOp = new PluginOperation(oldOperation);
if (!pluginOp) {
ATB_LOG(ERROR) << GetLogPrefix() << "new PluginOperation Failed!";
return;
}
ATB_LOG(INFO) << GetLogPrefix() << "node[" << i << "] operation is plugin operation, new PluginOperation";
opNode.operation = pluginOp;
}
}
}
Status GraphOperation::InferShapeImplDefault(const SVector<TensorDesc> &inTensorDescs,
SVector<TensorDesc> &outTensorDescs) const
{
ATB_LOG(INFO) << GetLogPrefix() << "InferShapeImplDefault start, inTensorDescsSize:" << inTensorDescs.size()
<< ", outTensorDescs:" << outTensorDescs.size();
std::vector<TensorDesc> totalTensorDescs(opGraph_.inTensorNum + opGraph_.outTensorNum + opGraph_.internalTensorNum);
for (size_t i = 0; i < inTensorDescs.size(); ++i) {
totalTensorDescs.at(i) = inTensorDescs.at(i);
}
for (size_t nodeId = 0; nodeId < opGraph_.nodes.size(); ++nodeId) {
auto &opNode = opGraph_.nodes.at(nodeId);
SVector<TensorDesc> opInTensorDescs;
opInTensorDescs.reserve(opNode.operation->GetInputNum());
opInTensorDescs.resize(opNode.operation->GetInputNum());
for (size_t i = 0; i < opNode.inTensorIds.size(); ++i) {
uint32_t tensorId = opNode.inTensorIds.at(i);
opInTensorDescs.at(i) = totalTensorDescs.at(tensorId);
if (i < opNode.inTensorReshapeFuncs.size() && opNode.inTensorReshapeFuncs.at(i)) {
Dims newShape;
opNode.inTensorReshapeFuncs.at(i)(opInTensorDescs.at(i).shape, newShape);
opInTensorDescs.at(i).shape = newShape;
}
}
SVector<TensorDesc> opOutTensorDescs;
Status st = opNode.operation->InferShape(opInTensorDescs, opOutTensorDescs);
if (st != 0) {
ATB_LOG(ERROR) << GetLogPrefix() << "node[" << nodeId << "] infer shape fail, error code: " << st;
return st;
}
for (size_t i = 0; i < opNode.outTensorIds.size(); ++i) {
uint32_t tensorId = opNode.outTensorIds.at(i);
totalTensorDescs.at(tensorId) = opOutTensorDescs.at(i);
}
}
outTensorDescs.reserve(opGraph_.outTensorNum);
outTensorDescs.resize(opGraph_.outTensorNum);
for (size_t i = 0; i < opGraph_.outTensorNum; ++i) {
outTensorDescs.at(i) = totalTensorDescs.at(inTensorDescs.size() + i);
}
return NO_ERROR;
}
void GraphOperation::BuildFullTensorPtrs(std::vector<Tensor *> &fullTensorPtrs, GraphRunner::Graph &runnerGraph) const
{
size_t offset = 0;
for (size_t i = 0; i < runnerGraph.inTensors.size(); ++i) {
fullTensorPtrs.at(offset++) = &runnerGraph.inTensors.at(i);
}
for (size_t i = 0; i < runnerGraph.outTensors.size(); ++i) {
fullTensorPtrs.at(offset++) = &runnerGraph.outTensors.at(i);
}
for (size_t i = 0; i < runnerGraph.internalTensors.size(); ++i) {
fullTensorPtrs.at(offset++) = &runnerGraph.internalTensors.at(i);
}
}
SVector<bool> GraphOperation::GetEmptyInTensorPermissions() const
{
return emptyInTensorPerms_;
}
void GraphOperation::InitEmptyInTensorPerms()
{
emptyInTensorPerms_.reserve(opGraph_.inTensorNum);
emptyInTensorPerms_.resize(opGraph_.inTensorNum);
for (size_t i = 0; i < emptyInTensorPerms_.size(); ++i) {
emptyInTensorPerms_.at(i) = false;
}
for (size_t nodeId = 0; nodeId < opGraph_.nodes.size(); ++nodeId) {
auto &node = opGraph_.nodes.at(nodeId);
if (!node.operation) {
ATB_LOG(WARN) << GetLogPrefix() << "node[" << nodeId << "] operation is null";
continue;
}
OperationBase *opBase = dynamic_cast<OperationBase *>(node.operation);
if (!opBase) {
ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] operation is not OperationBase";
continue;
}
SVector<bool> childOpEmptyInTensorPerms = opBase->GetEmptyInTensorPermissions();
if (childOpEmptyInTensorPerms.size() != node.inTensorIds.size()) {
ATB_LOG(WARN) << GetLogPrefix() << "node[" << nodeId
<< "] childOpEmptyInTensorPerms.size:" << childOpEmptyInTensorPerms.size()
<< " != inTensorIds.size:" << node.inTensorIds.size();
continue;
}
for (size_t i = 0; i < childOpEmptyInTensorPerms.size(); ++i) {
uint32_t inTensorId = node.inTensorIds.at(i);
if (childOpEmptyInTensorPerms.at(i) && inTensorId < emptyInTensorPerms_.size()) {
emptyInTensorPerms_.at(inTensorId) = true;
ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] " << node.operation->GetName()
<< " inTensor[" << i << "] is allow empty";
}
}
}
ATB_LOG(INFO) << GetLogPrefix() << "emptyInTensorPerms:" << emptyInTensorPerms_;
}
SVector<bool> GraphOperation::GetEmptyOutTensorPermissions() const
{
return emptyOutTensorPerms_;
}
void GraphOperation::InitEmptyOutTensorPerms()
{
emptyOutTensorPerms_.reserve(opGraph_.outTensorNum);
emptyOutTensorPerms_.resize(opGraph_.outTensorNum);
for (size_t i = 0; i < emptyOutTensorPerms_.size(); ++i) {
emptyOutTensorPerms_.at(i) = false;
}
for (size_t nodeId = 0; nodeId < opGraph_.nodes.size(); ++nodeId) {
auto &node = opGraph_.nodes.at(nodeId);
if (!node.operation) {
ATB_LOG(WARN) << GetLogPrefix() << "node[" << nodeId << "] operation is null";
continue;
}
OperationBase *opBase = dynamic_cast<OperationBase *>(node.operation);
if (!opBase) {
ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] operation is not OperationBase";
continue;
}
SVector<bool> childOpEmptyOutTensorPerms = opBase->GetEmptyOutTensorPermissions();
if (childOpEmptyOutTensorPerms.size() != node.outTensorIds.size()) {
ATB_LOG(WARN) << GetLogPrefix() << "node[" << nodeId
<< "] childOpEmptyOutTensorPerms.size:" << childOpEmptyOutTensorPerms.size()
<< " != outTensorIds.size:" << node.outTensorIds.size();
continue;
}
for (size_t i = 0; i < childOpEmptyOutTensorPerms.size(); ++i) {
uint32_t outTensorId = node.outTensorIds.at(i);
if (childOpEmptyOutTensorPerms.at(i) && outTensorId >= opGraph_.inTensorNum &&
outTensorId < opGraph_.inTensorNum + opGraph_.outTensorNum) {
emptyOutTensorPerms_.at(outTensorId - opGraph_.inTensorNum) = true;
ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] " << node.operation->GetName()
<< " outTensor[" << i << "] is allow empty";
}
}
}
ATB_LOG(INFO) << GetLogPrefix() << "emptyOutTensorPerms:" << emptyOutTensorPerms_;
}
void GraphOperation::GetGraphInfoImpl(nlohmann::json &graphJson) const
{
graphJson["internalTensorNum"] = opGraph_.internalTensorNum;
graphJson["param"] = "";
for (size_t i = 0; i < opGraph_.nodes.size(); i++) {
nlohmann::json graphNodeJson;
auto &node = opGraph_.nodes.at(i);
nlohmann::json tmpNodeJson = "";
OperationBase *opBase = dynamic_cast<OperationBase *>(node.operation);
if (opBase) {
tmpNodeJson = opBase->GetGraphInfo();
} else {
nlohmann::json tmpJson;
tmpJson["name"] = node.operation ? node.operation->GetName() : "Unknown Operation";
tmpNodeJson = tmpJson;
}
graphNodeJson = tmpNodeJson;
ATB_LOG(INFO) << GetLogPrefix() << "node[" << i << "] json string:" << graphNodeJson.dump();
graphNodeJson["inTensorIds"] = SVectorToVector(node.inTensorIds);
graphNodeJson["outTensorIds"] = SVectorToVector(node.outTensorIds);
graphJson["nodes"].emplace_back(graphNodeJson);
}
}
void GraphOperation::SetExecuteStreamId(uint32_t streamId)
{
streamId_ = streamId;
for (size_t i = 0; i < opGraph_.nodes.size(); i++) {
OperationBase *opBase = dynamic_cast<OperationBase*>(opGraph_.nodes.at(i).operation);
if (!opBase) {
ATB_LOG(ERROR) << GetLogPrefix() << "Graph node internal error! set stream id ERROR!";
return;
}
if (opBase->GetExecuteStreamId() == 0) {
ATB_LOG(INFO) << GetLogPrefix() << "Change node[" << i <<"] stream id to " << streamId;
opBase->SetExecuteStreamId(streamId);
}
}
}
}