* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <cstdint>
#include "acl/acl.h"
#include "utils.h"
namespace {
void UpdateFinalResultOnError(const char *apiName, aclError ret, int32_t &finalResult)
{
if (ret == ACL_SUCCESS) {
return;
}
ERROR_LOG("Operation failed: %s returned error code %d", apiName, static_cast<int32_t>(ret));
finalResult = -1;
}
}
int main()
{
const int32_t deviceId = 0;
const uint32_t branchIndex = 1;
const uint32_t labelCount = 2;
const uint32_t headStreamFlag = 0x00000000U;
aclrtContext context = nullptr;
aclrtStream labelStream = nullptr;
aclrtStream executeStream = nullptr;
aclrtLabel labels[labelCount] = {nullptr, nullptr};
aclrtLabelList labelList = nullptr;
uint32_t *branchIndexDevice = nullptr;
aclmdlRI modelRI = nullptr;
bool aclInitialized = false;
bool deviceSet = false;
bool contextCreated = false;
bool labelStreamCreated = false;
bool executeStreamCreated = false;
bool branchIndexAllocated = false;
bool labelCreated[labelCount] = {false, false};
bool labelListCreated = false;
bool modelRICreated = false;
bool labelStreamBound = false;
const int32_t result = [&]() -> int32_t {
CHECK_ERROR(aclInit(nullptr));
aclInitialized = true;
INFO_LOG("ACL initialized.");
CHECK_ERROR(aclrtSetDevice(deviceId));
deviceSet = true;
INFO_LOG("Device %d selected.", deviceId);
CHECK_ERROR(aclrtCreateContext(&context, deviceId));
contextCreated = true;
INFO_LOG("Context created on device %d.", deviceId);
CHECK_ERROR(aclrtCreateStreamWithConfig(&labelStream, 0x00U, ACL_STREAM_PERSISTENT));
labelStreamCreated = true;
INFO_LOG("Persistent label stream created.");
CHECK_ERROR(aclrtCreateStream(&executeStream));
executeStreamCreated = true;
INFO_LOG("Execute stream created.");
CHECK_ERROR(aclmdlRIBuildBegin(&modelRI, 0x00U));
modelRICreated = true;
INFO_LOG("Model runtime instance build started.");
CHECK_ERROR(aclmdlRIBindStream(modelRI, labelStream, headStreamFlag));
labelStreamBound = true;
INFO_LOG("Persistent label stream bound to the model runtime instance.");
CHECK_ERROR(
aclrtMalloc(reinterpret_cast<void **>(&branchIndexDevice), sizeof(branchIndex), ACL_MEM_MALLOC_HUGE_FIRST));
branchIndexAllocated = true;
INFO_LOG("Allocated device memory for branch index.");
CHECK_ERROR(aclrtMemcpy(
branchIndexDevice,
sizeof(branchIndex),
&branchIndex,
sizeof(branchIndex),
ACL_MEMCPY_HOST_TO_DEVICE));
INFO_LOG("Copied branch index %u from host to device.", branchIndex);
for (uint32_t index = 0; index < labelCount; ++index) {
CHECK_ERROR(aclrtCreateLabel(&labels[index]));
labelCreated[index] = true;
INFO_LOG("Created label %u.", index);
}
CHECK_ERROR(aclrtCreateLabelList(labels, labelCount, &labelList));
labelListCreated = true;
INFO_LOG("Created label list with %u labels.", labelCount);
CHECK_ERROR(aclrtSwitchLabelByIndex(branchIndexDevice, labelCount, labelList, labelStream));
INFO_LOG("Submitted switch-label task with branch index %u.", branchIndex);
for (uint32_t index = 0; index < labelCount; ++index) {
CHECK_ERROR(aclrtSetLabel(labels[index], labelStream));
INFO_LOG("Set label %u on the persistent stream.", index);
}
CHECK_ERROR(aclmdlRIEndTask(modelRI, labelStream));
CHECK_ERROR(aclmdlRIBuildEnd(modelRI, nullptr));
INFO_LOG("Model runtime instance build finished.");
CHECK_ERROR(aclmdlRIExecuteAsync(modelRI, executeStream));
CHECK_ERROR(aclrtSynchronizeStream(executeStream));
INFO_LOG("Switch label executed successfully with branch index %u.", branchIndex);
return 0;
}();
int32_t finalResult = result;
if (labelStreamBound) {
UpdateFinalResultOnError(
"aclmdlRIUnbindStream(modelRI, labelStream)", aclmdlRIUnbindStream(modelRI, labelStream), finalResult);
}
if (modelRICreated) {
UpdateFinalResultOnError("aclmdlRIDestroy(modelRI)", aclmdlRIDestroy(modelRI), finalResult);
}
if (labelListCreated) {
UpdateFinalResultOnError("aclrtDestroyLabelList(labelList)", aclrtDestroyLabelList(labelList), finalResult);
}
if (labelCreated[1]) {
UpdateFinalResultOnError("aclrtDestroyLabel(labels[1])", aclrtDestroyLabel(labels[1]), finalResult);
}
if (labelCreated[0]) {
UpdateFinalResultOnError("aclrtDestroyLabel(labels[0])", aclrtDestroyLabel(labels[0]), finalResult);
}
if (branchIndexAllocated) {
UpdateFinalResultOnError("aclrtFree(branchIndexDevice)", aclrtFree(branchIndexDevice), finalResult);
}
if (executeStreamCreated) {
UpdateFinalResultOnError("aclrtDestroyStream(executeStream)", aclrtDestroyStream(executeStream), finalResult);
}
if (labelStreamCreated) {
UpdateFinalResultOnError("aclrtDestroyStream(labelStream)", aclrtDestroyStream(labelStream), finalResult);
}
if (contextCreated) {
UpdateFinalResultOnError("aclrtDestroyContext(context)", aclrtDestroyContext(context), finalResult);
}
if (deviceSet) {
UpdateFinalResultOnError("aclrtResetDeviceForce(deviceId)", aclrtResetDeviceForce(deviceId), finalResult);
}
if (aclInitialized) {
UpdateFinalResultOnError("aclFinalize()", aclFinalize(), finalResult);
}
if (finalResult == 0) {
INFO_LOG("Run the simple_label sample successfully.");
}
return finalResult;
}