* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <vector>
#include "utils.h"
#include "acl/acl.h"
#include "model_utils.h"
#include "aclnnop/aclnn_add.h"
using namespace std;
namespace {
constexpr int32_t kDeviceId = 0;
constexpr int64_t kElementCount = 8;
const vector<int64_t> kShape = {4, 2};
const vector<float> kSrcAHost = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
const vector<float> kSrcBHost = {0.5f, 1.0f, 1.5f, 2.0f, 2.5f, 3.0f, 3.5f, 4.0f};
const int64_t kSize = kElementCount * sizeof(float);
}
int32_t TestCondIf()
{
INFO_LOG("========== IF condition ==========");
aclrtStream stream = nullptr;
aclrtStream subStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&stream));
CHECK_ERROR(aclrtCreateStream(&subStream));
void* srcADev = nullptr;
void* srcBDev = nullptr;
void* dstDev = nullptr;
aclTensor* srcA = nullptr;
aclTensor* srcB = nullptr;
aclTensor* dst = nullptr;
ModelUtils::CreateAclTensor(kShape, &srcADev, aclDataType::ACL_FLOAT, &srcA);
ModelUtils::CreateAclTensor(kShape, &srcBDev, aclDataType::ACL_FLOAT, &srcB);
ModelUtils::CreateAclTensor(kShape, &dstDev, aclDataType::ACL_FLOAT, &dst);
CHECK_ERROR(aclrtMemcpy(srcADev, kSize, kSrcAHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclrtMemcpy(srcBDev, kSize, kSrcBHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
float alphaTrue = 2.0f;
float alphaFalse = 0.5f;
aclScalar* scalarTrue = aclCreateScalar(&alphaTrue, aclDataType::ACL_FLOAT);
aclScalar* scalarFalse = aclCreateScalar(&alphaFalse, aclDataType::ACL_FLOAT);
uint64_t addWsSize0 = 0;
uint64_t addWsSize1 = 0;
aclOpExecutor* addExecutor0 = nullptr;
aclOpExecutor* addExecutor1 = nullptr;
aclnnAddGetWorkspaceSize(srcA, srcB, scalarTrue, dst, &addWsSize0, &addExecutor0);
aclnnAddGetWorkspaceSize(srcA, srcB, scalarFalse, dst, &addWsSize1, &addExecutor1);
void* wsAddr0 = nullptr;
if (addWsSize0 > 0) {
CHECK_ERROR(aclrtMalloc(&wsAddr0, addWsSize0, ACL_MEM_MALLOC_HUGE_FIRST));
}
void* wsAddr1 = nullptr;
if (addWsSize1 > 0) {
CHECK_ERROR(aclrtMalloc(&wsAddr1, addWsSize1, ACL_MEM_MALLOC_HUGE_FIRST));
}
CHECK_ERROR(aclmdlRICaptureBegin(stream, ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclmdlRICaptureStatus status = ACL_MODEL_RI_CAPTURE_STATUS_NONE;
aclmdlRI parentModelRI = nullptr;
CHECK_ERROR(aclmdlRICaptureGetInfo(stream, &status, &parentModelRI));
aclmdlRICondHandle condHandle = nullptr;
CHECK_ERROR(aclmdlRICondHandleCreate(parentModelRI, 0, ACL_MODEL_RI_COND_HANDLE_ASSIGN_DEFAULT, &condHandle));
uint64_t* condDevPtr = nullptr;
CHECK_ERROR(aclmdlRICondHandleGetCondPtr(condHandle, &condDevPtr));
aclmdlRI subModels[2] = {};
aclmdlRICondTaskParams params;
params.handle = condHandle;
params.type = ACL_MODEL_RI_COND_TYPE_IF;
params.size = 2;
params.modelRIArray = subModels;
CHECK_ERROR(aclmdlRIAddCondTask(params, stream, 0));
CHECK_ERROR(aclmdlRICaptureToModelRIBegin(subStream, subModels[0], ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclnnAdd(wsAddr0, addWsSize0, addExecutor0, subStream);
CHECK_ERROR(aclmdlRICaptureEnd(subStream, &subModels[0]));
CHECK_ERROR(aclmdlRICaptureToModelRIBegin(subStream, subModels[1], ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclnnAdd(wsAddr1, addWsSize1, addExecutor1, subStream);
CHECK_ERROR(aclmdlRICaptureEnd(subStream, &subModels[1]));
CHECK_ERROR(aclmdlRICaptureEnd(stream, &parentModelRI));
vector<float> outHost(kElementCount, 0.0f);
aclrtStream execStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&execStream));
static const uint64_t kCondTrue = 1;
CHECK_ERROR(aclrtMemcpy(condDevPtr, sizeof(uint64_t), &kCondTrue, sizeof(uint64_t), ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclmdlRIExecuteAsync(parentModelRI, execStream));
CHECK_ERROR(aclrtSynchronizeStream(execStream));
CHECK_ERROR(aclrtMemcpy(outHost.data(), kSize, dstDev, kSize, ACL_MEMCPY_DEVICE_TO_HOST));
INFO_LOG("IF true branch result:");
ModelUtils::PrintArray(outHost);
static const uint64_t kCondFalse = 0;
CHECK_ERROR(aclrtMemcpy(condDevPtr, sizeof(uint64_t), &kCondFalse, sizeof(uint64_t), ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclmdlRIExecuteAsync(parentModelRI, execStream));
CHECK_ERROR(aclrtSynchronizeStream(execStream));
CHECK_ERROR(aclrtMemcpy(outHost.data(), kSize, dstDev, kSize, ACL_MEMCPY_DEVICE_TO_HOST));
INFO_LOG("IF false branch result:");
ModelUtils::PrintArray(outHost);
CHECK_ERROR(aclmdlRIDestroy(parentModelRI));
CHECK_ERROR(aclrtDestroyStream(subStream));
CHECK_ERROR(aclrtDestroyStream(stream));
CHECK_ERROR(aclrtDestroyStream(execStream));
CHECK_ERROR(aclDestroyTensor(srcA));
CHECK_ERROR(aclDestroyTensor(srcB));
CHECK_ERROR(aclDestroyTensor(dst));
CHECK_ERROR(aclDestroyScalar(scalarTrue));
CHECK_ERROR(aclDestroyScalar(scalarFalse));
CHECK_ERROR(aclrtFree(srcADev));
CHECK_ERROR(aclrtFree(srcBDev));
CHECK_ERROR(aclrtFree(dstDev));
if (wsAddr0 != nullptr) {
CHECK_ERROR(aclrtFree(wsAddr0));
}
if (wsAddr1 != nullptr) {
CHECK_ERROR(aclrtFree(wsAddr1));
}
INFO_LOG("IF condition PASSED");
return 0;
}
int32_t TestCondWhile()
{
INFO_LOG("========== WHILE condition ==========");
aclrtStream stream = nullptr;
aclrtStream subStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&stream));
CHECK_ERROR(aclrtCreateStream(&subStream));
void* srcADev = nullptr;
void* srcBDev = nullptr;
void* dstDev = nullptr;
aclTensor* srcA = nullptr;
aclTensor* srcB = nullptr;
aclTensor* dst = nullptr;
ModelUtils::CreateAclTensor(kShape, &srcADev, aclDataType::ACL_FLOAT, &srcA);
ModelUtils::CreateAclTensor(kShape, &srcBDev, aclDataType::ACL_FLOAT, &srcB);
ModelUtils::CreateAclTensor(kShape, &dstDev, aclDataType::ACL_FLOAT, &dst);
CHECK_ERROR(aclrtMemcpy(srcADev, kSize, kSrcAHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclrtMemcpy(srcBDev, kSize, kSrcBHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
float alpha = 1.0f;
aclScalar* scalarAlpha = aclCreateScalar(&alpha, aclDataType::ACL_FLOAT);
uint64_t addWsSize = 0;
aclOpExecutor* addExecutor = nullptr;
aclnnAddGetWorkspaceSize(srcA, srcB, scalarAlpha, dst, &addWsSize, &addExecutor);
void* wsAddr = nullptr;
if (addWsSize > 0) {
CHECK_ERROR(aclrtMalloc(&wsAddr, addWsSize, ACL_MEM_MALLOC_HUGE_FIRST));
}
static const uint64_t kCondZero = 0;
void* condZeroDev = nullptr;
CHECK_ERROR(aclrtMalloc(&condZeroDev, sizeof(uint64_t), ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ERROR(aclrtMemcpy(condZeroDev, sizeof(uint64_t), &kCondZero, sizeof(uint64_t), ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclmdlRICaptureBegin(stream, ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclmdlRICaptureStatus status = ACL_MODEL_RI_CAPTURE_STATUS_NONE;
aclmdlRI parentModelRI = nullptr;
CHECK_ERROR(aclmdlRICaptureGetInfo(stream, &status, &parentModelRI));
aclmdlRICondHandle condHandle = nullptr;
CHECK_ERROR(aclmdlRICondHandleCreate(parentModelRI, 0, ACL_MODEL_RI_COND_HANDLE_ASSIGN_DEFAULT, &condHandle));
uint64_t* condDevPtr = nullptr;
CHECK_ERROR(aclmdlRICondHandleGetCondPtr(condHandle, &condDevPtr));
aclmdlRI loopBodyModel[1] = {};
aclmdlRICondTaskParams params;
params.handle = condHandle;
params.type = ACL_MODEL_RI_COND_TYPE_WHILE;
params.size = 1;
params.modelRIArray = loopBodyModel;
CHECK_ERROR(aclmdlRIAddCondTask(params, stream, 0));
CHECK_ERROR(aclmdlRICaptureToModelRIBegin(subStream, loopBodyModel[0], ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclnnAdd(wsAddr, addWsSize, addExecutor, subStream);
CHECK_ERROR(aclrtMemcpyAsync(
condDevPtr, sizeof(uint64_t), condZeroDev, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_DEVICE, subStream));
CHECK_ERROR(aclmdlRICaptureEnd(subStream, &loopBodyModel[0]));
CHECK_ERROR(aclmdlRICaptureEnd(stream, &parentModelRI));
vector<float> outHost(kElementCount, 0.0f);
aclrtStream execStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&execStream));
static const uint64_t kCondOne = 1;
CHECK_ERROR(aclrtMemcpy(condDevPtr, sizeof(uint64_t), &kCondOne, sizeof(uint64_t), ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclmdlRIExecuteAsync(parentModelRI, execStream));
CHECK_ERROR(aclrtSynchronizeStream(execStream));
CHECK_ERROR(aclrtMemcpy(outHost.data(), kSize, dstDev, kSize, ACL_MEMCPY_DEVICE_TO_HOST));
INFO_LOG("WHILE single iteration result:");
ModelUtils::PrintArray(outHost);
CHECK_ERROR(aclmdlRIDestroy(parentModelRI));
CHECK_ERROR(aclrtDestroyStream(subStream));
CHECK_ERROR(aclrtDestroyStream(stream));
CHECK_ERROR(aclrtDestroyStream(execStream));
CHECK_ERROR(aclDestroyTensor(srcA));
CHECK_ERROR(aclDestroyTensor(srcB));
CHECK_ERROR(aclDestroyTensor(dst));
CHECK_ERROR(aclDestroyScalar(scalarAlpha));
CHECK_ERROR(aclrtFree(srcADev));
CHECK_ERROR(aclrtFree(srcBDev));
CHECK_ERROR(aclrtFree(dstDev));
CHECK_ERROR(aclrtFree(condZeroDev));
if (wsAddr != nullptr) {
CHECK_ERROR(aclrtFree(wsAddr));
}
INFO_LOG("WHILE condition PASSED");
return 0;
}
int32_t TestCondSwitch()
{
INFO_LOG("========== SWITCH condition ==========");
aclrtStream stream = nullptr;
aclrtStream subStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&stream));
CHECK_ERROR(aclrtCreateStream(&subStream));
void* srcADev = nullptr;
void* srcBDev = nullptr;
void* dstDev = nullptr;
aclTensor* srcA = nullptr;
aclTensor* srcB = nullptr;
aclTensor* dst = nullptr;
ModelUtils::CreateAclTensor(kShape, &srcADev, aclDataType::ACL_FLOAT, &srcA);
ModelUtils::CreateAclTensor(kShape, &srcBDev, aclDataType::ACL_FLOAT, &srcB);
ModelUtils::CreateAclTensor(kShape, &dstDev, aclDataType::ACL_FLOAT, &dst);
CHECK_ERROR(aclrtMemcpy(srcADev, kSize, kSrcAHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclrtMemcpy(srcBDev, kSize, kSrcBHost.data(), kSize, ACL_MEMCPY_HOST_TO_DEVICE));
float alphaCase[3] = {1.0f, 2.0f, 0.5f};
aclScalar* scalarCase0 = aclCreateScalar(&alphaCase[0], aclDataType::ACL_FLOAT);
aclScalar* scalarCase1 = aclCreateScalar(&alphaCase[1], aclDataType::ACL_FLOAT);
aclScalar* scalarCase2 = aclCreateScalar(&alphaCase[2], aclDataType::ACL_FLOAT);
uint64_t wsSizes[3] = {};
aclOpExecutor* executors[3] = {};
void* wsAddrs[3] = {};
aclnnAddGetWorkspaceSize(srcA, srcB, scalarCase0, dst, &wsSizes[0], &executors[0]);
aclnnAddGetWorkspaceSize(srcA, srcB, scalarCase1, dst, &wsSizes[1], &executors[1]);
aclnnAddGetWorkspaceSize(srcA, srcB, scalarCase2, dst, &wsSizes[2], &executors[2]);
for (int i = 0; i < 3; ++i) {
if (wsSizes[i] > 0) {
CHECK_ERROR(aclrtMalloc(&wsAddrs[i], wsSizes[i], ACL_MEM_MALLOC_HUGE_FIRST));
}
}
CHECK_ERROR(aclmdlRICaptureBegin(stream, ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclmdlRICaptureStatus status = ACL_MODEL_RI_CAPTURE_STATUS_NONE;
aclmdlRI parentModelRI = nullptr;
CHECK_ERROR(aclmdlRICaptureGetInfo(stream, &status, &parentModelRI));
aclmdlRICondHandle condHandle = nullptr;
CHECK_ERROR(aclmdlRICondHandleCreate(parentModelRI, 0, ACL_MODEL_RI_COND_HANDLE_ASSIGN_DEFAULT, &condHandle));
uint64_t* condDevPtr = nullptr;
CHECK_ERROR(aclmdlRICondHandleGetCondPtr(condHandle, &condDevPtr));
aclmdlRI caseModels[3] = {};
aclmdlRICondTaskParams params;
params.handle = condHandle;
params.type = ACL_MODEL_RI_COND_TYPE_SWITCH;
params.size = 3;
params.modelRIArray = caseModels;
CHECK_ERROR(aclmdlRIAddCondTask(params, stream, 0));
for (int i = 0; i < 3; ++i) {
CHECK_ERROR(aclmdlRICaptureToModelRIBegin(subStream, caseModels[i], ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL));
aclnnAdd(wsAddrs[i], wsSizes[i], executors[i], subStream);
CHECK_ERROR(aclmdlRICaptureEnd(subStream, &caseModels[i]));
}
CHECK_ERROR(aclmdlRICaptureEnd(stream, &parentModelRI));
vector<float> outHost(kElementCount, 0.0f);
aclrtStream execStream = nullptr;
CHECK_ERROR(aclrtCreateStream(&execStream));
static const uint64_t kCondValues[3] = {0, 1, 2};
for (int i = 0; i < 3; ++i) {
CHECK_ERROR(
aclrtMemcpy(condDevPtr, sizeof(uint64_t), &kCondValues[i], sizeof(uint64_t), ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclmdlRIExecuteAsync(parentModelRI, execStream));
CHECK_ERROR(aclrtSynchronizeStream(execStream));
CHECK_ERROR(aclrtMemcpy(outHost.data(), kSize, dstDev, kSize, ACL_MEMCPY_DEVICE_TO_HOST));
INFO_LOG("SWITCH case %d result:", i);
ModelUtils::PrintArray(outHost);
}
CHECK_ERROR(aclmdlRIDestroy(parentModelRI));
CHECK_ERROR(aclrtDestroyStream(subStream));
CHECK_ERROR(aclrtDestroyStream(stream));
CHECK_ERROR(aclrtDestroyStream(execStream));
CHECK_ERROR(aclDestroyTensor(srcA));
CHECK_ERROR(aclDestroyTensor(srcB));
CHECK_ERROR(aclDestroyTensor(dst));
CHECK_ERROR(aclDestroyScalar(scalarCase0));
CHECK_ERROR(aclDestroyScalar(scalarCase1));
CHECK_ERROR(aclDestroyScalar(scalarCase2));
CHECK_ERROR(aclrtFree(srcADev));
CHECK_ERROR(aclrtFree(srcBDev));
CHECK_ERROR(aclrtFree(dstDev));
for (int i = 0; i < 3; ++i) {
if (wsAddrs[i] != nullptr) {
CHECK_ERROR(aclrtFree(wsAddrs[i]));
}
}
INFO_LOG("SWITCH condition PASSED");
return 0;
}
int32_t main()
{
CHECK_ERROR(aclInit(nullptr));
CHECK_ERROR(aclrtSetDevice(kDeviceId));
aclrtContext context = nullptr;
CHECK_ERROR(aclrtCreateContext(&context, kDeviceId));
TestCondIf();
TestCondWhile();
TestCondSwitch();
INFO_LOG("========== All tests PASSED ==========");
CHECK_ERROR(aclrtDestroyContext(context));
CHECK_ERROR(aclrtResetDeviceForce(kDeviceId));
CHECK_ERROR(aclFinalize());
return 0;
}