aclnnSparseFlashAttention
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A2 推理系列产品 | √ |
| Atlas A3 推理系列产品 | √ |
功能说明
-
接口功能:sparse_flash_attention(SFA)是针对大序列长度推理场景的高效注意力计算模块,该模块通过“只计算关键部分”大幅减少计算量,然而会引入大量的离散访存,造成数据搬运时间增加,进而影响整体性能。
-
计算公式:
softmax(Q@K~Tdk)@V~\text{softmax}(\frac{Q@\tilde{K}^T}{\sqrt{d_k}})@\tilde{V}
其中K~,V~\tilde{K},\tilde{V}为基于某种选择算法(如lightning_indexer)得到的重要性较高的Key和Value,一般具有稀疏或分块稀疏的特征,dkd_k为Q,K~Q,\tilde{K}每一个头的维度。
函数原型
每个算子分为两段式接口,必须先调用“aclnnSparseFlashAttentionGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnSparseFlashAttention”接口执行计算。
aclnnStatus aclnnSparseFlashAttentionGetWorkspaceSize(
const aclTensor *query,
const aclTensor *key,
const aclTensor *value,
const aclTensor *sparseIndices,
const aclTensor *blockTable,
const aclTensor *actualSeqLengthsQuery,
const aclTensor *actualSeqLengthsKv,
const aclTensor *queryRope,
const aclTensor *keyRope,
double scaleValue,
int64_t sparseBlockSize,
char *layoutQuery,
char *layoutKv,
int64_t sparseMode,
int64_t preTokens,
int64_t nextTokens,
int64_t attentionMode,
bool returnSoftmaxLse,
const aclTensor *attentionOutOut,
const aclTensor *softmaxMaxOut,
const aclTensor *softmaxSumOut,
uint64_t *workspaceSize,
aclOpExecutor **executor)
aclnnStatus aclnnSparseFlashAttention(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
const aclrtStream stream)
aclnnSparseFlashAttentionGetWorkspaceSize
-
参数说明:
Note
- query、key、value参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
- Q_S和S1表示query shape中的S,KV_S和S2表示key shape中的S,Q_N和N1表示num_query_heads,KV_N和N2表示num_key_value_heads,T1表示query shape中的T,T2表示key shape中的输入样本序列长度的累加和。
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor query(aclTensor) 输入 attention结构的Query输入。 不支持空tensor。 FLOAT16、BFLOAT16 ND - layout_query为BSND时,shape为(B,S1,N1,D)。
- layout_query为TND时,shape为(T1,N1,D)。
x key(aclTensor) 输入 attention结构的Key输入 - 不支持空tensor。
- block_num为PageAttention时block总数。
FLOAT16、BFLOAT16 ND - layout_kv为PA_BSND时,shape为(block_num, block_size, KV_N, D)。
- layout_kv为BSND时,shape为(B, S2, KV_N, D)。
- layout_kv为TND时,shape为(T2, KV_N, D)。
x value(aclTensor) 输入 attention结构的Value输入。 不支持空tensor。 FLOAT16、BFLOAT16 ND shape与key的shape一致。 x sparseIndices(aclTensor) 输入 离散取kvCache的索引。 - 不支持空tensor。
- sparse_size为一次离散选取的block数,需要保证每行有效值均在前半部分,无效值均在后半部分,且需要满足sparse_size大于0。
INT32 ND - layout_query为BSND时,shape为(B, Q_S, KV_N, sparse_size)。
- layout_query为TND时,shape为(Q_T, KV_N, sparse_size)。
x blockTable(aclTensor) 输入 表示PageAttention中kvCache存储使用的block映射表。 - 不支持空tensor。
- 第二维长度不小于所有batch中最大的S2对应的block数量,即S2_max / block_size向上取整。
INT32 ND shape支持(B,S2/block_size)。 x actualSeqLengthsQuery(aclTensor) 输入 表示不同Batch中query的有效token数。 - 不支持空tensor。
- 如果不指定seqlen可传入None,表示和query的shape的S长度相同。
- 该入参中每个Batch的有效token数不超过query中的维度S大小且不小于0。支持长度为B的一维tensor。
- layout_query为TND时,该入参必须传入,且以该入参元素的数量作为B值,该参数中每个元素的值表示当前batch与之前所有batch的token数总和。
INT32 ND (B,) x actualSeqLengthsKv(aclTensor) 输入 表示不同Batch中key和value的有效token数。 - 不支持空tensor。
- 如果不指定seqlen可传入None,表示和key的shape的S长度相同。
- 该参数中每个Batch的有效token数不超过key/value中的维度S大小且不小于0。支持长度为B的一维tensor。
- 当layout_kv为TND或PA_BSND时,该入参必须传入。
- layout_kv为TND,该参数中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。
INT32 ND (B,) x queryRope(aclTensor) 输入 表示MLA结构中的query的rope信息。 不支持空tensor。 FLOAT16、BFLOAT16 ND - layout_query为TND时,shape为(B,S1,N1,Dr)。
- layout_query为BSND时,shape为(T1,N1,Dr)。
x keyRope(aclTensor) 输入 表示MLA结构中的key的rope信息。 不支持空tensor。 FLOAT16、BFLOAT16 ND - layout_kv为TND时,shape为(B,S1,N1,Dr)。
- layout_kv为BSND时,shape为(T1,N1,Dr)。
- layout_kv为PA_BSND时,shape为(block_num,block_size,N2,Dr)。
x scaleValue(double) 输入 代表缩放系数。 - FLOAT16 - - - sparseBlockSize(int64_t) 输入 代表sparse阶段的block大小。 - sparse_block_size为1时,为Token-wise稀疏化场景,将每个token视为独立单元,在计算重要性分数时,评估每个查询token与每个键值token之间的独立关联程度。
- sparse_block_size为大于1小于等于128时,为Block-wise稀疏化场景,将token序列划分为固定大小的连续块,以块为单位进行重要性评估,块内token共享相同的稀疏化决策。
INT64 - - - layoutQuery(char) 输入 标识输入query的数据排布格式。 - 用户不特意指定时可传入默认值"BSND"。
- 支持传入BSND和TND。
STRING - - - layoutKv(char) 输入 标识输入key的数据排布格式。 - 用户不特意指定时可传入默认值"BSND"。
- 支持传入TND、BSND和PA_BSND,其中PA_BSND在使能PageAttention时使用。
STRING - - - sparseMode(int64_t) 输入 表示sparse的模式。 - sparse_mode为0时,代表全部计算。
- sparse_mode为3时,代表rightDownCausal模式的mask,对应以右下顶点往左上为划分线的下三角场景。
INT64 - - - preTokens(int64_t) 输入 用于稀疏计算,表示attention需要和前几个Token计算关联。 仅支持默认值2^63-1。 INT64 - - - nextTokens(int64_t) 输入 用于稀疏计算,表示attention需要和后几个Token计算关联。 仅支持默认值2^63-1。 INT64 - - - attentionMode(int64_t) 输入 - 仅支持传入2,表示MLA-absorb模式。 INT64 - - - returnSoftmaxLse(bool) 输入 用于表示是否返回softmax_max和softmax_sum。 - True表示返回,False表示不返回;默认值为False。
- 该参数仅在训练且layout_kv不为PA_BSND场景支持。
BOOL - - - attentionOut(aclTensor) 输出 公式中的输出。 不支持空tensor。 FLOAT16、BFLOAT16 ND - layout_query为BSND时,shape为(B,S1,N1,D)。
- layout_query为TND时shape为(T1,N1,D)。
x softmaxMaxOut(aclTensor) 输出 Attention算法对query乘key的结果,取max得到softmax_max。 不支持空tensor。 FLOAT ND - layout_query为BSND时,shape为(B,N2,S1,N1/N2)。
- layout_query为TND时shape为(N2,T1,N1/N2)。
x softmaxSumOut(aclTensor) 输出 Attention算法query乘key的结果减去softmax_max, 再取exp,接着求sum,得到softmax_sum。 不支持空tensor。 FLOAT ND - layout_query为BSND时,shape为(B,N2,S1,N1/N2)。
- layout_query为TND时shape为(N2,T1,N1/N2)。
x workspaceSize(uint64_t*) 输出 返回需要在Device侧申请的workspace大小。 - - - - - executor(aclOpExecutor) 输出 返回op执行器,包含了算子计算流程。 - - - - - -
返回值:
aclnnStatus:返回状态码,具体参见aclnn返回码。
第一段接口会完成入参校验,出现以下场景时报错:
返回值 错误码 描述 ACLNN_ERR_PARAM_NULLPTR 161001 如果传入参数是必选输入,输出或者必选属性,且是空指针,则返回161001。 ACLNN_ERR_PARAM_INVALID 161002 query、key、value、sparseIndices、blockTable、actualSeqLengthsQuery、actualSeqLengthsKv、queryRope、keyRope、scaleValue、sparseBlockSize、layoutQuery、layoutKv、sparseMode、attentionMode、returnSoftmaxLse、attentionOut、softmaxMaxOut、softmaxSumOut的数据类型和数据格式不在支持的范围内。
aclnnSparseFlashAttention
-
参数说明:
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址。 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnSparseFlashAttentionGetWorkspaceSize获取。 executor 输入 op执行器,包含了算子计算流程。 stream 输入 指定执行任务的Stream。 -
返回值:
aclnnStatus:返回状态码,具体参见aclnn返回码。
约束说明
- 确定性计算:aclnnSparseFlashAttention默认确定性实现。
- 该接口支持推理场景下使用。
- N1支持1~64和128。
- block_size为一个block的token数,block_size取值为16的倍数,且最大支持1024。
- 参数query中的D和key、value的D值相等为512,参数query_rope中的Dr和key_rope的Dr值相等为64。
- 参数query、key、value的数据类型必须保持一致。
- 支持sparse_block_size整除block_size。
- Ascend 950PR/Ascend 950DT:
- 只支持sparse_block_size为1。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- 支持[1,128],且要求是2的幂次方,在PageAttention场景下要求sparse_block_size整除block_size
- Ascend 950PR/Ascend 950DT:
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
/**
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file test_incre_flash_attention_v4.cpp
* \brief
*/
#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
#include "securec.h"
#include "acl/acl.h"
#include "aclnnop/aclnn_sparse_flash_attention.h"
using namespace std;
namespace {
#define CHECK_RET(cond) ((cond) ? true :(false))
#define LOG_PRINT(message, ...) \
do { \
(void)printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
int Init(int32_t deviceId, aclrtStream* stream) {
auto ret = aclInit(nullptr);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclInit failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtSetDevice(deviceId);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtCreateStream(stream);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret);
return ret;
}
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret);
return ret;
}
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret);
return ret;
}
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
struct TensorResources {
void* queryDeviceAddr = nullptr;
void* keyDeviceAddr = nullptr;
void* valueDeviceAddr = nullptr;
void* sparseIndicesDeviceAddr = nullptr;
void* attentionOutDeviceAddr = nullptr;
void* softmaxMaxDeviceAddr = nullptr;
void* softmaxSumDeviceAddr = nullptr;
void* queryRopeDeviceAddr = nullptr;
void* keyRopeDeviceAddr = nullptr;
aclTensor* queryTensor = nullptr;
aclTensor* keyTensor = nullptr;
aclTensor* valueTensor = nullptr;
aclTensor* sparseIndicesTensor = nullptr;
aclTensor* attentionOutTensor = nullptr;
aclTensor* softmaxMaxTensor = nullptr;
aclTensor* softmaxSumTensor = nullptr;
aclTensor* queryRopeTensor = nullptr;
aclTensor* keyRopeTensor = nullptr;
};
int InitializeTensors(TensorResources& resources) {
std::vector<int64_t> queryShape = {1, 2, 1, 512};
std::vector<int64_t> keyShape = {1, 2, 1, 512};
std::vector<int64_t> valueShape = {1, 2, 1, 512};
std::vector<int64_t> sparseIndicesShape = {1, 2, 1, 2};
std::vector<int64_t> attentionOutShape = {1, 2, 1, 512};
std::vector<int64_t> softmaxMaxShape = {1, 2, 1, 16};
std::vector<int64_t> softmaxSumShape = {1, 2, 1, 16};
std::vector<int64_t> queryRopeShape = {1, 2, 1, 64};
std::vector<int64_t> keyRopeShape = {1, 2, 1, 64};
int64_t queryShapeSize = GetShapeSize(queryShape);
int64_t keyShapeSize = GetShapeSize(keyShape);
int64_t valueShapeSize = GetShapeSize(valueShape);
int64_t sparseIndicesShapeSize = GetShapeSize(sparseIndicesShape);
int64_t attentionOutShapeSize = GetShapeSize(attentionOutShape);
int64_t softmaxMaxShapeSize = GetShapeSize(softmaxMaxShape);
int64_t softmaxSumShapeSize = GetShapeSize(softmaxSumShape);
int64_t queryRopeShapeSize = GetShapeSize(queryRopeShape);
int64_t keyRopeShapeSize = GetShapeSize(keyRopeShape);
std::vector<float> queryHostData(queryShapeSize, 1);
std::vector<float> keyHostData(keyShapeSize, 1);
std::vector<float> valueHostData(valueShapeSize, 1);
std::vector<int32_t> sparseIndicesHostData(sparseIndicesShapeSize, 1);
std::vector<float> attentionOutHostData(attentionOutShapeSize, 1);
std::vector<float> softmaxMaxHostData(softmaxMaxShapeSize, 1);
std::vector<float> softmaxSumHostData(softmaxSumShapeSize, 1);
std::vector<float> queryRopeHostData(queryRopeShapeSize, 1);
std::vector<float> keyRopeHostData(keyRopeShapeSize, 1);
// Create query aclTensor.
int ret = CreateAclTensor(queryHostData, queryShape, &resources.queryDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.queryTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create key aclTensor.
ret = CreateAclTensor(keyHostData, keyShape, &resources.keyDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.keyTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create value aclTensor.
ret = CreateAclTensor(valueHostData, valueShape, &resources.valueDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.valueTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create sparseIndices aclTensor.
ret = CreateAclTensor(sparseIndicesHostData, sparseIndicesShape, &resources.sparseIndicesDeviceAddr,
aclDataType::ACL_INT32, &resources.sparseIndicesTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create queryRope aclTensor.
ret = CreateAclTensor(queryRopeHostData, queryRopeShape, &resources.queryRopeDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.queryRopeTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create keyRope aclTensor.
ret = CreateAclTensor(keyRopeHostData, keyRopeShape, &resources.keyRopeDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.keyRopeTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create attention_out aclTensor.
ret = CreateAclTensor(attentionOutHostData, attentionOutShape, &resources.attentionOutDeviceAddr,
aclDataType::ACL_FLOAT16, &resources.attentionOutTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create softmax_max aclTensor.
ret = CreateAclTensor(softmaxMaxHostData, softmaxMaxShape, &resources.softmaxMaxDeviceAddr,
aclDataType::ACL_FLOAT, &resources.softmaxMaxTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
// Create softmax_sum aclTensor.
ret = CreateAclTensor(softmaxSumHostData, softmaxSumShape, &resources.softmaxSumDeviceAddr,
aclDataType::ACL_FLOAT, &resources.softmaxSumTensor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
return ret;
}
return ACL_SUCCESS;
}
int ExecuteSparseFlashAttention(TensorResources& resources, aclrtStream stream,
void** workspaceAddr, uint64_t* workspaceSize) {
int64_t d = 2;
double scaleValue = 1 / sqrt(d);
int64_t sparseBlockSize = 64;
constexpr const char layerOutStr[] = "BSND";
constexpr size_t layerOutLen = sizeof(layerOutStr);
char layoutQuery[layerOutLen];
char layoutKv[layerOutLen];
errno_t memcpyRet = memcpy_s(layoutQuery, sizeof(layoutQuery), layerOutStr, layerOutLen);
if (memcpyRet != 0) {
LOG_PRINT("memcpy_s layoutQuery failed. ERROR: %d\n", memcpyRet);
return -1;
}
memcpyRet = memcpy_s(layoutKv, sizeof(layoutKv), layerOutStr, layerOutLen);
if (memcpyRet != 0) {
LOG_PRINT("memcpy_s layoutKv failed. ERROR: %d\n", memcpyRet);
return -1;
}
int64_t sparseMode = 3;
int64_t preTokens = 9223372036854775807;
int64_t nextTokens = 9223372036854775807;
int64_t attentionMode = 2;
bool returnSoftmaxLse = false;
aclOpExecutor* executor;
int ret = aclnnSparseFlashAttentionGetWorkspaceSize(resources.queryTensor, resources.keyTensor, resources.valueTensor, resources.sparseIndicesTensor, nullptr, nullptr, nullptr, resources.queryRopeTensor, resources.keyRopeTensor,
scaleValue, sparseBlockSize, layoutQuery, layoutKv, sparseMode, preTokens,
nextTokens, attentionMode, returnSoftmaxLse, resources.attentionOutTensor, resources.softmaxMaxTensor, resources.softmaxSumTensor, workspaceSize, &executor);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclnnSparseFlashAttentionGetWorkspaceSize failed. ERROR: %d\n", ret);
return ret;
}
if (*workspaceSize > 0ULL) {
ret = aclrtMalloc(workspaceAddr, *workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret);
return ret;
}
}
ret = aclnnSparseFlashAttention(*workspaceAddr, *workspaceSize, executor, stream);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclnnSparseFlashAttention failed. ERROR: %d\n", ret);
return ret;
}
return ACL_SUCCESS;
}
int PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
auto size = GetShapeSize(shape);
std::vector<aclFloat16> resultData(size, 0);
auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]),
*deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret);
return ret;
}
for (int64_t i = 0; i < size; i++) {
LOG_PRINT("mean result[%ld] is: %f\n", i, aclFloat16ToFloat(resultData[i]));
}
return ACL_SUCCESS;
}
void CleanupResources(TensorResources& resources, void* workspaceAddr,
aclrtStream stream, int32_t deviceId) {
if (resources.queryTensor) {
aclDestroyTensor(resources.queryTensor);
}
if (resources.keyTensor) {
aclDestroyTensor(resources.keyTensor);
}
if (resources.valueTensor) {
aclDestroyTensor(resources.valueTensor);
}
if (resources.sparseIndicesTensor) {
aclDestroyTensor(resources.sparseIndicesTensor);
}
if (resources.attentionOutTensor) {
aclDestroyTensor(resources.attentionOutTensor);
}
if (resources.softmaxMaxTensor) {
aclDestroyTensor(resources.softmaxMaxTensor);
}
if (resources.softmaxSumTensor) {
aclDestroyTensor(resources.softmaxSumTensor);
}
if (resources.queryRopeTensor) {
aclDestroyTensor(resources.queryRopeTensor);
}
if (resources.keyRopeTensor) {
aclDestroyTensor(resources.keyRopeTensor);
}
if (resources.queryDeviceAddr) {
aclrtFree(resources.queryDeviceAddr);
}
if (resources.keyDeviceAddr) {
aclrtFree(resources.keyDeviceAddr);
}
if (resources.valueDeviceAddr) {
aclrtFree(resources.valueDeviceAddr);
}
if (resources.sparseIndicesDeviceAddr) {
aclrtFree(resources.sparseIndicesDeviceAddr);
}
if (resources.attentionOutDeviceAddr) {
aclrtFree(resources.attentionOutDeviceAddr);
}
if (resources.softmaxMaxDeviceAddr) {
aclrtFree(resources.softmaxMaxDeviceAddr);
}
if (resources.softmaxSumDeviceAddr) {
aclrtFree(resources.softmaxSumDeviceAddr);
}
if (resources.queryRopeDeviceAddr) {
aclrtFree(resources.queryRopeDeviceAddr);
}
if (resources.keyRopeDeviceAddr) {
aclrtFree(resources.keyRopeDeviceAddr);
}
if (workspaceAddr) {
aclrtFree(workspaceAddr);
}
if (stream) {
aclrtDestroyStream(stream);
}
aclrtResetDevice(deviceId);
aclFinalize();
}
} // namespace
int main() {
int32_t deviceId = 0;
aclrtStream stream = nullptr;
TensorResources resources = {};
void* workspaceAddr = nullptr;
uint64_t workspaceSize = 0;
std::vector<int64_t> attentionOutShape = {1, 2, 1, 16};
std::vector<int64_t> softmaxMaxShape = {1, 2, 1, 16};
std::vector<int64_t> softmaxSumShape = {1, 2, 1, 16};
int ret = ACL_SUCCESS;
// 1. Initialize device and stream
ret = Init(deviceId, &stream);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("Init acl failed. ERROR: %d\n", ret);
return ret;
}
// 2. Initialize tensors
ret = InitializeTensors(resources);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
CleanupResources(resources, workspaceAddr, stream, deviceId);
return ret;
}
// 3. Execute the operation
ret = ExecuteSparseFlashAttention(resources, stream, &workspaceAddr, &workspaceSize);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
CleanupResources(resources, workspaceAddr, stream, deviceId);
return ret;
}
// 4. Synchronize stream
ret = aclrtSynchronizeStream(stream);
if (!CHECK_RET(ret == ACL_SUCCESS)) {
LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret);
CleanupResources(resources, workspaceAddr, stream, deviceId);
return ret;
}
// 5. Process results
printf("-----------attentionOut输出-----------\n");
PrintOutResult(attentionOutShape, &resources.attentionOutDeviceAddr);
printf("-----------softmaxMax输出-----------\n");
PrintOutResult(softmaxMaxShape, &resources.softmaxMaxDeviceAddr);
printf("-----------softmaxSum输出-----------\n");
PrintOutResult(softmaxSumShape, &resources.softmaxSumDeviceAddr);
// 6. Cleanup resources
CleanupResources(resources, workspaceAddr, stream, deviceId);
return 0;
}