* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* This sample demonstrates kernel loading and execution.
* When the mode is 'simple', all kernel inputs are regular pointer-type arguments,
* requiring users to manually copy input data to the device.
* When the mode is 'placeholder', the kernel inputs include placeholder-type arguments;
* for these arguments, the device address does not need to be specified when adding arguments,
* and input arguments are automatically transferred to the device during kernel launch.
*/
#include <cstddef>
#include <cstdint>
#include <string>
#include "acl/acl.h"
#include "utils.h"
#include "kernel_utils.h"
namespace {
struct KernelBuffers {
uint8_t *xHost = nullptr;
uint8_t *yHost = nullptr;
uint8_t *zHost = nullptr;
uint8_t *xDevice = nullptr;
uint8_t *yDevice = nullptr;
uint8_t *zDevice = nullptr;
};
struct RuntimeResources {
int32_t deviceId = 0;
aclrtStream stream = nullptr;
aclrtBinHandle binHandle = nullptr;
bool aclInitialized = false;
bool deviceSet = false;
bool streamCreated = false;
bool binLoaded = false;
};
void UpdateFinalResultOnError(const char *apiName, aclError ret, int32_t &finalResult)
{
if (ret == ACL_SUCCESS) {
return;
}
ERROR_LOG("Operation failed: %s returned error code %d", apiName, static_cast<int32_t>(ret));
finalResult = -1;
}
int32_t InitializeRuntime(RuntimeResources *runtime)
{
CHECK_ERROR(aclInit(nullptr));
runtime->aclInitialized = true;
CHECK_ERROR(aclrtSetDevice(runtime->deviceId));
runtime->deviceSet = true;
CHECK_ERROR(aclrtCreateStream(&runtime->stream));
runtime->streamCreated = true;
return 0;
}
int32_t AllocateKernelBuffers(size_t inputByteSize, size_t outputByteSize, KernelBuffers *buffers)
{
CHECK_ERROR(aclrtMallocHost(reinterpret_cast<void **>(&buffers->xHost), inputByteSize));
CHECK_ERROR(aclrtMallocHost(reinterpret_cast<void **>(&buffers->yHost), inputByteSize));
CHECK_ERROR(aclrtMallocHost(reinterpret_cast<void **>(&buffers->zHost), outputByteSize));
CHECK_ERROR(aclrtMalloc(reinterpret_cast<void **>(&buffers->xDevice), inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ERROR(aclrtMalloc(reinterpret_cast<void **>(&buffers->yDevice), inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ERROR(aclrtMalloc(reinterpret_cast<void **>(&buffers->zDevice), outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST));
return 0;
}
int32_t PrepareInputData(size_t inputByteSize, const KernelBuffers &buffers)
{
size_t xFileSize = inputByteSize;
if (!kernel::ReadFile("./input/input_x.bin", xFileSize, buffers.xHost, inputByteSize) ||
xFileSize != inputByteSize) {
ERROR_LOG("Read input_x.bin failed or file size is invalid.");
return -1;
}
size_t yFileSize = inputByteSize;
if (!kernel::ReadFile("./input/input_y.bin", yFileSize, buffers.yHost, inputByteSize) ||
yFileSize != inputByteSize) {
ERROR_LOG("Read input_y.bin failed or file size is invalid.");
return -1;
}
CHECK_ERROR(aclrtMemcpy(buffers.xDevice, inputByteSize, buffers.xHost, inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ERROR(aclrtMemcpy(buffers.yDevice, inputByteSize, buffers.yHost, inputByteSize, ACL_MEMCPY_HOST_TO_DEVICE));
return 0;
}
int32_t AppendCommonKernelArgs(aclrtArgsHandle argsHandle, uint8_t *xDevice, uint8_t *yDevice, uint8_t *zDevice)
{
aclrtParamHandle paramHandle1 = nullptr;
aclrtParamHandle paramHandle2 = nullptr;
aclrtParamHandle paramHandle3 = nullptr;
CHECK_ERROR(aclrtKernelArgsAppend(argsHandle, reinterpret_cast<void **>(&xDevice), sizeof(uintptr_t), ¶mHandle1));
CHECK_ERROR(aclrtKernelArgsAppend(argsHandle, reinterpret_cast<void **>(&yDevice), sizeof(uintptr_t), ¶mHandle2));
CHECK_ERROR(aclrtKernelArgsAppend(argsHandle, reinterpret_cast<void **>(&zDevice), sizeof(uintptr_t), ¶mHandle3));
return 0;
}
int32_t ConfigurePlaceholderArgs(aclrtArgsHandle argsHandle)
{
constexpr int32_t TOTAL_LENGTH = 8 * 2048;
constexpr int32_t TILE_NUM = 8;
int32_t *lengthHost = nullptr;
int32_t *numHost = nullptr;
aclrtParamHandle paramHandle4 = nullptr;
aclrtParamHandle paramHandle5 = nullptr;
CHECK_ERROR(aclrtKernelArgsAppendPlaceHolder(argsHandle, ¶mHandle4));
CHECK_ERROR(aclrtKernelArgsAppendPlaceHolder(argsHandle, ¶mHandle5));
CHECK_ERROR(aclrtKernelArgsGetPlaceHolderBuffer(
argsHandle, paramHandle4, sizeof(TOTAL_LENGTH), reinterpret_cast<void **>(&lengthHost)));
CHECK_ERROR(aclrtKernelArgsGetPlaceHolderBuffer(
argsHandle, paramHandle5, sizeof(TILE_NUM), reinterpret_cast<void **>(&numHost)));
*lengthHost = TOTAL_LENGTH;
*numHost = TILE_NUM;
return 0;
}
int32_t BuildKernelArgs(
const std::string &mode,
uint8_t *xDevice,
uint8_t *yDevice,
uint8_t *zDevice,
RuntimeResources *runtime,
aclrtFuncHandle *funcHandle,
aclrtArgsHandle *argsHandle)
{
const bool isPlaceholder = (mode == "placeholder");
const char *filePath = isPlaceholder
? "./out/fatbin/ascendc_kernels_placeholder/ascendc_kernels_placeholder.o"
: "./out/fatbin/ascendc_kernels_simple/ascendc_kernels_simple.o";
CHECK_ERROR(aclrtBinaryLoadFromFile(filePath, nullptr, &runtime->binHandle));
runtime->binLoaded = true;
CHECK_ERROR(aclrtBinaryGetFunction(runtime->binHandle, "add_custom", funcHandle));
CHECK_ERROR(aclrtKernelArgsInit(*funcHandle, argsHandle));
if (AppendCommonKernelArgs(*argsHandle, xDevice, yDevice, zDevice) != 0) {
return -1;
}
if (isPlaceholder && ConfigurePlaceholderArgs(*argsHandle) != 0) {
return -1;
}
CHECK_ERROR(aclrtKernelArgsFinalize(*argsHandle));
return 0;
}
int32_t LaunchKernelAndWriteOutput(
aclrtFuncHandle funcHandle,
aclrtArgsHandle argsHandle,
uint32_t blockDim,
aclrtStream stream,
uint8_t *zDevice,
uint8_t *zHost,
size_t outputByteSize)
{
CHECK_ERROR(aclrtLaunchKernelWithConfig(funcHandle, blockDim, stream, nullptr, argsHandle, nullptr));
CHECK_ERROR(aclrtSynchronizeStream(stream));
CHECK_ERROR(aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, ACL_MEMCPY_DEVICE_TO_HOST));
if (!kernel::WriteFile("./output/output_z.bin", zHost, outputByteSize)) {
ERROR_LOG("Write output_z.bin failed.");
return -1;
}
return 0;
}
void ReleaseKernelResources(RuntimeResources &runtime, KernelBuffers &buffers, int32_t &finalResult)
{
if (runtime.binLoaded) {
UpdateFinalResultOnError(
"aclrtBinaryUnLoad(runtime.binHandle)", aclrtBinaryUnLoad(runtime.binHandle), finalResult);
}
if (buffers.zDevice != nullptr) {
UpdateFinalResultOnError("aclrtFree(buffers.zDevice)", aclrtFree(buffers.zDevice), finalResult);
}
if (buffers.yDevice != nullptr) {
UpdateFinalResultOnError("aclrtFree(buffers.yDevice)", aclrtFree(buffers.yDevice), finalResult);
}
if (buffers.xDevice != nullptr) {
UpdateFinalResultOnError("aclrtFree(buffers.xDevice)", aclrtFree(buffers.xDevice), finalResult);
}
if (buffers.zHost != nullptr) {
UpdateFinalResultOnError("aclrtFreeHost(buffers.zHost)", aclrtFreeHost(buffers.zHost), finalResult);
}
if (buffers.yHost != nullptr) {
UpdateFinalResultOnError("aclrtFreeHost(buffers.yHost)", aclrtFreeHost(buffers.yHost), finalResult);
}
if (buffers.xHost != nullptr) {
UpdateFinalResultOnError("aclrtFreeHost(buffers.xHost)", aclrtFreeHost(buffers.xHost), finalResult);
}
if (runtime.streamCreated) {
UpdateFinalResultOnError(
"aclrtDestroyStreamForce(runtime.stream)", aclrtDestroyStreamForce(runtime.stream), finalResult);
}
if (runtime.deviceSet) {
UpdateFinalResultOnError(
"aclrtResetDeviceForce(runtime.deviceId)", aclrtResetDeviceForce(runtime.deviceId), finalResult);
}
if (runtime.aclInitialized) {
UpdateFinalResultOnError("aclFinalize()", aclFinalize(), finalResult);
}
}
int32_t RunKernelLaunchSample(const std::string &mode)
{
const uint32_t blockDim = 8;
const size_t inputByteSize = 8 * 2048 * sizeof(uint16_t);
const size_t outputByteSize = 8 * 2048 * sizeof(uint16_t);
const int32_t deviceId = 0;
aclrtFuncHandle funcHandle = nullptr;
aclrtArgsHandle argsHandle = nullptr;
RuntimeResources runtime;
runtime.deviceId = deviceId;
KernelBuffers buffers;
const int32_t result = [&]() -> int32_t {
if (InitializeRuntime(&runtime) != 0) {
return -1;
}
if (AllocateKernelBuffers(inputByteSize, outputByteSize, &buffers) != 0) {
return -1;
}
if (PrepareInputData(inputByteSize, buffers) != 0) {
return -1;
}
if (BuildKernelArgs(
mode, buffers.xDevice, buffers.yDevice, buffers.zDevice, &runtime, &funcHandle, &argsHandle) != 0) {
return -1;
}
if (LaunchKernelAndWriteOutput(
funcHandle, argsHandle, blockDim, runtime.stream, buffers.zDevice, buffers.zHost, outputByteSize) != 0) {
return -1;
}
INFO_LOG("Kernel launch sample runs in %s mode.", mode.c_str());
return 0;
}();
int32_t finalResult = result;
ReleaseKernelResources(runtime, buffers, finalResult);
if (finalResult == 0) {
INFO_LOG("Run the launch_kernel sample successfully.");
}
return finalResult;
}
}
int32_t main(int32_t argc, char *argv[])
{
const std::string mode = (argc > 1) ? argv[1] : "simple";
if (mode != "simple" && mode != "placeholder") {
ERROR_LOG("Invalid run mode: %s. Mode must be simple or placeholder.", mode.c_str());
return -1;
}
return RunKernelLaunchSample(mode);
}