* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_aclnn_flash_attn.cpp
* \brief FlashAttn + FlashAttnMetadata 算子调用示例
*
*/
#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
#include <random>
#include <algorithm>
#include <cstdio>
#include <string>
#include <sstream>
#include "acl/acl.h"
#include "../op_api/aclnn_flash_attn.h"
#include "../../flash_attn_metadata/op_host/op_api/aclnn_flash_attn_metadata.h"
namespace {
static float fp16_to_float(uint16_t h) {
uint32_t sign = (h >> 15) & 0x1u;
uint32_t exp = (h >> 10) & 0x1fu;
uint32_t mant = h & 0x3ffu;
uint32_t f;
if (exp == 0) {
f = (sign << 31) | (mant << 13);
} else if (exp == 31) {
f = (sign << 31) | 0x7f800000u | (mant << 13);
} else {
f = (sign << 31) | ((exp + 127u - 15u) << 23) | (mant << 13);
}
float result;
std::memcpy(&result, &f, sizeof(result));
return result;
}
#define CHECK_RET(cond) ((cond) ? true : (false))
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
int Init(int32_t deviceId, aclrtStream *stream)
{
auto ret = aclInit(nullptr);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclInit failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtSetDevice(deviceId);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtCreateStream(stream);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret);
return ret;
}
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, aclTensor **tensor)
{
auto size = GetShapeSize(shape) * sizeof(T);
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret);
return ret;
}
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
static int CreateAclTensorDeviceZero(const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, size_t elemSize, aclTensor **tensor)
{
auto size = GetShapeSize(shape) * (int64_t)elemSize;
auto ret = aclrtMalloc(deviceAddr, (size_t)size, ACL_MEM_MALLOC_HUGE_FIRST);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMalloc (device zero) failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtMemset(*deviceAddr, (size_t)size, 0, (size_t)size);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMemset failed. ERROR: %d\n", ret);
return ret;
}
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = (int64_t)shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0,
aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr);
return 0;
}
static uint16_t float_to_fp16(float f)
{
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
uint32_t sign = (bits >> 31) & 0x1u;
int32_t exp = (int32_t)((bits >> 23) & 0xffu) - 127 + 15;
uint32_t mant = (bits >> 13) & 0x3ffu;
if (exp <= 0) return (uint16_t)(sign << 15);
if (exp >= 31) return (uint16_t)((sign << 15) | 0x7c00u);
return (uint16_t)((sign << 15) | ((uint32_t)exp << 10) | mant);
}
static void cpu_flash_attention(
const std::vector<float> &Q,
const std::vector<float> &K,
const std::vector<float> &V,
std::vector<float> &Out,
int64_t B, int64_t N_q, int64_t N_kv,
int64_t S_q, int64_t S_kv, int64_t D,
int64_t maskMode = 0)
{
float scale = 1.0f / std::sqrt((float)D);
int64_t causalOffset = (maskMode == 1) ? (S_kv - S_q) : (S_kv - 1);
std::vector<float> score(S_q * S_kv);
for (int64_t b = 0; b < B; b++) {
for (int64_t nq = 0; nq < N_q; nq++) {
int64_t nkv = nq * N_kv / N_q;
for (int64_t s1 = 0; s1 < S_q; s1++) {
for (int64_t s2 = 0; s2 < S_kv; s2++) {
if (maskMode == 1 && s2 > s1 + causalOffset) {
score[s1 * S_kv + s2] = -1e9f;
continue;
}
float dot = 0.0f;
for (int64_t d = 0; d < D; d++) {
dot += Q[((b * N_q + nq) * S_q + s1) * D + d] *
K[((b * N_kv + nkv) * S_kv + s2) * D + d];
}
score[s1 * S_kv + s2] = dot * scale;
}
}
for (int64_t s1 = 0; s1 < S_q; s1++) {
float mx = *std::max_element(score.data() + s1 * S_kv,
score.data() + s1 * S_kv + S_kv);
float sm = 0.0f;
for (int64_t s2 = 0; s2 < S_kv; s2++) {
score[s1 * S_kv + s2] = std::exp(score[s1 * S_kv + s2] - mx);
sm += score[s1 * S_kv + s2];
}
for (int64_t s2 = 0; s2 < S_kv; s2++)
score[s1 * S_kv + s2] /= sm;
}
for (int64_t s1 = 0; s1 < S_q; s1++) {
for (int64_t d = 0; d < D; d++) {
float acc = 0.0f;
for (int64_t s2 = 0; s2 < S_kv; s2++) {
acc += score[s1 * S_kv + s2] *
V[((b * N_kv + nkv) * S_kv + s2) * D + d];
}
Out[((b * N_q + nq) * S_q + s1) * D + d] = acc;
}
}
}
}
}
static int CreateAclTensorRandom(const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, aclTensor **tensor,
std::mt19937 &rng,
float randMin = 0.0f, float randMax = 1.0f,
std::vector<float> *floatRef = nullptr,
int64_t refElems = 0)
{
const int64_t CHUNK = 1 << 20;
int64_t total = GetShapeSize(shape);
auto ret = aclrtMalloc(deviceAddr, (size_t)total * sizeof(uint16_t), ACL_MEM_MALLOC_HUGE_FIRST);
if (!CHECK_RET(ret == ACL_SUCCESS)) { LOG_PRINT("aclrtMalloc random failed. ERROR: %d\n", ret); return ret; }
if (floatRef && refElems > 0) floatRef->resize(refElems);
std::uniform_real_distribution<float> dist(randMin, randMax);
std::vector<uint16_t> chunk(std::min(CHUNK, total));
int64_t offset = 0;
while (offset < total) {
int64_t count = std::min(CHUNK, total - offset);
for (int64_t i = 0; i < count; ++i) {
float v = dist(rng);
uint16_t h = float_to_fp16(v);
chunk[i] = h;
if (floatRef && offset + i < refElems)
(*floatRef)[offset + i] = fp16_to_float(h);
}
ret = aclrtMemcpy(reinterpret_cast<uint8_t*>(*deviceAddr) + offset * sizeof(uint16_t),
(size_t)count * sizeof(uint16_t), chunk.data(),
(size_t)count * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE);
if (!CHECK_RET(ret == ACL_SUCCESS)) { LOG_PRINT("aclrtMemcpy chunk failed. ERROR: %d\n", ret); return ret; }
offset += count;
}
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = (int64_t)shape.size() - 2; i >= 0; --i)
strides[i] = shape[i + 1] * strides[i + 1];
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0,
aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr);
return 0;
}
static constexpr int64_t METADATA_MAX_ELEMS = 8192;
static constexpr uint32_t META_AIC = 36, META_AIV = 72;
static constexpr uint32_t META_FA_SLOT = 16, META_FD_SLOT = 16, META_HDR = 16;
static int CreateMetadataTensor(void **deviceAddr, aclTensor **tensor)
{
size_t size = METADATA_MAX_ELEMS * sizeof(uint32_t);
int ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_SUCCESS) { LOG_PRINT("aclrtMalloc metadata failed. ERROR: %d\n", ret); return ret; }
ret = aclrtMemset(*deviceAddr, size, 0, size);
if (ret != ACL_SUCCESS) { LOG_PRINT("aclrtMemset metadata failed. ERROR: %d\n", ret); return ret; }
int64_t shape[1] = {METADATA_MAX_ELEMS};
int64_t strides[1] = {1};
*tensor = aclCreateTensor(shape, 1, aclDataType::ACL_INT32, strides, 0,
aclFormat::ACL_FORMAT_ND, shape, 1, *deviceAddr);
return ACL_SUCCESS;
}
static void PrintMetadata(const uint32_t *meta, int64_t totalElems)
{
if (totalElems < (int64_t)META_HDR) {
LOG_PRINT("[PrintMetadata] buffer 太小 (%ld 元素)\n", totalElems);
return;
}
uint32_t sectionNum = meta[0];
uint32_t mBaseSize = meta[1];
uint32_t s2BaseSize = meta[2];
static const char *FA_FIELDS[7] = {
"BN2_START", "M_START", "S2_START",
"BN2_END", "M_END", "S2_END", "FIRST_FD_WS"
};
static const char *FD_FIELDS[6] = {
"BN2_IDX", "M_IDX", "WS_IDX", "WS_NUM", "M_START", "M_NUM"
};
const char *SEP = "==============================================================================";
const char *DASH = "------------------------------------------------------------------------------";
LOG_PRINT("\n%s\n", SEP);
LOG_PRINT(" [Metadata Header]\n");
LOG_PRINT(" sectionNum = %u\n", sectionNum);
LOG_PRINT(" mBaseSize = %u\n", mBaseSize);
LOG_PRINT(" s2BaseSize = %u\n", s2BaseSize);
for (uint32_t sec = 0; sec < sectionNum; sec++) {
LOG_PRINT("\n%s\n", DASH);
LOG_PRINT(" [Section %u] FA Metadata — AIC cores (36 × 16 slots, 7 字段有效)\n", sec);
LOG_PRINT("%s\n", DASH);
LOG_PRINT(" %-8s", "Core");
for (int f = 0; f < 7; f++) LOG_PRINT("%14s", FA_FIELDS[f]);
LOG_PRINT("\n %-8s", "");
for (int f = 0; f < 7; f++) {
char buf[8]; snprintf(buf, sizeof(buf), "[%d]", f);
LOG_PRINT("%14s", buf);
}
LOG_PRINT("\n %s\n", "--------" "--------------------------------------------------------------"
"----------------------------------------------------------");
for (uint32_t core = 0; core < META_AIC; core++) {
uint32_t off = META_HDR + sec * META_AIC * META_FA_SLOT + core * META_FA_SLOT;
uint32_t vals[7];
bool allZero = true;
for (int f = 0; f < 7; f++) {
vals[f] = meta[off + f];
if (vals[f]) allZero = false;
}
LOG_PRINT(" AIC%02u ", core);
for (int f = 0; f < 7; f++) LOG_PRINT("%14u", vals[f]);
if (allZero) LOG_PRINT(" (inactive)");
LOG_PRINT("\n");
}
}
for (uint32_t sec = 0; sec < sectionNum; sec++) {
LOG_PRINT("\n%s\n", DASH);
LOG_PRINT(" [Section %u] FD Metadata — 活跃 AIV cores (M_NUM>0, 72 × 16 slots, 6 字段有效)\n", sec);
LOG_PRINT("%s\n", DASH);
LOG_PRINT(" %-8s", "Core");
for (int f = 0; f < 6; f++) LOG_PRINT("%12s", FD_FIELDS[f]);
LOG_PRINT("\n %-8s", "");
for (int f = 0; f < 6; f++) {
char buf[8]; snprintf(buf, sizeof(buf), "[%d]", f);
LOG_PRINT("%12s", buf);
}
LOG_PRINT("\n %s\n",
"--------" "------------------------------------------------------------"
"----------------------------------");
uint32_t fdBase = META_HDR + sectionNum * META_AIC * META_FA_SLOT;
int active = 0;
for (uint32_t core = 0; core < META_AIV; core++) {
uint32_t off = fdBase + sec * META_AIV * META_FD_SLOT + core * META_FD_SLOT;
if (meta[off + 5] == 0) continue;
LOG_PRINT(" AIV%02u ", core);
for (int f = 0; f < 6; f++) LOG_PRINT("%12u", meta[off + f]);
LOG_PRINT("\n");
++active;
}
if (active == 0) LOG_PRINT(" (no active FD cores)\n");
}
LOG_PRINT("%s\n\n", SEP);
}
static void SaveFloatToTxt(const char *path, const float *data, int64_t n,
const char *header = nullptr)
{
FILE *f = fopen(path, "w");
if (!f) { LOG_PRINT("[WARN] 无法创建文件: %s\n", path); return; }
if (header) fprintf(f, "# %s\n", header);
for (int64_t i = 0; i < n; ++i) fprintf(f, "%.8f\n", data[i]);
fclose(f);
}
}
struct TestCase {
const char *name;
int64_t B;
int64_t N_q;
int64_t N_kv;
int64_t S_q;
int64_t S_kv;
int64_t D;
int64_t maskMode;
int64_t winLeft;
int64_t winRight;
float qMin, qMax;
float kMin, kMax;
float vMin, vMax;
};
static constexpr float RATIO_THRESHOLD = 0.005f;
static constexpr float FAIL_RATIO_LIMIT = 0.005f;
static bool RunTestCase(const TestCase &tc, aclrtStream stream,
bool dumpOutput, bool verbose, uint32_t seed)
{
const char *SEP64 = "──────────────────────────────────────────────────────────────────";
LOG_PRINT("\n┌%s┐\n", SEP64);
LOG_PRINT("│ Case: %-10s B=%ld N_q=%ld N_kv=%ld S_q=%ld S_kv=%ld D=%ld "
"mask=%ld\n",
tc.name, tc.B, tc.N_q, tc.N_kv, tc.S_q, tc.S_kv, tc.D, tc.maskMode);
LOG_PRINT("└%s┘\n", SEP64);
int64_t B = tc.B, N_q = tc.N_q, N_kv = tc.N_kv;
int64_t S_q = tc.S_q, S_kv = tc.S_kv, D = tc.D;
std::vector<int64_t> qShape = {B, N_q, S_q, D};
std::vector<int64_t> kShape = {B, N_kv, S_kv, D};
std::vector<int64_t> vShape = {B, N_kv, S_kv, D};
std::vector<int64_t> outShape = {B, N_q, S_q, D};
int64_t qElem = GetShapeSize(qShape), kElem = GetShapeSize(kShape);
int64_t outElem = GetShapeSize(outShape); (void)GetShapeSize(vShape);
int64_t qRefElems = N_q * S_q * D;
int64_t kvRefElems = N_kv * S_kv * D;
LOG_PRINT(" Tensors: Q[%ld,%ld,%ld,%ld]=%.2fMB K/V[%ld,%ld,%ld,%ld]=%.2fMB\n",
B,N_q,S_q,D, qElem*2.0/1e6, B,N_kv,S_kv,D, kElem*2.0/1e6);
std::mt19937 rng(seed ? seed : static_cast<uint32_t>(std::random_device{}()));
std::vector<float> qRef, kRef, vRef;
void *qDev=nullptr, *kDev=nullptr, *vDev=nullptr, *outDev=nullptr;
void *metaDev = nullptr;
aclTensor *qT=nullptr, *kT=nullptr, *vT=nullptr, *outT=nullptr, *metaT=nullptr;
auto cleanup = [&]() {
if (qT) aclDestroyTensor(qT);
if (kT) aclDestroyTensor(kT);
if (vT) aclDestroyTensor(vT);
if (outT) aclDestroyTensor(outT);
if (metaT) aclDestroyTensor(metaT);
if (qDev) aclrtFree(qDev);
if (kDev) aclrtFree(kDev);
if (vDev) aclrtFree(vDev);
if (outDev) aclrtFree(outDev);
if (metaDev) aclrtFree(metaDev);
};
{
std::uniform_real_distribution<float> dist(tc.qMin, tc.qMax);
std::vector<uint16_t> h(qElem);
qRef.resize(qRefElems);
for (int64_t i = 0; i < qElem; ++i) {
uint16_t fp = float_to_fp16(dist(rng));
h[i] = fp;
if (i < qRefElems) qRef[i] = fp16_to_float(fp);
}
if (CreateAclTensor(h, qShape, &qDev, aclDataType::ACL_FLOAT16, &qT) != ACL_SUCCESS)
{ cleanup(); return false; }
}
if (CreateAclTensorRandom(kShape, &kDev, aclDataType::ACL_FLOAT16, &kT,
rng, tc.kMin, tc.kMax, &kRef, kvRefElems) != ACL_SUCCESS)
{ cleanup(); return false; }
if (CreateAclTensorRandom(vShape, &vDev, aclDataType::ACL_FLOAT16, &vT,
rng, tc.vMin, tc.vMax, &vRef, kvRefElems) != ACL_SUCCESS)
{ cleanup(); return false; }
if (CreateAclTensorDeviceZero(outShape, &outDev, aclDataType::ACL_FLOAT16,
sizeof(uint16_t), &outT) != ACL_SUCCESS)
{ cleanup(); return false; }
if (CreateMetadataTensor(&metaDev, &metaT) != ACL_SUCCESS)
{ cleanup(); return false; }
double softmaxScale = 1.0 / std::sqrt((double)D);
int64_t winLeft = tc.winLeft;
int64_t winRight = tc.winRight;
LOG_PRINT("\n [A] 调用 flash_attn_metadata 算子...\n");
{
uint64_t wsSize = 0;
aclOpExecutor *exec = nullptr;
int ret = aclnnFlashAttnMetadataGetWorkspaceSize(
nullptr,
nullptr,
nullptr,
nullptr,
B, S_q, S_kv, N_q, N_kv, D,
tc.maskMode, winLeft, winRight,
"BNSD", "BNSD", "BNSD",
metaT, &wsSize, &exec);
if (ret != ACL_SUCCESS) {
LOG_PRINT(" aclnnFlashAttnMetadataGetWorkspaceSize FAILED. ret=%d\n", ret);
cleanup(); return false;
}
void *ws = nullptr;
if (wsSize > 0) {
aclrtMalloc(&ws, wsSize, ACL_MEM_MALLOC_HUGE_FIRST);
}
ret = aclnnFlashAttnMetadata(ws, wsSize, exec, stream);
if (ret != ACL_SUCCESS) {
LOG_PRINT(" aclnnFlashAttnMetadata FAILED. ret=%d\n", ret);
if (ws) aclrtFree(ws);
cleanup(); return false;
}
aclrtSynchronizeStream(stream);
if (ws) aclrtFree(ws);
}
std::vector<uint32_t> metaHost(METADATA_MAX_ELEMS, 0);
aclrtMemcpy(metaHost.data(), METADATA_MAX_ELEMS * sizeof(uint32_t),
metaDev, METADATA_MAX_ELEMS * sizeof(uint32_t),
ACL_MEMCPY_DEVICE_TO_HOST);
LOG_PRINT("\n ▼ flash_attn_metadata 输出(sectionNum=%u mBase=%u s2Base=%u)\n",
metaHost[0], metaHost[1], metaHost[2]);
PrintMetadata(metaHost.data(), METADATA_MAX_ELEMS);
LOG_PRINT(" [B] 调用 flash_attn 算子(使用预计算 metadata)...\n");
{
uint64_t wsSize = 0;
aclOpExecutor *exec = nullptr;
int ret = aclnnFlashAttnGetWorkspaceSize(
qT, kT, vT,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
metaT,
softmaxScale,
(int32_t)tc.maskMode, (int32_t)winLeft, (int32_t)winRight,
(int32_t)S_q, (int32_t)S_kv,
"BNSD", "BNSD", "BNSD",
(int32_t)0,
outT,
nullptr,
&wsSize, &exec);
if (ret != ACL_SUCCESS) {
LOG_PRINT(" aclnnFlashAttnGetWorkspaceSize FAILED. ret=%d\n", ret);
cleanup(); return false;
}
void *ws = nullptr;
if (wsSize > 0) aclrtMalloc(&ws, wsSize, ACL_MEM_MALLOC_HUGE_FIRST);
ret = aclnnFlashAttn(ws, wsSize, exec, stream);
if (ret != ACL_SUCCESS) {
LOG_PRINT(" aclnnFlashAttn FAILED. ret=%d\n", ret);
if (ws) aclrtFree(ws);
cleanup(); return false;
}
aclrtSynchronizeStream(stream);
if (ws) aclrtFree(ws);
LOG_PRINT(" flash_attn 执行完毕。\n");
}
std::vector<uint16_t> outHost(outElem, 0);
aclrtMemcpy(outHost.data(), (size_t)outElem * sizeof(uint16_t),
outDev, (size_t)outElem * sizeof(uint16_t), ACL_MEMCPY_DEVICE_TO_HOST);
LOG_PRINT("\n [C] CPU 参考计算(batch 0, mask=%ld)...\n", tc.maskMode);
std::vector<float> cpuOut(qRefElems, 0.0f);
cpu_flash_attention(qRef, kRef, vRef, cpuOut, 1, N_q, N_kv, S_q, S_kv, D, tc.maskMode);
double maxAbsErr=0, sumAbsErr=0, maxRelErr=0, sumRelErr=0;
int64_t failCount = 0;
for (int64_t i = 0; i < qRefElems; ++i) {
float npu = fp16_to_float(outHost[i]);
float cpu = cpuOut[i];
float ae = std::abs(npu - cpu);
float ref = std::abs(cpu) + 1e-6f;
float re = ae / ref;
float thr = std::max(ref * RATIO_THRESHOLD, 0.000025f);
if (ae > maxAbsErr) maxAbsErr = ae;
if (re > maxRelErr) maxRelErr = re;
sumAbsErr += ae;
sumRelErr += re;
if (ae > thr) ++failCount;
}
double meanAbsErr = sumAbsErr / (double)qRefElems;
double meanRelErr = sumRelErr / (double)qRefElems;
double failRatio = (double)failCount / (double)qRefElems;
bool passed = (failRatio <= (double)FAIL_RATIO_LIMIT);
LOG_PRINT("\n┌%s┐\n", SEP64);
LOG_PRINT("│ 精度报告: %-10s (batch 0, %ld 元素)\n", tc.name, qRefElems);
LOG_PRINT("├%s┤\n", SEP64);
LOG_PRINT("│ Shape : [%ld,%ld,%ld,%ld]\n", 1L, N_q, S_q, D);
LOG_PRINT("│ MaxAbsErr : %.8f\n", maxAbsErr);
LOG_PRINT("│ MeanAbsErr : %.8f\n", meanAbsErr);
LOG_PRINT("│ MaxRelErr : %.8f\n", maxRelErr);
LOG_PRINT("│ MeanRelErr : %.8f\n", meanRelErr);
LOG_PRINT("│ FailElems : %ld / %ld (%.4f%%)\n",
failCount, qRefElems, failRatio * 100.0);
LOG_PRINT("│ Threshold : elemDynamic failRatio<=%.2f%%\n",
FAIL_RATIO_LIMIT * 100.0f);
LOG_PRINT("│ 结论 : %s\n", passed ? "✓ PASS" : "✗ FAIL");
if (failCount > 0) {
LOG_PRINT("├%s┤\n", SEP64);
int64_t printMax = verbose ? failCount : std::min(failCount, (int64_t)10);
LOG_PRINT("│ %s%ld 个超阈值元素:\n", verbose ? "全部 " : "前", printMax);
LOG_PRINT("│ %8s %14s %14s %12s %12s\n",
"idx", "NPU", "CPU", "absErr", "relErr");
int64_t printed = 0;
for (int64_t i = 0; i < qRefElems && printed < printMax; ++i) {
float npu = fp16_to_float(outHost[i]);
float cpu = cpuOut[i];
float ae = std::abs(npu - cpu);
float ref = std::abs(cpu) + 1e-6f;
float thr = std::max(ref * RATIO_THRESHOLD, 0.000025f);
if (ae > thr) {
LOG_PRINT("│ %8ld %+14.8f %+14.8f %12.8f %12.6f\n",
i, (double)npu, (double)cpu, (double)ae, (double)(ae/ref));
++printed;
}
}
}
LOG_PRINT("└%s┘\n", SEP64);
if (dumpOutput) {
char buf[256];
std::vector<float> npuFloat(qRefElems);
for (int64_t i = 0; i < qRefElems; ++i) npuFloat[i] = fp16_to_float(outHost[i]);
snprintf(buf, sizeof(buf), "fa_npu_out_%s.txt", tc.name);
char hdr[256];
snprintf(hdr, sizeof(hdr), "shape=[1,%ld,%ld,%ld] case=%s elems=%ld",
N_q, S_q, D, tc.name, qRefElems);
SaveFloatToTxt(buf, npuFloat.data(), qRefElems, hdr);
LOG_PRINT(" NPU output saved: %s\n", buf);
snprintf(buf, sizeof(buf), "fa_cpu_ref_%s.txt", tc.name);
SaveFloatToTxt(buf, cpuOut.data(), qRefElems, hdr);
LOG_PRINT(" CPU ref saved: %s\n", buf);
snprintf(buf, sizeof(buf), "fa_prec_cmp_%s.txt", tc.name);
FILE *f = fopen(buf, "w");
if (f) {
fprintf(f, "case=%s B=%ld N_q=%ld N_kv=%ld S_q=%ld S_kv=%ld D=%ld mask=%ld\n",
tc.name, B, N_q, N_kv, S_q, S_kv, D, tc.maskMode);
fprintf(f, "maxAbsErr=%.8f meanAbsErr=%.8f maxRelErr=%.8f "
"failElems=%ld/%ld %s\n",
maxAbsErr, meanAbsErr, maxRelErr,
failCount, qRefElems, passed ? "PASS" : "FAIL");
fprintf(f, "%-8s %-14s %-14s %-14s %-12s %s\n",
"idx", "npu", "cpu", "absErr", "relErr", "result");
for (int64_t i = 0; i < qRefElems; ++i) {
float npu = fp16_to_float(outHost[i]);
float ae = std::abs(npu - cpuOut[i]);
float ref = std::abs(cpuOut[i]) + 1e-6f;
float thr = std::max(ref * RATIO_THRESHOLD, 0.000025f);
fprintf(f, "%-8ld %+.8f %+.8f %.8f %.6f %s\n",
i, (double)npu, (double)cpuOut[i], (double)ae,
(double)(ae/ref), ae > thr ? "FAIL" : "pass");
}
fclose(f);
LOG_PRINT(" Report saved: %s\n", buf);
}
}
cleanup();
return passed;
}
static void PrintUsage(const char *prog)
{
printf(
"用法: %s [选项]\n"
"\n"
"形状参数(对标 test_case.py 各字段):\n"
" --B <int> batch_size (默认 1)\n"
" --N1 <int> query head 数 (默认 1)\n"
" --N2 <int> kv head 数;N2<N1 为 GQA/MQA (默认 = N1)\n"
" --S1 <int> query 序列长度;S1=1 触发 FD 模式 (默认 64)\n"
" --S2 <int> key/value 序列长度 (默认 256)\n"
" --D <int> head_dim (默认 128)\n"
"\n"
"掩码参数:\n"
" --mask <int> mask_mode: 0=全量 1=因果下三角 (默认 1)\n"
" --wL <int> win_left / pre_tokens;-1=无限制 (默认 -1)\n"
" --wR <int> win_right / next_tokens;-1=无限制(默认 -1)\n"
"\n"
"QKV 值域参数(fp16 均匀随机分布):\n"
" --qMin <float> Q 下界 (默认 0.0)\n"
" --qMax <float> Q 上界 (默认 1.0)\n"
" --kMin <float> K 下界 (默认 0.0)\n"
" --kMax <float> K 上界 (默认 1.0)\n"
" --vMin <float> V 下界 (默认 0.0)\n"
" --vMax <float> V 上界 (默认 1.0)\n"
"\n"
"其他选项:\n"
" --name <str> 输出文件名前缀 (默认 \"test\")\n"
" --seed <int> 随机种子;0=每次随机 (默认 42)\n"
" --device <int> NPU device id (默认 0)\n"
" --dump 保存 npu_out/cpu_ref/prec_cmp.txt\n"
" --verbose 打印全部超阈值点(而非仅前 10 个)\n"
" --help 打印此帮助\n"
"\n"
"示例:\n"
" %s # 默认参数运行\n"
" %s --B 1 --N1 8 --N2 2 --S1 1 --S2 512 --D 128 # FD+GQA\n"
" %s --B 4 --N1 16 --N2 4 --S1 256 --S2 256 --D 128 --mask 1\n"
" %s --S1 64 --S2 256 --qMin 10 --qMax 10 --kMin 10 --kMax 10\n",
prog, prog, prog, prog, prog);
}
int main(int argc, char *argv[])
{
std::string caseName = "test";
int64_t B = 1;
int64_t N_q = 19;
int64_t N_kv = 19;
int64_t S_q = 640;
int64_t S_kv = 1024;
int64_t D = 128;
int64_t maskMode = 0;
int64_t winLeft = -1;
int64_t winRight = -1;
float qMin = -10.0f, qMax = 10.0f;
float kMin = -10.0f, kMax = 10.0f;
float vMin = -10.0f, vMax = 10.0f;
uint32_t seed = 42u;
int32_t deviceId = 0;
bool dumpOutput = false;
bool verbose = false;
for (int i = 1; i < argc; ++i) {
std::string a = argv[i];
auto nextStr = [&]() -> std::string {
return (i + 1 < argc) ? argv[++i] : "";
};
if (a == "--help") { PrintUsage(argv[0]); return 0; }
else if (a == "--dump") { dumpOutput = true; }
else if (a == "--verbose") { verbose = true; }
else if (a == "--name") { caseName = nextStr(); }
else if (a == "--seed") { seed = (uint32_t)std::stoul(nextStr()); }
else if (a == "--device") { deviceId = (int32_t) std::stoi(nextStr()); }
else if (a == "--B") { B = std::stoll(nextStr()); }
else if (a == "--N1") { N_q = std::stoll(nextStr()); }
else if (a == "--N2") { N_kv = std::stoll(nextStr()); }
else if (a == "--S1") { S_q = std::stoll(nextStr()); }
else if (a == "--S2") { S_kv = std::stoll(nextStr()); }
else if (a == "--D") { D = std::stoll(nextStr()); }
else if (a == "--mask") { maskMode = std::stoll(nextStr()); }
else if (a == "--wL") { winLeft = std::stoll(nextStr()); }
else if (a == "--wR") { winRight = std::stoll(nextStr()); }
else if (a == "--qMin") { qMin = std::stof(nextStr()); }
else if (a == "--qMax") { qMax = std::stof(nextStr()); }
else if (a == "--kMin") { kMin = std::stof(nextStr()); }
else if (a == "--kMax") { kMax = std::stof(nextStr()); }
else if (a == "--vMin") { vMin = std::stof(nextStr()); }
else if (a == "--vMax") { vMax = std::stof(nextStr()); }
else { LOG_PRINT("[WARN] 未知参数: %s,已忽略(--help 查看用法)\n", a.c_str()); }
}
if (N_kv < 0) N_kv = N_q;
TestCase tc;
tc.name = caseName.c_str();
tc.B = B;
tc.N_q = N_q;
tc.N_kv = N_kv;
tc.S_q = S_q;
tc.S_kv = S_kv;
tc.D = D;
tc.maskMode = maskMode;
tc.winLeft = winLeft;
tc.winRight = winRight;
tc.qMin = qMin; tc.qMax = qMax;
tc.kMin = kMin; tc.kMax = kMax;
tc.vMin = vMin; tc.vMax = vMax;
const char *SEP = "══════════════════════════════════════════════════════════════════";
LOG_PRINT("\n%s\n", SEP);
LOG_PRINT(" FlashAttn 测试参数配置\n");
LOG_PRINT("%s\n", SEP);
LOG_PRINT(" B=%ld N1(N_q)=%ld N2(N_kv)=%ld S1(S_q)=%ld S2(S_kv)=%ld D=%ld\n",
B, N_q, N_kv, S_q, S_kv, D);
LOG_PRINT(" maskMode=%ld winLeft=%ld winRight=%ld\n",
maskMode, winLeft, winRight);
LOG_PRINT(" Q∈[%.3g, %.3g] K∈[%.3g, %.3g] V∈[%.3g, %.3g]\n",
(double)qMin, (double)qMax, (double)kMin, (double)kMax, (double)vMin, (double)vMax);
LOG_PRINT(" seed=%u device=%d dump=%s verbose=%s\n",
seed, deviceId, dumpOutput ? "yes" : "no", verbose ? "yes" : "no");
LOG_PRINT("%s\n\n", SEP);
aclrtStream stream;
if (Init(deviceId, &stream) != ACL_SUCCESS) {
LOG_PRINT("Init acl failed.\n");
return 1;
}
bool passed = false;
try {
passed = RunTestCase(tc, stream, dumpOutput, verbose, seed);
} catch (const std::exception &e) {
LOG_PRINT("[ERROR] 运行异常: %s\n", e.what());
} catch (...) {
LOG_PRINT("[ERROR] 未知异常\n");
}
LOG_PRINT("\n%s\n 最终结论: %s\n%s\n\n",
SEP, passed ? "✓ PASS" : "✗ FAIL", SEP);
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return passed ? 0 : 1;
}