* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "CubeForward.h"
#include "VectorForward.h"
using namespace AscendC;
extern "C" __global__ __aicore__ void laser_attention(
__gm__ uint8_t * __restrict__ q_gm,
__gm__ uint8_t * __restrict__ k_gm,
__gm__ uint8_t * __restrict__ v_gm,
__gm__ uint8_t * __restrict__ atten_mask_gm,
__gm__ uint8_t *__restrict__ alibi_mask_gm,
__gm__ uint8_t * __restrict__ drop_mask_gm,
__gm__ uint8_t * __restrict__ softmax_log_max_sum_gm,
__gm__ uint8_t * __restrict__ attention_out_gm,
__gm__ uint8_t *__restrict__ workspace,
__gm__ uint8_t *__restrict__ tiling_para_gm)
{
GET_TILING_DATA(tiling_data_in, tiling_para_gm);
const LaserAttentionTilingData* __restrict tiling_data = &tiling_data_in;
SetSysWorkspace(workspace);
__gm__ uint8_t* user = GetUserWorkspace(workspace);
int32_t y = tiling_data->coreNumPerGroup;
int32_t f = tiling_data->coreGroupNum;
int32_t b = tiling_data->batchSize;
int32_t n = tiling_data->headNum;
int32_t s1 = tiling_data->qSeqLength;
int32_t s2 = tiling_data->kSeqLength;
int32_t d = tiling_data->headDim;
int32_t g = tiling_data->headGroupSize;
int32_t qkTriangle = tiling_data->isTriangle;
int32_t sparseMode = tiling_data->sparseMode;
int32_t windowLen = tiling_data->windowLen;
bool isHighPrecision = true;
int32_t maskSeqLength = tiling_data->maskSeqLength;
float scale = tiling_data->scale;
auto aicNum = y * f;
__gm__ float * __restrict__ gm_attention_out = (__gm__ float *__restrict__)attention_out_gm;
__gm__ float* __restrict__ softmax_log_max_sum = (__gm__ float *__restrict__)softmax_log_max_sum_gm;
__gm__ float * __restrict__ gm_rowmax_diag = (__gm__ float *__restrict__)user;
__gm__ float * __restrict__ gm_rowsum_diag =
(__gm__ float *__restrict__)(gm_rowmax_diag + 256 * 128 * 2 * MAX_SWITCH_TIME * aicNum);
__gm__ uint8_t * __restrict__ score_gm =
(__gm__ uint8_t *__restrict__)(gm_rowsum_diag + 256 * 128 * 2 * MAX_SWITCH_TIME* aicNum);
#ifdef __DAV_C220_CUBE__
CUBE_FORWARD_ONLINE::CubeForward<half, false, half> op;
op.Init(q_gm, k_gm, v_gm, score_gm, gm_attention_out, gm_rowsum_diag, gm_rowmax_diag, softmax_log_max_sum,
y, f, b, n, s1, s2, d, g, qkTriangle, sparseMode, windowLen);
op.Run();
#elif __DAV_C220_VEC__
VectorForward<half, false, half> op;
op.Init(q_gm, k_gm, v_gm, atten_mask_gm, score_gm, gm_attention_out, softmax_log_max_sum, gm_rowsum_diag,
gm_rowmax_diag, s1, s2, n, b, y, qkTriangle, windowLen / BASE_BLOCK_SIDE_LEN, maskSeqLength, scale, windowLen);
op.SetHighPrecision(true);
op.Run();
#endif
}