* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_aclnn_mhc_pre.cpp
* \brief
*/
#include <iostream>
#include <vector>
#include "acl/acl.h"
#include "aclnnop/aclnn_mhc_pre.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
int64_t size = 1;
for (int64_t dim : shape) {
size *= dim;
}
return size;
}
void PrintTensorDataFloat(const std::vector<int64_t> &shape, void *device_addr)
{
int64_t size = GetShapeSize(shape);
std::vector<float> host_data(size, 0.0f);
aclError ret = aclrtMemcpy(host_data.data(), size * sizeof(float), device_addr, size * sizeof(float),
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Memcpy device to host failed, error: %d\n", ret); return);
LOG_PRINT("Tensor data (first 10 elements): ");
for (int64_t i = 0; i < std::min((int64_t)10, size); ++i) {
LOG_PRINT("%f ", host_data[i]);
}
LOG_PRINT("\n");
}
void PrintTensorDataFloat16(const std::vector<int64_t> &shape, void *device_addr)
{
int64_t size = GetShapeSize(shape);
std::vector<aclFloat16> host_fp16(size);
std::vector<float> host_data(size, 0.0f);
aclError ret = aclrtMemcpy(host_fp16.data(), size * sizeof(aclFloat16), device_addr, size * sizeof(aclFloat16),
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Memcpy device to host failed, error: %d\n", ret); return);
for (int64_t i = 0; i < size; ++i) {
host_data[i] = aclFloat16ToFloat(host_fp16[i]);
}
LOG_PRINT("Tensor data (first 10 elements): ");
for (int64_t i = 0; i < std::min((int64_t)10, size); ++i) {
LOG_PRINT("%f ", host_data[i]);
}
LOG_PRINT("\n");
}
int InitAcl(int32_t device_id, aclrtContext &context, aclrtStream &stream)
{
aclError ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed, error: %d\n", ret); return -1);
ret = aclrtSetDevice(device_id);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed, error: %d\n", ret); return -1);
ret = aclrtCreateContext(&context, device_id);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed, error: %d\n", ret); return -1);
ret = aclrtSetCurrentContext(context);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed, error: %d\n", ret); return -1);
ret = aclrtCreateStream(&stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed, error: %d\n", ret); return -1);
return 0;
}
int CreateAclTensorFloat32(const std::vector<float> &host_data, const std::vector<int64_t> &shape, void *&device_addr,
aclTensor *&tensor)
{
int64_t size = GetShapeSize(shape) * sizeof(float);
aclError ret = aclrtMalloc(&device_addr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed, error: %d\n", ret); return -1);
ret = aclrtMemcpy(device_addr, size, host_data.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed, error: %d\n", ret); return -1);
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
tensor = aclCreateTensor(shape.data(), shape.size(), ACL_FLOAT, strides.data(), 0, ACL_FORMAT_ND, shape.data(),
shape.size(), device_addr);
CHECK_RET(tensor != nullptr, LOG_PRINT("aclCreateTensor failed\n"); return -1);
return 0;
}
int CreateAclTensorFloat16(const std::vector<float> &host_data, const std::vector<int64_t> &shape, void *&device_addr,
aclTensor *&tensor)
{
int64_t size = GetShapeSize(shape);
std::vector<aclFloat16> host_data_fp16(size);
for (int64_t i = 0; i < size; ++i) {
host_data_fp16[i] = aclFloat16(host_data[i]);
}
int64_t byte_size = size * sizeof(aclFloat16);
aclError ret = aclrtMalloc(&device_addr, byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed, error: %d\n", ret); return -1);
ret = aclrtMemcpy(device_addr, byte_size, host_data_fp16.data(), byte_size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed, error: %d\n", ret); return -1);
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
tensor = aclCreateTensor(shape.data(), shape.size(), ACL_FLOAT16, strides.data(), 0, ACL_FORMAT_ND, shape.data(),
shape.size(), device_addr);
CHECK_RET(tensor != nullptr, LOG_PRINT("aclCreateTensor failed\n"); return -1);
return 0;
}
int CreateAclTensorFloat16Output(const std::vector<int64_t> &shape, void *&device_addr, aclTensor *&tensor)
{
int64_t size = GetShapeSize(shape);
int64_t byte_size = size * sizeof(aclFloat16);
aclError ret = aclrtMalloc(&device_addr, byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed, error: %d\n", ret); return -1);
tensor = aclCreateTensor(shape.data(), shape.size(), ACL_FLOAT16, nullptr, 0, ACL_FORMAT_ND, shape.data(),
shape.size(), device_addr);
CHECK_RET(tensor != nullptr, LOG_PRINT("aclCreateTensor failed\n"); return -1);
return 0;
}
int CreateAclTensorFloat32Output(const std::vector<int64_t> &shape, void *&device_addr, aclTensor *&tensor)
{
int64_t size = GetShapeSize(shape);
int64_t byte_size = size * sizeof(float);
aclError ret = aclrtMalloc(&device_addr, byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed, error: %d\n", ret); return -1);
tensor = aclCreateTensor(shape.data(), shape.size(), ACL_FLOAT, nullptr, 0, ACL_FORMAT_ND, shape.data(),
shape.size(), device_addr);
CHECK_RET(tensor != nullptr, LOG_PRINT("aclCreateTensor failed\n"); return -1);
return 0;
}
struct Tensors {
void *x_addr = nullptr, *phi_addr = nullptr, *alpha_addr = nullptr, *bias_addr = nullptr, *gamma_addr = nullptr;
void *hin_addr = nullptr, *h_post_addr = nullptr, *h_res_addr = nullptr, *inv_rms_addr = nullptr;
void *h_mix_addr = nullptr, *h_pre_addr = nullptr;
aclTensor *x = nullptr, *phi = nullptr, *alpha = nullptr, *bias = nullptr, *gamma = nullptr;
aclTensor *hin = nullptr, *h_post = nullptr, *h_res = nullptr, *inv_rms = nullptr;
aclTensor *h_mix = nullptr, *h_pre = nullptr;
};
int CreateInputTensors(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &phi_shape,
const std::vector<int64_t> &alpha_shape, const std::vector<int64_t> &bias_shape,
const std::vector<int64_t> &gamma_shape, Tensors &tensors)
{
std::vector<float> x_host_data(GetShapeSize(x_shape), 1.0f);
std::vector<float> phi_host_data(GetShapeSize(phi_shape), 1.0f);
std::vector<float> alpha_host_data(3, 1.0f);
std::vector<float> bias_host_data(GetShapeSize(bias_shape), 1.0f);
std::vector<float> gamma_host_data(GetShapeSize(gamma_shape), 1.0f);
int ret = CreateAclTensorFloat16(x_host_data, x_shape, tensors.x_addr, tensors.x);
CHECK_RET(ret == 0, LOG_PRINT("Create x_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32(phi_host_data, phi_shape, tensors.phi_addr, tensors.phi);
CHECK_RET(ret == 0, LOG_PRINT("Create phi_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32(alpha_host_data, alpha_shape, tensors.alpha_addr, tensors.alpha);
CHECK_RET(ret == 0, LOG_PRINT("Create alpha_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32(bias_host_data, bias_shape, tensors.bias_addr, tensors.bias);
CHECK_RET(ret == 0, LOG_PRINT("Create bias_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32(gamma_host_data, gamma_shape, tensors.gamma_addr, tensors.gamma);
CHECK_RET(ret == 0, LOG_PRINT("Create gamma_tensor failed\n"); return -1);
return 0;
}
int CreateOutputTensors(const std::vector<int64_t> &hin_shape, const std::vector<int64_t> &h_post_shape,
const std::vector<int64_t> &h_res_shape, const std::vector<int64_t> &inv_rms_shape,
const std::vector<int64_t> &h_mix_shape, const std::vector<int64_t> &h_pre_shape,
Tensors &tensors)
{
int ret = CreateAclTensorFloat16Output(hin_shape, tensors.hin_addr, tensors.hin);
CHECK_RET(ret == 0, LOG_PRINT("Create hin_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32Output(h_post_shape, tensors.h_post_addr, tensors.h_post);
CHECK_RET(ret == 0, LOG_PRINT("Create h_post_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32Output(h_res_shape, tensors.h_res_addr, tensors.h_res);
CHECK_RET(ret == 0, LOG_PRINT("Create h_res_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32Output(inv_rms_shape, tensors.inv_rms_addr, tensors.inv_rms);
CHECK_RET(ret == 0, LOG_PRINT("Create inv_rms_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32Output(h_mix_shape, tensors.h_mix_addr, tensors.h_mix);
CHECK_RET(ret == 0, LOG_PRINT("Create h_mix_tensor failed\n"); return -1);
ret = CreateAclTensorFloat32Output(h_pre_shape, tensors.h_pre_addr, tensors.h_pre);
CHECK_RET(ret == 0, LOG_PRINT("Create h_pre_tensor failed\n"); return -1);
return 0;
}
void DestroyTensors(Tensors &tensors)
{
aclDestroyTensor(tensors.x);
aclDestroyTensor(tensors.phi);
aclDestroyTensor(tensors.alpha);
aclDestroyTensor(tensors.bias);
aclDestroyTensor(tensors.gamma);
aclDestroyTensor(tensors.hin);
aclDestroyTensor(tensors.h_post);
aclDestroyTensor(tensors.h_res);
aclDestroyTensor(tensors.inv_rms);
aclDestroyTensor(tensors.h_mix);
aclDestroyTensor(tensors.h_pre);
}
void FreeDeviceMemory(Tensors &tensors)
{
aclrtFree(tensors.x_addr);
aclrtFree(tensors.phi_addr);
aclrtFree(tensors.alpha_addr);
aclrtFree(tensors.bias_addr);
aclrtFree(tensors.gamma_addr);
aclrtFree(tensors.hin_addr);
aclrtFree(tensors.h_post_addr);
aclrtFree(tensors.h_res_addr);
aclrtFree(tensors.inv_rms_addr);
aclrtFree(tensors.h_mix_addr);
aclrtFree(tensors.h_pre_addr);
}
int main()
{
int32_t device_id = 0;
aclrtContext context = nullptr;
aclrtStream stream = nullptr;
Tensors tensors;
int B = 1, S = 2048, n = 4, D = 2560;
std::vector<int64_t> x_shape = {B * S, n, D}, phi_shape = {n * n + 2 * n, n * D}, alpha_shape = {3},
bias_shape = {n * n + 2 * n}, gamma_shape = {n, D};
std::vector<int64_t> hin_shape = {B * S, D}, h_post_shape = {B * S, n}, h_res_shape = {B * S, n, n},
inv_rms_shape = {B * S}, h_mix_shape = {B * S, n * n + 2 * n}, h_pre_shape = {B * S, n};
int ret = InitAcl(device_id, context, stream);
CHECK_RET(ret == 0, LOG_PRINT("InitAcl failed, error: %d\n", ret); return -1);
ret = CreateInputTensors(x_shape, phi_shape, alpha_shape, bias_shape, gamma_shape, tensors);
CHECK_RET(ret == 0, return -1);
ret = CreateOutputTensors(hin_shape, h_post_shape, h_res_shape, inv_rms_shape, h_mix_shape, h_pre_shape, tensors);
CHECK_RET(ret == 0, return -1);
uint64_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
aclnnStatus aclnn_ret = aclnnMhcPreGetWorkspaceSize(
tensors.x, tensors.phi, tensors.alpha, tensors.bias, tensors.gamma, 1e-6, 1e-6, tensors.hin, tensors.h_post,
tensors.h_res, tensors.inv_rms, tensors.h_mix, tensors.h_pre, &workspace_size, &executor);
CHECK_RET(aclnn_ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPreGetWorkspaceSize failed, error: %d\n", aclnn_ret);
return -1);
void *workspace_addr = nullptr;
if (workspace_size > 0) {
ret = aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc workspace failed, error: %d\n", ret); return -1);
}
aclnn_ret = aclnnMhcPre(workspace_addr, workspace_size, executor, stream);
CHECK_RET(aclnn_ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPre failed, error: %d\n", aclnn_ret); return -1);
CHECK_RET(aclrtSynchronizeStream(stream) == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed\n"); return -1);
LOG_PRINT("MhcPre compute success!\nOutput tensor data: \n");
PrintTensorDataFloat16(hin_shape, tensors.hin_addr);
PrintTensorDataFloat(h_post_shape, tensors.h_post_addr);
PrintTensorDataFloat(h_res_shape, tensors.h_res_addr);
PrintTensorDataFloat(inv_rms_shape, tensors.inv_rms_addr);
PrintTensorDataFloat(h_mix_shape, tensors.h_mix_addr);
PrintTensorDataFloat(h_pre_shape, tensors.h_pre_addr);
DestroyTensors(tensors);
FreeDeviceMemory(tensors);
if (workspace_size > 0)
aclrtFree(workspace_addr);
aclrtDestroyStream(stream);
aclrtDestroyContext(context);
aclrtResetDevice(device_id);
aclFinalize();
LOG_PRINT("All resources released successfully!\n");
return 0;
}