* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* @file test_aclnn_log_space.cpp
* @brief LogSpace 算子 aclnn 两段式调用示例
*
* 计算公式:
* out[i] = base ^ (start + i * (end - start) / (steps - 1)), i in [0, steps)
* steps == 1: out[0] = base^start
* steps == 0: 空 Tensor,不下发 Kernel
*
* 本示例依次执行 3 组用例,覆盖 FLOAT / FLOAT16 / BFLOAT16 三种输出 dtype:
* Case1: dtype=FLOAT start=0.0 end=2.0 steps=5 base=10.0 -> [1, 10, 100, 1000, 10000]
* Case2: dtype=FLOAT16 start=-1.0 end=1.0 steps=5 base=2.0 -> [0.5, ~0.707, 1.0, ~1.414, 2.0]
* Case3: dtype=BFLOAT16 start=0.0 end=3.0 steps=4 base=10.0 -> [1, 10, 100, 1000]
*/
#include <cstdint>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <string>
#include "acl/acl.h"
#include "aclnn_log_space.h"
#define CHECK_RET(cond, msg, ...) \
do { \
if (!(cond)) { \
fprintf(stderr, "[ERROR] " msg "\n", ##__VA_ARGS__); \
return 1; \
} \
} while (0)
static float Fp16ToFloat(uint16_t h) {
uint32_t sign = (h >> 15) & 0x1;
uint32_t exp = (h >> 10) & 0x1F;
uint32_t mant = h & 0x3FF;
uint32_t f;
if (exp == 0) {
if (mant == 0) {
f = sign << 31;
} else {
int e = -1;
while ((mant & 0x400) == 0) { mant <<= 1; e -= 1; }
mant &= 0x3FF;
uint32_t fexp = static_cast<uint32_t>(127 + e);
f = (sign << 31) | (fexp << 23) | (mant << 13);
}
} else if (exp == 0x1F) {
f = (sign << 31) | (0xFF << 23) | (mant << 13);
} else {
uint32_t fexp = exp + 127 - 15;
f = (sign << 31) | (fexp << 23) | (mant << 13);
}
float out;
std::memcpy(&out, &f, sizeof(out));
return out;
}
static float Bf16ToFloat(uint16_t b) {
uint32_t bits = static_cast<uint32_t>(b) << 16;
float out;
std::memcpy(&out, &bits, sizeof(out));
return out;
}
struct ExampleCase {
std::string name;
aclDataType dtype;
double start;
double end;
int64_t steps;
double base;
};
static int RunOne(const ExampleCase& c, aclrtStream stream) {
printf("\n========================================\n");
printf("[%s] dtype=%d start=%g end=%g steps=%lld base=%g\n",
c.name.c_str(), (int)c.dtype, c.start, c.end, (long long)c.steps, c.base);
printf("========================================\n");
float start_f = static_cast<float>(c.start);
float end_f = static_cast<float>(c.end);
aclScalar* startScalar = aclCreateScalar(&start_f, ACL_FLOAT);
aclScalar* endScalar = aclCreateScalar(&end_f, ACL_FLOAT);
CHECK_RET(startScalar && endScalar, "aclCreateScalar failed");
size_t elem_size = (c.dtype == ACL_FLOAT) ? sizeof(float) : sizeof(uint16_t);
int64_t shape[1] = { c.steps };
int64_t strides[1] = { 1 };
size_t buf_size = static_cast<size_t>(c.steps) * elem_size;
if (buf_size == 0) buf_size = elem_size;
void* result_dev = nullptr;
aclError ret = aclrtMalloc(&result_dev, buf_size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, "aclrtMalloc result failed: %d", ret);
aclrtMemset(result_dev, buf_size, 0, buf_size);
aclTensor* resultTensor = aclCreateTensor(
shape, 1, c.dtype, strides, 0,
aclFormat::ACL_FORMAT_ND, shape, 1, result_dev);
CHECK_RET(resultTensor != nullptr, "aclCreateTensor failed");
uint64_t workspaceSize = 0;
aclOpExecutor* executor = nullptr;
ret = aclnnLogSpaceGetWorkspaceSize(
startScalar, endScalar, c.steps, c.base, resultTensor,
&workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, "aclnnLogSpaceGetWorkspaceSize failed: %d", ret);
printf(" workspaceSize = %llu\n", (unsigned long long)workspaceSize);
void* workspace = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, "aclrtMalloc workspace failed: %d", ret);
}
ret = aclnnLogSpace(workspace, workspaceSize, executor, stream);
CHECK_RET(ret == ACL_SUCCESS, "aclnnLogSpace failed: %d", ret);
ret = aclrtSynchronizeStream(stream);
CHECK_RET(ret == ACL_SUCCESS, "aclrtSynchronizeStream failed: %d", ret);
std::vector<uint8_t> host_buf(buf_size);
ret = aclrtMemcpy(host_buf.data(), buf_size, result_dev, buf_size,
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, "aclrtMemcpy D2H failed: %d", ret);
printf(" result = [");
for (int64_t i = 0; i < c.steps; ++i) {
float v = 0.0f;
if (c.dtype == ACL_FLOAT) {
v = reinterpret_cast<const float*>(host_buf.data())[i];
} else if (c.dtype == ACL_FLOAT16) {
v = Fp16ToFloat(reinterpret_cast<const uint16_t*>(host_buf.data())[i]);
} else {
v = Bf16ToFloat(reinterpret_cast<const uint16_t*>(host_buf.data())[i]);
}
printf("%s%g", (i == 0 ? "" : ", "), v);
}
printf("]\n");
if (workspace) aclrtFree(workspace);
aclDestroyTensor(resultTensor);
aclrtFree(result_dev);
aclDestroyScalar(startScalar);
aclDestroyScalar(endScalar);
return 0;
}
int main(int , char** ) {
printf("LogSpace aclnn invoke example\n");
int32_t deviceId = 0;
aclError ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, "aclInit failed: %d", ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, "aclrtSetDevice failed: %d", ret);
aclrtStream stream = nullptr;
ret = aclrtCreateStream(&stream);
CHECK_RET(ret == ACL_SUCCESS, "aclrtCreateStream failed: %d", ret);
std::vector<ExampleCase> cases = {
{"Case1_FLOAT", ACL_FLOAT, 0.0, 2.0, 5, 10.0},
{"Case2_FLOAT16", ACL_FLOAT16, -1.0, 1.0, 5, 2.0},
{"Case3_BFLOAT16", ACL_BF16, 0.0, 3.0, 4, 10.0},
};
int failed = 0;
for (const auto& c : cases) {
if (RunOne(c, stream) != 0) failed++;
}
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
printf("\n========================================\n");
printf("All cases: %zu failed: %d\n", cases.size(), failed);
printf("========================================\n");
return failed == 0 ? 0 : 1;
}