* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <acl/acl.h>
#include <iostream>
#include <vector>
#include "atb/atb_infer.h"
static void CreateInTensorDescs(atb::SVector<atb::TensorDesc> &intensorDescs)
{
for (size_t i = 0; i < intensorDescs.size(); i++) {
intensorDescs.at(i).dtype = ACL_FLOAT16;
intensorDescs.at(i).format = ACL_FORMAT_ND;
intensorDescs.at(i).shape.dimNum = 2;
intensorDescs.at(i).shape.dims[0] = 2;
intensorDescs.at(i).shape.dims[1] = 2;
}
}
static aclError CreateInTensors(atb::SVector<atb::Tensor> &inTensors, atb::SVector<atb::TensorDesc> &intensorDescs)
{
std::vector<char> zeroData(8, 0);
int ret;
for (size_t i = 0; i < inTensors.size(); i++) {
inTensors.at(i).desc = intensorDescs.at(i);
inTensors.at(i).dataSize = atb::Utils::GetTensorSize(inTensors.at(i));
ret = aclrtMalloc(&inTensors.at(i).deviceData, inTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!";
return ret;
}
ret = aclrtMemcpy(inTensors.at(i).deviceData, inTensors.at(i).dataSize, zeroData.data(), zeroData.size(), ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != 0) {
std::cout << "memcpy error!";
}
}
return ret;
}
static aclError CreateOutTensors(atb::SVector<atb::Tensor> &outTensors, atb::SVector<atb::TensorDesc> &outtensorDescs)
{
int ret;
for (size_t i = 0; i < outTensors.size(); i++) {
outTensors.at(i).desc = outtensorDescs.at(i);
outTensors.at(i).dataSize = atb::Utils::GetTensorSize(outTensors.at(i));
ret = aclrtMalloc(&outTensors.at(i).deviceData, outTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!";
}
}
return ret;
}
static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation **operation)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(3);
size_t nodeId = 0;
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode2 = opGraph.nodes.at(nodeId++);
atb::Node &addNode3 = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {3};
atb::infer::ElewiseParam addParam2;
addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam2, &addNode2.operation);
addNode2.inTensorIds = {3, 1};
addNode2.outTensorIds = {4};
atb::infer::ElewiseParam addParam3;
addParam3.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
CreateOperation(addParam3, &addNode3.operation);
addNode3.inTensorIds = {4, 1};
addNode3.outTensorIds = {2};
atb::CreateOperation(opGraph, operation);
}
static void CreateGraphOperationForMultiStream(atb::GraphParam &opGraph, atb::Operation **operation)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 2;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(4);
size_t nodeId = 0;
atb::Node &mulNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode2 = opGraph.nodes.at(nodeId++);
atb::Node &graphNode = opGraph.nodes.at(nodeId++);
atb::Node &mulNode1 = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL;
atb::CreateOperation(mulParam, &mulNode.operation);
mulNode.inTensorIds = {0, 1};
mulNode.outTensorIds = {3};
atb::infer::ElewiseParam addParam2;
addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam2, &addNode2.operation);
addNode2.inTensorIds = {0, 1};
addNode2.outTensorIds = {4};
atb::GraphParam graphParam;
CreateMiniGraphOperation(graphParam, &graphNode.operation);
graphNode.inTensorIds = {4, 1};
graphNode.outTensorIds = {5};
SetExecuteStreamId(graphNode.operation, 1);
atb::CreateOperation(mulParam, &mulNode1.operation);
mulNode1.inTensorIds = {5, 1};
mulNode1.outTensorIds = {2};
SetExecuteStreamId(mulNode1.operation, 1);
atb::CreateOperation(opGraph, operation);
}
static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(5);
size_t nodeId = 0;
atb::Node &mulNode = opGraph.nodes.at(nodeId++);
atb::Node &waitNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &graphNode = opGraph.nodes.at(nodeId++);
atb::Node &recordNode = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL;
atb::CreateOperation(mulParam, &mulNode.operation);
mulNode.inTensorIds = {0, 1};
mulNode.outTensorIds = {3};
atb::common::EventParam waitParam;
waitParam.event = event;
waitParam.operatorType = atb::common::EventParam::OperatorType::WAIT;
atb::CreateOperation(waitParam, &waitNode.operation);
atb::GraphParam graphParam;
CreateMiniGraphOperation(graphParam, &graphNode.operation);
graphNode.inTensorIds = {3, 4};
graphNode.outTensorIds = {2};
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {4};
SetExecuteStreamId(addNode.operation, 1);
atb::common::EventParam recordParam;
recordParam.event = event;
recordParam.operatorType = atb::common::EventParam::OperatorType::RECORD;
atb::CreateOperation(recordParam, &recordNode.operation);
SetExecuteStreamId(recordNode.operation, 1);
atb::CreateOperation(opGraph, operation);
}
static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(5);
size_t nodeId = 0;
atb::Node &mulNode = opGraph.nodes.at(nodeId++);
atb::Node &waitNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &graphNode = opGraph.nodes.at(nodeId++);
atb::Node &recordNode = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL;
atb::CreateOperation(mulParam, &mulNode.operation);
mulNode.inTensorIds = {0, 1};
mulNode.outTensorIds = {3};
atb::common::EventParam waitParam;
waitParam.event = event;
waitParam.operatorType = atb::common::EventParam::OperatorType::RECORD;
atb::CreateOperation(waitParam, &waitNode.operation);
atb::GraphParam graphParam;
CreateMiniGraphOperation(graphParam, &graphNode.operation);
graphNode.inTensorIds = {3, 4};
graphNode.outTensorIds = {2};
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {4};
SetExecuteStreamId(addNode.operation, 1);
atb::common::EventParam recordParam;
recordParam.event = event;
recordParam.operatorType = atb::common::EventParam::OperatorType::WAIT;
atb::CreateOperation(recordParam, &recordNode.operation);
SetExecuteStreamId(recordNode.operation, 1);
atb::CreateOperation(opGraph, operation);
}
int main()
{
aclInit(nullptr);
uint32_t deviceId = 1;
aclrtSetDevice(deviceId);
aclrtStream stream1 = nullptr;
aclrtCreateStream(&stream1);
aclrtStream stream2 = nullptr;
aclrtCreateStream(&stream2);
std::vector<aclrtStream> streams = {stream1, stream2};
atb::Context *context = nullptr;
atb::CreateContext(&context);
context->SetExecuteStreams(streams);
atb::Operation *operation = nullptr;
atb::GraphParam opGraph;
CreateGraphOperationForMultiStream(opGraph, &operation);
atb::VariantPack pack;
atb::SVector<atb::TensorDesc> intensorDescs;
atb::SVector<atb::TensorDesc> outtensorDescs;
uint32_t inTensorNum = opGraph.inTensorNum;
uint32_t outTensorNum = opGraph.outTensorNum;
inTensorNum = operation->GetInputNum();
outTensorNum = operation->GetOutputNum();
pack.inTensors.resize(inTensorNum);
intensorDescs.resize(inTensorNum);
CreateInTensorDescs(intensorDescs);
outtensorDescs.resize(outTensorNum);
pack.outTensors.resize(outTensorNum);
operation->InferShape(intensorDescs, outtensorDescs);
aclError ret = CreateOutTensors(pack.outTensors, outtensorDescs);
if (ret != 0) {
exit(ret);
}
uint64_t workspaceSize = 0;
void *workSpace = nullptr;
std::cout << "Single graph multi-stream demo start" << std::endl;
ret = CreateInTensors(pack.inTensors, intensorDescs);
if (ret != 0) {
exit(ret);
}
operation->Setup(pack, workspaceSize, context);
if (workspaceSize != 0 && workSpace == nullptr) {
ret = aclrtMalloc(&workSpace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!\n";
exit(1);
}
}
operation->Execute(pack, (uint8_t*)workSpace, workspaceSize, context);
ret = aclrtSynchronizeStream(stream1);
if (ret != 0) {
std::cout << "sync error!";
exit(1);
}
ret = aclrtSynchronizeStream(stream2);
if (ret != 0) {
std::cout << "sync error!";
exit(1);
}
atb::DestroyOperation(operation);
atb::DestroyContext(context);
for (size_t i = 0; i < pack.inTensors.size(); i++) {
aclrtFree(pack.inTensors.at(i).deviceData);
}
for (size_t i = 0; i < pack.outTensors.size(); i++) {
aclrtFree(pack.outTensors.at(i).deviceData);
}
aclrtFree(workSpace);
aclrtDestroyStream(stream1);
aclrtDestroyStream(stream2);
aclrtResetDevice(deviceId);
aclFinalize();
}