* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <acl/acl.h>
#include <iostream>
#include <vector>
#include "atb/atb_infer.h"
static void CreateInTensorDescs(atb::SVector<atb::TensorDesc> &intensorDescs)
{
for (size_t i = 0; i < intensorDescs.size(); i++) {
intensorDescs.at(i).dtype = ACL_FLOAT16;
intensorDescs.at(i).format = ACL_FORMAT_ND;
intensorDescs.at(i).shape.dimNum = 2;
intensorDescs.at(i).shape.dims[0] = 2;
intensorDescs.at(i).shape.dims[1] = 2;
}
}
static aclError CreateInTensors(atb::SVector<atb::Tensor> &inTensors, atb::SVector<atb::TensorDesc> &intensorDescs)
{
std::vector<char> zeroData(8, 0);
int ret;
for (size_t i = 0; i < inTensors.size(); i++) {
inTensors.at(i).desc = intensorDescs.at(i);
inTensors.at(i).dataSize = atb::Utils::GetTensorSize(inTensors.at(i));
ret = aclrtMalloc(&inTensors.at(i).deviceData, inTensors.at(i).dataSize,
ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!";
return ret;
}
ret = aclrtMemcpy(inTensors.at(i).deviceData, inTensors.at(i).dataSize, zeroData.data(), zeroData.size(),
ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != 0) {
std::cout << "memcpy error!";
}
}
return ret;
}
static aclError CreateOutTensors(atb::SVector<atb::Tensor> &outTensors, atb::SVector<atb::TensorDesc> &outtensorDescs)
{
int ret;
for (size_t i = 0; i < outTensors.size(); i++) {
outTensors.at(i).desc = outtensorDescs.at(i);
outTensors.at(i).dataSize = atb::Utils::GetTensorSize(outTensors.at(i));
ret = aclrtMalloc(&outTensors.at(i).deviceData, outTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!";
}
}
return ret;
}
static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation **operation)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(3);
size_t nodeId = 0;
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode2 = opGraph.nodes.at(nodeId++);
atb::Node &addNode3 = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {3};
atb::infer::ElewiseParam addParam2;
addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam2, &addNode2.operation);
addNode2.inTensorIds = {3, 1};
addNode2.outTensorIds = {4};
atb::infer::ElewiseParam addParam3;
addParam3.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
CreateOperation(addParam3, &addNode3.operation);
addNode3.inTensorIds = {4, 1};
addNode3.outTensorIds = {2};
atb::CreateOperation(opGraph, operation);
}
static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(5);
size_t nodeId = 0;
atb::Node &mulNode = opGraph.nodes.at(nodeId++);
atb::Node &waitNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &graphNode = opGraph.nodes.at(nodeId++);
atb::Node &recordNode = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL;
atb::CreateOperation(mulParam, &mulNode.operation);
mulNode.inTensorIds = {0, 1};
mulNode.outTensorIds = {3};
atb::common::EventParam waitParam;
waitParam.event = event;
waitParam.operatorType = atb::common::EventParam::OperatorType::WAIT;
atb::CreateOperation(waitParam, &waitNode.operation);
atb::GraphParam graphParam;
CreateMiniGraphOperation(graphParam, &graphNode.operation);
graphNode.inTensorIds = {3, 4};
graphNode.outTensorIds = {2};
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {4};
atb::common::EventParam recordParam;
recordParam.event = event;
recordParam.operatorType = atb::common::EventParam::OperatorType::RECORD;
atb::CreateOperation(recordParam, &recordNode.operation);
atb::CreateOperation(opGraph, operation);
}
static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event)
{
opGraph.inTensorNum = 2;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = 2;
opGraph.nodes.resize(5);
size_t nodeId = 0;
atb::Node &mulNode = opGraph.nodes.at(nodeId++);
atb::Node &recordNode = opGraph.nodes.at(nodeId++);
atb::Node &addNode = opGraph.nodes.at(nodeId++);
atb::Node &graphNode = opGraph.nodes.at(nodeId++);
atb::Node &waitNode = opGraph.nodes.at(nodeId++);
atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL;
atb::CreateOperation(mulParam, &mulNode.operation);
mulNode.inTensorIds = {0, 1};
mulNode.outTensorIds = {3};
atb::common::EventParam recordParam;
recordParam.event = event;
recordParam.operatorType = atb::common::EventParam::OperatorType::RECORD;
atb::CreateOperation(recordParam, &recordNode.operation);
atb::GraphParam graphParam;
CreateMiniGraphOperation(graphParam, &graphNode.operation);
graphNode.inTensorIds = {3, 4};
graphNode.outTensorIds = {2};
atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
atb::CreateOperation(addParam, &addNode.operation);
addNode.inTensorIds = {0, 1};
addNode.outTensorIds = {4};
atb::common::EventParam waitParam;
waitParam.event = event;
waitParam.operatorType = atb::common::EventParam::OperatorType::WAIT;
atb::CreateOperation(waitParam, &waitNode.operation);
atb::CreateOperation(opGraph, operation);
}
int main()
{
aclInit(nullptr);
uint32_t deviceId = 1;
aclrtSetDevice(deviceId);
aclrtStream stream1 = nullptr;
aclrtCreateStream(&stream1);
aclrtStream stream2 = nullptr;
aclrtCreateStream(&stream2);
aclrtEvent event;
aclrtCreateEventWithFlag(&event, ACL_EVENT_SYNC);
atb::Context *contextWR = nullptr;
atb::CreateContext(&contextWR);
contextWR->SetExecuteStream(stream1);
atb::Context *contextRW = nullptr;
atb::CreateContext(&contextRW);
contextRW->SetExecuteStream(stream2);
atb::Operation *operationWR = nullptr;
atb::GraphParam opGraphWR;
CreateGraphOperationWithWREvent(opGraphWR, &operationWR, event);
atb::Operation *operationRW = nullptr;
atb::GraphParam opGraphRW;
CreateGraphOperationWithRWEvent(opGraphRW, &operationRW, event);
atb::VariantPack packWR;
atb::VariantPack packRW;
atb::SVector<atb::TensorDesc> intensorDescs;
atb::SVector<atb::TensorDesc> outtensorDescs;
uint32_t inTensorNum = opGraphWR.inTensorNum;
uint32_t outTensorNum = opGraphWR.outTensorNum;
inTensorNum = operationWR->GetInputNum();
outTensorNum = operationWR->GetOutputNum();
packWR.inTensors.resize(inTensorNum);
packRW.inTensors.resize(inTensorNum);
intensorDescs.resize(inTensorNum);
CreateInTensorDescs(intensorDescs);
outtensorDescs.resize(outTensorNum);
packWR.outTensors.resize(outTensorNum);
packRW.outTensors.resize(outTensorNum);
operationWR->InferShape(intensorDescs, outtensorDescs);
aclError ret = CreateInTensors(packWR.inTensors, intensorDescs);
if (ret != 0) {
exit(ret);
}
ret = CreateOutTensors(packWR.outTensors, outtensorDescs);
if (ret != 0) {
exit(ret);
}
ret = CreateInTensors(packRW.inTensors, intensorDescs);
if (ret != 0) {
exit(ret);
}
ret = CreateOutTensors(packRW.outTensors, outtensorDescs);
if (ret != 0) {
exit(ret);
}
uint64_t workspaceSizeWR = 0;
void *workSpaceWR = nullptr;
uint64_t workspaceSizeRW = 0;
void *workSpaceRW = nullptr;
std::cout << "multi graph multi-stream demo start" << std::endl;
operationWR->Setup(packWR, workspaceSizeWR, contextWR);
operationRW->Setup(packRW, workspaceSizeRW, contextRW);
if (workspaceSizeWR != 0 && workSpaceWR == nullptr) {
ret = aclrtMalloc(&workSpaceWR, workspaceSizeWR, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!\n";
exit(1);
}
}
if (workspaceSizeRW != 0 && workSpaceRW == nullptr) {
ret = aclrtMalloc(&workSpaceRW, workspaceSizeRW, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != 0) {
std::cout << "alloc error!\n";
exit(1);
}
}
operationWR->Execute(packWR, (uint8_t *)workSpaceWR, workspaceSizeWR, contextWR);
operationRW->Execute(packRW, (uint8_t *)workSpaceRW, workspaceSizeRW, contextRW);
ret = aclrtSynchronizeStream(stream1);
if (ret != 0) {
std::cout << "sync error!";
exit(1);
}
ret = aclrtSynchronizeStream(stream2);
if (ret != 0) {
std::cout << "sync error!";
exit(1);
}
atb::DestroyOperation(operationWR);
atb::DestroyContext(contextWR);
for (size_t i = 0; i < packWR.inTensors.size(); i++) {
aclrtFree(packWR.inTensors.at(i).deviceData);
}
for (size_t i = 0; i < packWR.outTensors.size(); i++) {
aclrtFree(packWR.outTensors.at(i).deviceData);
}
aclrtFree(workSpaceWR);
atb::DestroyOperation(operationRW);
atb::DestroyContext(contextRW);
for (size_t i = 0; i < packRW.inTensors.size(); i++) {
aclrtFree(packRW.inTensors.at(i).deviceData);
}
for (size_t i = 0; i < packRW.outTensors.size(); i++) {
aclrtFree(packRW.outTensors.at(i).deviceData);
}
aclrtFree(workSpaceRW);
aclrtDestroyEvent(event);
aclrtDestroyStream(stream1);
aclrtDestroyStream(stream2);
aclrtResetDevice(deviceId);
aclFinalize();
}