* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef __VECTORFORWARD_H__
#define __VECTORFORWARD_H__
#include "matmul_const.h"
#include "AddressMappingVectorForwardOnline.h"
#ifdef __DAV_C220_VEC__
#define ROUND_UP_8(x) (((x) + 7) / 8 * 8)
constexpr size_t MASK_BASE = 128;
constexpr size_t MASK_HALF_BASE = 64;
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> class VectorForward {
public:
__aicore__ inline VectorForward() {};
__aicore__ inline void Init(
__gm__ uint8_t * __restrict__ a_cube1,
__gm__ uint8_t * __restrict__ b_cube1,
__gm__ uint8_t * __restrict__ b_cube2,
__gm__ uint8_t * __restrict__ mask_gm,
__gm__ uint8_t * __restrict__ score_gm,
__gm__ float * __restrict__ c_cube2,
__gm__ float * __restrict__ log_sum_gm,
__gm__ float * __restrict__ diag_rowsum_gm,
__gm__ float * __restrict__ d_rowmax_gm,
int32_t qSeqLength,
int32_t kSeqLength,
int32_t H,
int32_t B,
int32_t Y,
int32_t qk,
int32_t windows_block_num,
int32_t maskSeqLength,
float scale,
int32_t windowLen
);
__aicore__ __inline__ void __set_mask(int32_t len)
{
if (len >= 128) {
set_vector_mask((uint64_t)-1, (uint64_t)-1);
return;
}
int32_t highMask = len - 64 > 0 ? len - 64 : 0;
int32_t lowMask = len - 64 >= 0 ? 64 : len;
if (len < MASK_HALF_BASE) {
set_vector_mask(0x0, ((uint64_t)1 << lowMask) - 1);
} else {
set_vector_mask(((uint64_t)1 << highMask) - 1, 0xffffffffffffffff);
}
}
__aicore__ __inline__ void __set_reverse_mask(int32_t len)
{
if (len >= MASK_BASE) {
set_vector_mask((uint64_t)-1, (uint64_t)-1);
return;
}
int32_t lowMask = len - 64 > 0 ? len -64 : 0;
int32_t highMask = len - 64 >= 0 ? 64 : len;
if (len < MASK_HALF_BASE) {
set_vector_mask(0xffffffffffffffff, ~(((uint64_t)1 << highMask) - 1));
} else {
set_vector_mask(~(uint64_t)1 << (lowMask - 1), 0x0);
}
}
__aicore__ inline void Run();
__aicore__ inline void SetHighPrecision(bool isHighPrecision)
{
this->isHighPrecision = isHighPrecision;
};
struct UB_FOR_SHORT_LEN_ATTN_SCORE {
__ubuf__ float* cur_buf_for_vbrcb_rowmax_fp32;
__ubuf__ INPUT_T* buf_for_load_attn_score_fp16;
__ubuf__ INPUT_T* buf_for_subMaxValueResult_fp16;
__ubuf__ float* buf_for_diag_fp32;
__ubuf__ INPUT_T* buf_for_load_one_block_tri_mask_fp16;
__ubuf__ float* buf_for_cacl_final_rowmax_fp32;
__ubuf__ INPUT_T* buf_for_cacl_final_rowmax_fp16;
__ubuf__ float* buf_for_cacl_rowmax_fp32;
__ubuf__ INPUT_T* buf_for_cacl_rowmax_fp16;
__ubuf__ float* buf_for_vbrcb_rowmax_fp32;
__ubuf__ INPUT_T* buf_for_vbrcb_rowmax_fp16;
__ubuf__ float* buf_for_record_rowmax_fp32;
__ubuf__ INPUT_T* buf_for_record_rowmax_fp16;
__ubuf__ INPUT_T* pp_buf_for_attn_score_fp16[2];
__ubuf__ INPUT_T* pp_buf_for_load_one_block_tri_mask_fp16[2];
__ubuf__ INPUT_T* maxJVectorUbAddr;
__ubuf__ INPUT_T* lastMaxJVectorUbAddr;
__ubuf__ INPUT_T* subMaxValueResultVectorUbAddr;
__ubuf__ float* subMaxValueResultVectorUbAddr_fp32;
__ubuf__ INPUT_T* ljVectorUbAddr;
__ubuf__ float* ljVectorUbAddr_fp32;
__ubuf__ float* lastLjVectorUbAddr ;
__ubuf__ float* diagExpMaxJMatPingUbAddr ;
__ubuf__ float* diagExpMaxJMatPongUbAddr ;
};
struct UB_FOR_NORMALIZE {
__ubuf__ float* buf_for_load_O_fp32;
__ubuf__ float* buf_for_load_rowsum_fp32;
__ubuf__ float* buf_for_brcb_rowsum_fp32;
int32_t o_ping_pong_interval;
int32_t rowsum_ping_pong_interval;
int32_t rowsum_brcb_ping_pong_interval;
};
struct PARAM_MEDIUM_SEQ_EXP {
int32_t block_num_per_step;
int32_t block_num_for_last;
int32_t last_padding_block_num;
int32_t section_start_line_offset;
int32_t section_mask_offset;
int32_t total_frag_num;
int32_t cur_frag;
bool tail_block;
int32_t tri_matrix_num;
int32_t apply_tri_mask;
int32_t buf_offset;
int32_t record_rowmax_offset;
};
struct PARAM_SHORT_SEQ_MAX {
int32_t section_block_num;
int32_t section_padding_block_num;
int32_t section_start_line_offset;
int32_t section_mask_offset;
int32_t record_rowmax_offset;
int32_t apply_tri_mask;
bool is_head_section;
bool is_tail_section;
};
struct PARAM_LONG_SEQ_EXP {
int32_t section_block_num;
int32_t section_padding_block_num;
int32_t section_start_line_offset;
int32_t section_mask_offset;
int32_t total_frag_num;
int32_t cur_frag;
int32_t tri_matrix_num;
int32_t apply_tri_mask;
};
struct UB_FOR_LN_ROWSUM {
__ubuf__ float* ub_buf_for_rowsum;
__ubuf__ float* ub_buf_for_rowsum_res;
};
private:
__aicore__ __inline__ void initWorkSpace();
__aicore__ __inline__ void allocate_ubuf_for_norm (UB_FOR_NORMALIZE *ub_norm);
__aicore__ inline void get_sub_seq_length_per_proc(int32_t k_seq_len,
int32_t block_num_per_full_line,
int32_t *sub_seq_length_per_proc);
__aicore__ __inline__ void get_padding_info_for_row_max(int32_t total_block_num,
int32_t *padding_block_num);
__aicore__ __inline__ void padding_for_row_max_or_rowsum(int32_t total_block_num,
int32_t padding_block_num,
int32_t ping_pong_flag,
int32_t paddingType,
__ubuf__ float * pp_buf_for_attn_score_fp16[]);
__aicore__ __inline__ void padding_for_row_max_or_rowsum2(int32_t total_block_num,
int32_t padding_block_num,
int32_t ping_pong_flag,
int32_t paddingType,
__ubuf__ INPUT_T * pp_buf_for_attn_score_fp16[]);
__aicore__ __inline__ void process_cacl_max(int32_t basic_block_num,
int32_t padding_block_num,
bool pp_first_section,
__ubuf__ INPUT_T * cur_buf_for_attn_score,
__ubuf__ INPUT_T * cur_buf_for_rowmax,
__ubuf__ INPUT_T * buf_for_cacl_final_rowmax_fp16
);
__aicore__ __inline__ void cacl_max(__ubuf__ INPUT_T * buf_for_cacl, int32_t _block_num);
__aicore__ __inline__ void process_calc_sum(int32_t qk_triangle,
PARAM_SHORT_SEQ_MAX param,
int32_t ping_pong_flag,
bool first_line,
__gm__ WORKSPACE_T * attn_score_gm,
__gm__ INPUT_T * attn_mask_gm,
UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn,
bool sparse_flag,
int32_t lines);
__aicore__ __inline__ void process_line_phase_one_for_short_seq_max(bool is_head_section,
bool is_tail_section, int32_t qk_triangle, PARAM_SHORT_SEQ_MAX param,
int32_t ping_pong_flag, bool first_line, __gm__ WORKSPACE_T * attn_score_gm,
__gm__ INPUT_T * attn_mask_gm, UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn, bool sparse_flag, int32_t lines,
__ubuf__ INPUT_T* maxJVectorUbAddr, __ubuf__ INPUT_T* lastMaxJVectorUbAddr,
__ubuf__ INPUT_T* subMaxValueResultVectorUbAddr, int rowsPerDiag, int rowmax_build_step);
__aicore__ __inline__ void process_line_phase_one_for_short_seq_exp_and_rowsum(int32_t qk_triangle,
PARAM_SHORT_SEQ_MAX param,
int32_t ping_pong_flag,
__gm__ WORKSPACE_T * attn_score_gm,
__gm__ INPUT_T * attn_mask_gm,
UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn,
int32_t offset,
int32_t lines,
bool sparse_flag);
__aicore__ __inline__ void attention_score_short_double_line_one(int32_t sectionLoop, int32_t sectionNum,
int32_t qk_triangle, int32_t section_block_nums, int32_t tri_matrix_mask_offset,
int32_t each_vector_proc_line_num, int32_t local_section_start_line_offset, bool isa_head_section,
bool is_tail_section, int32_t diag_offset, __gm__ WORKSPACE_T * __restrict__ attn_score_gm,
__gm__ INPUT_T * __restrict__ attn_mask_gm, UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn,
bool sparse_flag);
__aicore__ __inline__ void backupMaxandSum(__ubuf__ float* ljVectorUbAddr,
__ubuf__ float* lastLjVectorUbAddr,
__ubuf__ INPUT_T* maxJVectorUbAddr,
__ubuf__ INPUT_T* lastMaxJVectorUbAddr);
__aicore__ __inline__ void updateRowsum(__ubuf__ float* ljVectorUbAddr,
__ubuf__ float* lastLjVectorUbAddr,
__ubuf__ float* subMaxValueResultVectorUbAddr,
__ubuf__ float* tempUbAddr);
__aicore__ __inline__ void cacl_wrap(__ubuf__ INPUT_T * buf_for_cacl, int32_t _block_num, int32_t optcode);
__aicore__ __inline__ void add_max_wrap(__ubuf__ INPUT_T * dest_addr, __ubuf__ INPUT_T * src_addr,
__ubuf__ INPUT_T * src_addr2, int32_t rpt, int32_t dst_stride,
int32_t src_stride, int32_t src_stride2, int32_t dst_rpt_stride,
int32_t src_rpt_stride, int32_t src_rpt_stride2, int32_t optcode);
__aicore__ __inline__ void process_cacl_wrap(int32_t basic_block_num,
int32_t padding_block_num,
bool pp_first_section,
__ubuf__ INPUT_T * cur_buf_for_attn_score,
__ubuf__ INPUT_T * cur_buf_for_rowmax,
__ubuf__ INPUT_T * buf_for_cacl_final_rowmax_fp16, int32_t optcode);
__aicore__ __inline__ void get_uni_rowmax_seq_info_per_proc(int32_t block_num_per_full_line,
int32_t sub_seq_length_per_proc,
int32_t *ping_block_offset_num,
int32_t *pong_block_offset_num,
int32_t *tail_block_offset_num,
int32_t *tail_block_num,
int32_t *ping_pong_times
);
__aicore__ __inline__ void allocate_ubuf_for_short_seq_attn_score(
UB_FOR_SHORT_LEN_ATTN_SCORE* ub_attn_score);
__aicore__ __inline__ void form_diagmat(__ubuf__ float * diagmataddr, __ubuf__ float * vecaddr,
__gm__ float * gm_addr, int step, int rowsPerDiag, int32_t local_section_start_line_offset,
int32_t ping_pong_flag, int type);
private:
__gm__ INPUT_T* __restrict__ gm_a_cube1;
__gm__ INPUT_T* __restrict__ gm_b_cube1;
__gm__ INPUT_T* __restrict__ gm_b_cube2;
__gm__ INPUT_T* __restrict__ attn_mask_gm;
__gm__ WORKSPACE_T* __restrict__ attn_score_gm;
__gm__ float* __restrict__ gm_c_cube2;
__gm__ float* __restrict__ log_sum_max_gm;
__gm__ float* diag_rowsum_gm;
__gm__ float* diag_rowmax_gm;
__ubuf__ float* diagExpMaxJMatUbAddr;
__ubuf__ float* tempUb;
__ubuf__ float* __restrict__ ub_for_softsync_flags;
__ubuf__ float* __restrict__ ub_for_softsync_check;
int32_t q_seq_len;
int32_t k_seq_len;
int32_t head_num;
int32_t batch_num;
int32_t y_cube_num_per_line;
int32_t qk_triangle;
int32_t V_num;
int32_t softSyncTimesCount = 0;
bool use_soft_sync = false;
bool isHighPrecision = true;
int32_t maskSeqLength;
int32_t rowPerTime = 1;
float SCALE;
INPUT_T FP16_SMALL_NUM = 5.9604644775390625e-8;
int32_t block_per_core = 64;
int32_t ky = 2;
int32_t SIZE_FP32 = 4;
int32_t SIZE_FP16 = 2;
int32_t BLOCK_NUM_8 = 8;
int32_t BYTES_PER_BLOCK = 32;
int32_t data_num_per_blk = BYTES_PER_BLOCK / SIZE_FP16;
int32_t windowLen;
Address::AddressMappingVectorForwardOnline address;
};
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ inline void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::Init(
__gm__ uint8_t * __restrict__ a_cube1,
__gm__ uint8_t * __restrict__ b_cube1,
__gm__ uint8_t * __restrict__ b_cube2,
__gm__ uint8_t * __restrict__ mask_gm,
__gm__ uint8_t * __restrict__ score_gm,
__gm__ float * __restrict__ c_cube2,
__gm__ float * __restrict__ log_sum_gm,
__gm__ float * __restrict__ d_rowsum_gm,
__gm__ float * __restrict__ d_rowmax_gm,
int32_t qSeqLength,
int32_t kSeqLength,
int32_t H,
int32_t B,
int32_t Y,
int32_t qk,
int32_t windows_block_num,
int32_t maskSeqLength,
float scale,
int32_t windowLen
)
{
gm_a_cube1 = (__gm__ INPUT_T *__restrict__)a_cube1;
gm_b_cube1 = (__gm__ INPUT_T *__restrict__)b_cube1;
gm_b_cube2 = (__gm__ INPUT_T *__restrict__)b_cube2;
attn_mask_gm = (__gm__ INPUT_T *__restrict__)mask_gm;
attn_score_gm = (__gm__ WORKSPACE_T *__restrict__)score_gm;
gm_c_cube2 = c_cube2;
log_sum_max_gm = log_sum_gm;
q_seq_len = qSeqLength;
k_seq_len = kSeqLength;
head_num = H;
batch_num = B;
y_cube_num_per_line = Y;
qk_triangle = qk;
V_num = windows_block_num;
this->maskSeqLength = maskSeqLength;
this->SCALE = scale;
this->windowLen = windowLen;
int32_t sparse_mode = this->V_num == 0 ? 0 : 1;
address.init(this->batch_num, this->head_num,
this->q_seq_len, this->k_seq_len, this->maskSeqLength, this->qk_triangle, this->V_num, sparse_mode,
this->block_per_core, this->ky);
int32_t core_num = get_block_num();
int32_t cur_core_index = get_block_idx();
int32_t vector_id = get_subblockid();
address.set_core_info(core_num, cur_core_index, vector_id);
diag_rowsum_gm = d_rowsum_gm;
diag_rowmax_gm = d_rowmax_gm;
address.start();
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ inline void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::Run()
{
set_atomic_none();
set_mask_norm();
set_vector_mask((uint64_t)-1, (uint64_t)-1);
bool unirow_mode = false;
UB_FOR_NORMALIZE ub_norm;
allocate_ubuf_for_norm(&ub_norm);
UB_FOR_SHORT_LEN_ATTN_SCORE ub_short_seq_attn;
int32_t tri_matrix_offset[2] = {0};
for (int times_sync_cube = 0; times_sync_cube < address.get_total_round(); times_sync_cube++) {
wait_flag_dev(AIC2AIVFLAGID);
int32_t sectionLoopTimes = qk_triangle == 0 ? 1 : 2;
__ubuf__ float *fp32_test = reinterpret_cast<__ubuf__ float *>((uintptr_t)192 * 1024 - 128 * 4);
int32_t rowmax_ub_offset = 0;
int32_t record_rowmax_offset = 0;
Address::FORWARD_SECTION_INFO section;
address.get_section_info(times_sync_cube, section);
if (!address.is_running(times_sync_cube))
section.sectionNum = 0;
for (int32_t sectionLoop = 0; sectionLoop < section.sectionNum; sectionLoop++) {
int32_t sectionBlockNum = section.sectionBlockNums[sectionLoop];
int32_t maxRow = MAX_LENG_PER_UB_PROC / sectionBlockNum / 128;
if (maxRow >= 32) {
rowPerTime = 32;
}
else if (maxRow >= 16) {
rowPerTime = 16;
}
else if (maxRow >= 8) {
rowPerTime = 8;
}
else if (maxRow >= 4) {
rowPerTime = 4;
}
else if (maxRow >= 2) {
rowPerTime = 2;
}
else {
rowPerTime = 1;
}
while ((128 / (2 * y_cube_num_per_line) / 2 / rowPerTime) == 0){
rowPerTime /= 2;
}
if (section.sparseFlag && rowPerTime > 16) {
rowPerTime = 16;
}
rowPerTime = 4;
allocate_ubuf_for_short_seq_attn_score(&ub_short_seq_attn);
diagExpMaxJMatUbAddr = ub_short_seq_attn.diagExpMaxJMatPingUbAddr;
attention_score_short_double_line_one(sectionLoop, section.sectionNum, section.isTriangle,
section.sectionBlockNums[sectionLoop], section.matrixMaskOffset, section.processLineNum,
section.sectionBlockOffset[sectionLoop], section.isHeadSection[sectionLoop],
section.isTailSection[sectionLoop], section.diagOffset[sectionLoop],
attn_score_gm + section.attentionScoreOffset, attn_mask_gm, ub_short_seq_attn, section.sparseFlag);
if (section.isTailSection[sectionLoop])
{
__ubuf__ float * ljVectorUbAddr = ub_short_seq_attn.ljVectorUbAddr_fp32;
__ubuf__ float * buf_for_diag_fp32 = ub_short_seq_attn.buf_for_diag_fp32;
__ubuf__ float * tempUb = ub_short_seq_attn.buf_for_record_rowmax_fp32;
__set_mask(64);
vector_dup(tempUb,
(float)1,
1,
1,
1,
8,
1);
pipe_barrier(PIPE_V);
vdiv(buf_for_diag_fp32, tempUb, ljVectorUbAddr, 1, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
set_flag(PIPE_V, PIPE_S, EVENT_ID0);
wait_flag(PIPE_V, PIPE_S, EVENT_ID0);
int32_t vector_id = get_subblockid();
form_diagmat(diagExpMaxJMatUbAddr, buf_for_diag_fp32, diag_rowsum_gm +
section.diagOffset[sectionLoop], 0, 64, vector_id, 0, 1);
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0);
}
}
uint64_t mode4 = 2;
uint64_t config4 = 1 | (mode4 << 4) | (AIV2AICFLAGID << 8);
ffts_cross_core_sync(PIPE_MTE3, config4);
}
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::allocate_ubuf_for_short_seq_attn_score(UB_FOR_SHORT_LEN_ATTN_SCORE * ub_attn_score)
{
int32_t offset = 512;
int32_t section_num = 8;
ub_attn_score -> buf_for_cacl_rowmax_fp16 = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += MAX_LENG_PER_UB_PROC * SIZE_FP16;
ub_attn_score -> buf_for_vbrcb_rowmax_fp16 = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * SIZE_FP16 * rowPerTime * 2;
ub_attn_score -> buf_for_cacl_final_rowmax_fp16 = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * 2 * rowPerTime * SIZE_FP16 * 2;
ub_attn_score -> buf_for_record_rowmax_fp32= reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2*ROUND_UP_8(rowPerTime)* SIZE_FP32 * 2;
ub_attn_score -> cur_buf_for_vbrcb_rowmax_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * rowPerTime * SIZE_FP32;
offset = 512 + MAX_LENG_PER_UB_PROC * SIZE_FP32;
ub_attn_score -> buf_for_load_attn_score_fp16 = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += MAX_LENG_PER_UB_PROC * SIZE_FP16 * 2;
ub_attn_score -> buf_for_diag_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2 * SIZE_FP32 * 2;
ub_attn_score -> buf_for_load_one_block_tri_mask_fp16 = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * rowPerTime * 2 * SIZE_FP16;
ub_attn_score -> maxJVectorUbAddr = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2 * 4 * SIZE_FP16;
ub_attn_score -> lastMaxJVectorUbAddr = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2*section_num* 4 * SIZE_FP16;
ub_attn_score -> subMaxValueResultVectorUbAddr = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2* 4 * SIZE_FP16;
ub_attn_score -> subMaxValueResultVectorUbAddr_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2* 4 * SIZE_FP32;
ub_attn_score -> ljVectorUbAddr = reinterpret_cast<__ubuf__ INPUT_T *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * SIZE_FP16;
ub_attn_score -> ljVectorUbAddr_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN * SIZE_FP32;
ub_attn_score -> lastLjVectorUbAddr = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2*section_num * SIZE_FP16;
ub_attn_score -> diagExpMaxJMatPingUbAddr = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += BASE_BLOCK_SIDE_LEN/2*BASE_BLOCK_SIDE_LEN* SIZE_FP32;
ub_attn_score -> pp_buf_for_attn_score_fp16[1] = ub_attn_score -> buf_for_load_attn_score_fp16 +
MAX_LENG_PER_UB_PROC;
ub_attn_score -> pp_buf_for_attn_score_fp16[0] = ub_attn_score -> buf_for_load_attn_score_fp16;
ub_attn_score -> pp_buf_for_load_one_block_tri_mask_fp16[1] =
ub_attn_score -> buf_for_load_one_block_tri_mask_fp16 + BASE_BLOCK_SIDE_LEN * rowPerTime;
ub_attn_score -> pp_buf_for_load_one_block_tri_mask_fp16[0] = ub_attn_score -> buf_for_load_one_block_tri_mask_fp16;
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::allocate_ubuf_for_norm (UB_FOR_NORMALIZE *ub_norm)
{
int32_t offset = 512;
int32_t sizeof_buf_for_load_O_fp32 = MAX_LENG_PER_UB_PROC * 4 * 2;
int32_t sizeof_buf_for_load_rowsum_fp32 = MAX_LENG_PER_UB_PROC / HEAD_DIM * 4 * 2;
int32_t sizeof_buf_for_brcb_rowsum_fp32 = MAX_LENG_PER_UB_PROC / HEAD_DIM * 2 * 4 * (32 / 4) ;
ub_norm -> buf_for_load_O_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += sizeof_buf_for_load_O_fp32;
ub_norm -> buf_for_load_rowsum_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += sizeof_buf_for_load_rowsum_fp32;
ub_norm -> buf_for_brcb_rowsum_fp32 = reinterpret_cast<__ubuf__ float *>((uintptr_t)offset);
offset += sizeof_buf_for_brcb_rowsum_fp32;
ub_norm -> o_ping_pong_interval = MAX_LENG_PER_UB_PROC;
ub_norm -> rowsum_ping_pong_interval = MAX_LENG_PER_UB_PROC / HEAD_DIM;
ub_norm -> rowsum_brcb_ping_pong_interval = ub_norm -> rowsum_ping_pong_interval * (32 /4);
}
UB每次处理信号的长度
三角阵拼接后,即使按左右两个section计算,也按最长的预留空间
**/
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ inline void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::get_sub_seq_length_per_proc(int32_t k_seq_len,
int32_t block_num_per_full_line,
int32_t *sub_seq_length_per_proc) {
*sub_seq_length_per_proc = k_seq_len > MAX_LENG_PER_UB_PROC * 1 ? MAX_LENG_PER_UB_PROC: k_seq_len;
if (*sub_seq_length_per_proc < MAX_LENG_PER_UB_PROC && block_num_per_full_line > 1)
{
*sub_seq_length_per_proc = *sub_seq_length_per_proc / 2;
}
}
非2的幂次长度,为了折半求vmax,需要进行padding
*/
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::get_padding_info_for_row_max(int32_t total_block_num,
int32_t *padding_block_num)
{
auto tail_num = total_block_num % BLOCK_NUM_FOR_VMAX;
if (tail_num == 0)
{
*padding_block_num = 0;
return;
}
int32_t total_block = 2;
while (total_block < BLOCK_NUM_FOR_VMAX) {
if (tail_num <= total_block) {
break;
}
total_block *= 2;
}
*padding_block_num = total_block - tail_num;
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::padding_for_row_max_or_rowsum(int32_t total_block_num,
int32_t padding_block_num, int32_t ping_pong_flag, int32_t paddingType,
__ubuf__ float * pp_buf_for_attn_score_fp16[]
)
{
if (padding_block_num == 0) {
return;
}
auto tail_num = total_block_num % MAX_BLOCK_PER_ONE_PROC;
__ubuf__ float *cur_buf_for_attn_score = pp_buf_for_attn_score_fp16[ping_pong_flag];
if (paddingType == 0)
vector_dup(cur_buf_for_attn_score + tail_num * BASE_BLOCK_SIDE_LEN * rowPerTime,
float(PADDING_FOR_MAX), padding_block_num * rowPerTime * 2, 1, 1, 8, 8);
else
vector_dup(cur_buf_for_attn_score + tail_num * BASE_BLOCK_SIDE_LEN * rowPerTime,
float(0), padding_block_num * rowPerTime * 2, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::padding_for_row_max_or_rowsum2(int32_t total_block_num,
int32_t padding_block_num, int32_t ping_pong_flag, int32_t paddingType,
__ubuf__ INPUT_T * pp_buf_for_attn_score_fp16[])
{
if (padding_block_num == 0) {
return;
}
auto tail_num = total_block_num % MAX_BLOCK_PER_ONE_PROC;
__ubuf__ INPUT_T *cur_buf_for_attn_score = pp_buf_for_attn_score_fp16[ping_pong_flag];
if (paddingType == 0)
vector_dup(cur_buf_for_attn_score + tail_num * BASE_BLOCK_SIDE_LEN * rowPerTime,
INPUT_T(PADDING_FOR_MAX2), padding_block_num * rowPerTime, 1, 1, 8, 8);
else
vector_dup(cur_buf_for_attn_score + tail_num * BASE_BLOCK_SIDE_LEN * rowPerTime,
INPUT_T(0), padding_block_num * rowPerTime, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::cacl_max(__ubuf__ INPUT_T * buf_for_cacl, int32_t _block_num)
{
auto cur_block_num = _block_num;
while (cur_block_num > 1)
{
vmax(buf_for_cacl, buf_for_cacl, buf_for_cacl + BASE_BLOCK_SIDE_LEN * cur_block_num * rowPerTime,
BASE_BLOCK_SIDE_LEN * cur_block_num * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
cur_block_num = cur_block_num / 2;
}
vmax(buf_for_cacl, buf_for_cacl, buf_for_cacl + BASE_BLOCK_SIDE_LEN * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
* 64个基本块以下求max
*/
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::process_cacl_max(int32_t basic_block_num,
int32_t padding_block_num,
bool pp_first_section,
__ubuf__ INPUT_T * cur_buf_for_attn_score,
__ubuf__ INPUT_T * cur_buf_for_rowmax,
__ubuf__ INPUT_T * buf_for_cacl_final_rowmax_fp16
)
{
int32_t all_block_num = basic_block_num + padding_block_num;
int32_t tail_block_num = all_block_num % 16;
int32_t done_block_num = all_block_num / 16 * 16;
bool from_buf_for_attn_score = false;
if (all_block_num == 64)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 32,
32 * BASE_BLOCK_SIDE_LEN/ BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
cacl_max(cur_buf_for_rowmax, 16);
}
else if (all_block_num >= 48)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 16,
16 * BASE_BLOCK_SIDE_LEN/ BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 32,
16 * BASE_BLOCK_SIDE_LEN/ BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * 8,
8 * BASE_BLOCK_SIDE_LEN/ BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
cacl_max(cur_buf_for_rowmax, 4);
}
else if (all_block_num >= 32)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 16 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 16 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
cacl_max(cur_buf_for_rowmax, 8);
}
else if (all_block_num >= 16)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 8 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 8 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
cacl_max(cur_buf_for_rowmax, 4);
}
if (tail_block_num == 8)
{
if (all_block_num < 16)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 4) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 4 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * 2 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 2 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
else
{
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 4) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 4 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * 3 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 2 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * 2 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
else if (tail_block_num == 4)
{
if (all_block_num < 16)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 2) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 2 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
else
{
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 2) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 2 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * 2 * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
else if (tail_block_num == 2)
{
if (all_block_num < 16)
{
vmax(cur_buf_for_rowmax,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 1) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
else
{
vmax(cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowPerTime,
cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * (done_block_num + 1) * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vmax(cur_buf_for_rowmax,
cur_buf_for_rowmax,
cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * rowPerTime,
BASE_BLOCK_SIDE_LEN * 1 * rowPerTime / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
}
auto src_buf = from_buf_for_attn_score ? cur_buf_for_attn_score : cur_buf_for_rowmax;
if (pp_first_section)
{
copy_ubuf_to_ubuf(buf_for_cacl_final_rowmax_fp16,
src_buf,
0,
rowPerTime,
128/16,
0,
0);
pipe_barrier(PIPE_V);
}
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::process_calc_sum(int32_t qk_triangle, PARAM_SHORT_SEQ_MAX param, int32_t ping_pong_flag,
bool first_line, __gm__ WORKSPACE_T * attn_score_gm,
__gm__ INPUT_T * attn_mask_gm, UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn,
bool sparse_flag, int32_t lines)
{
auto event_id = (ping_pong_flag == 0 ? EVENT_ID0 : EVENT_ID1);
__ubuf__ INPUT_T * cur_buf_for_attn_score = ub_attn.pp_buf_for_attn_score_fp16[ping_pong_flag];
__ubuf__ INPUT_T * cur_buf_for_rowmax = ub_attn.buf_for_cacl_rowmax_fp16;
__ubuf__ INPUT_T * cur_buf_for_final_rowmax = ub_attn.buf_for_cacl_final_rowmax_fp16;
__ubuf__ INPUT_T * cur_buf_for_record_rowmax = ub_attn.buf_for_record_rowmax_fp16;
__ubuf__ INPUT_T * ljVectorUbAddr=ub_attn.ljVectorUbAddr;
if (param.section_block_num > 0)
{
padding_for_row_max_or_rowsum2(param.section_block_num, param.section_padding_block_num, ping_pong_flag,
PADDING_TYPE_ROWSUM, ub_attn.pp_buf_for_attn_score_fp16);
}
process_cacl_wrap(param.section_block_num, param.section_padding_block_num, true,
cur_buf_for_attn_score, cur_buf_for_rowmax, cur_buf_for_final_rowmax, 1);
vcadd(ljVectorUbAddr+lines, cur_buf_for_final_rowmax, rowPerTime, 1, 1, 8, false);
pipe_barrier(PIPE_V);
}
* 8192长度以下的序列计算max
*/
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::process_line_phase_one_for_short_seq_max(bool is_head_section,
bool is_tail_section, int32_t qk_triangle, PARAM_SHORT_SEQ_MAX param, int32_t ping_pong_flag,
bool first_line, __gm__ WORKSPACE_T * attn_score_gm, __gm__ INPUT_T * attn_mask_gm,
UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn, bool sparse_flag,
int32_t lines, __ubuf__ INPUT_T* maxJVectorUbAddr,
__ubuf__ INPUT_T* lastMaxJVectorUbAddr,
__ubuf__ INPUT_T* subMaxValueResultVectorUbAddr,
int rowsPerDiag,
int rowmax_build_step)
{
event_t event_ids[] = {EVENT_ID1, EVENT_ID2, EVENT_ID3, EVENT_ID4, EVENT_ID5, EVENT_ID6};
auto event_id = (ping_pong_flag == 0 ? EVENT_ID0 : EVENT_ID1);
__ubuf__ INPUT_T * cur_buf_for_attn_score_fp16 = ub_attn.pp_buf_for_attn_score_fp16[ping_pong_flag];
__ubuf__ INPUT_T * cur_buf_for_rowmax = ub_attn.buf_for_cacl_rowmax_fp16;
__ubuf__ INPUT_T * cur_buf_for_final_rowmax = ub_attn.buf_for_cacl_final_rowmax_fp16;
__ubuf__ INPUT_T * cur_buf_for_vbrcb_rowmax_fp16 = ub_attn.buf_for_vbrcb_rowmax_fp16;
__ubuf__ INPUT_T * cur_buf_for_mask_fp32 = ub_attn.pp_buf_for_load_one_block_tri_mask_fp16[ping_pong_flag];
__ubuf__ INPUT_T * cur_buf_for_mask_fp16 = (ub_attn.pp_buf_for_load_one_block_tri_mask_fp16[ping_pong_flag] +
64 * rowPerTime);
__ubuf__ INPUT_T * cur_buf_for_head_mask = ub_attn.pp_buf_for_load_one_block_tri_mask_fp16[ping_pong_flag] +
256 * rowPerTime;
__ubuf__ INPUT_T * cur_buf_for_head_mask_fp16 = (__ubuf__ INPUT_T *)(cur_buf_for_head_mask + 64 * rowPerTime);
__ubuf__ float * cur_buf_for_record_rowmax = ub_attn.buf_for_record_rowmax_fp32;
__ubuf__ float * cur_buf_for_vbrcb_rowmax_fp32 = ub_attn.cur_buf_for_vbrcb_rowmax_fp32;
__ubuf__ float * cur_buf_for_score_fp32 = (__ubuf__ float *)ub_attn.buf_for_cacl_rowmax_fp16;
copy_gm_to_ubuf(
cur_buf_for_attn_score_fp16,
attn_score_gm + param.section_start_line_offset,
0,
param.section_block_num,
BASE_BLOCK_SIDE_LEN * rowPerTime / 16,
(BASE_BLOCK_DATA_NUM - rowPerTime * BASE_BLOCK_SIDE_LEN) / 16,
0);
int32_t cur_core_index = get_block_idx();
int32_t vector_id = get_subblockid();
bool sparse_tail_mask_flag =
sparse_flag && (param.apply_tri_mask == TRI_MATRIX_TAIL || param.apply_tri_mask == TRI_MATRIX_HEAD_AND_TAIL);
bool sparse_head_mask_flag =
sparse_flag && (param.apply_tri_mask == TRI_MATRIX_HEAD || param.apply_tri_mask == TRI_MATRIX_HEAD_AND_TAIL);
int32_t srcGap = rowPerTime == 1 ? 0: maskSeqLength - BASE_BLOCK_SIDE_LEN;
if ((qk_triangle == 1 && is_tail_section) || sparse_tail_mask_flag)
{
copy_gm_to_ubuf(
cur_buf_for_mask_fp16,
attn_mask_gm + param.section_mask_offset,
0,
rowPerTime,
BASE_BLOCK_SIDE_LEN / 16,
srcGap / 16,
0);
}
if (sparse_head_mask_flag) {
copy_gm_to_ubuf(
cur_buf_for_head_mask_fp16,
attn_mask_gm + param.section_mask_offset + 1,
0,
rowPerTime,
BASE_BLOCK_SIDE_LEN / 16,
srcGap / 16,
0);
}
set_flag(PIPE_MTE2, PIPE_V, event_id);
wait_flag(PIPE_MTE2, PIPE_V, event_id);
vmuls(cur_buf_for_attn_score_fp16, cur_buf_for_attn_score_fp16,
(INPUT_T)SCALE, param.section_block_num * rowPerTime, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
if ((is_tail_section) && (windowLen != 0))
{
int align_size = 256;
int tailLen = (align_size - windowLen) % BASE_BLOCK_SIDE_LEN;
int pad_block_num = 1;
half PADDING_VAL = -60000;
if (windowLen >= 128) {
vector_dup(cur_buf_for_attn_score_fp16 + rowPerTime * (param.section_block_num - 1) * BASE_BLOCK_SIDE_LEN,
(INPUT_T)PADDING_VAL,
rowPerTime,
1,
1,
8,
1);
pipe_barrier(PIPE_V);
pad_block_num = 2;
}
if (tailLen != 0) {
__set_reverse_mask(tailLen);
vector_dup(cur_buf_for_attn_score_fp16 + rowPerTime *
(param.section_block_num - pad_block_num) * BASE_BLOCK_SIDE_LEN,
(INPUT_T)PADDING_VAL,
rowPerTime,
1,
1,
8,
1);
}
pipe_barrier(PIPE_V);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
}
if (param.section_block_num > 0)
{
padding_for_row_max_or_rowsum2(param.section_block_num, param.section_padding_block_num,
ping_pong_flag, PADDING_TYPE_ROWMAX, ub_attn.pp_buf_for_attn_score_fp16);
pipe_barrier(PIPE_V);
}
process_cacl_max(param.section_block_num, param.section_padding_block_num, true, cur_buf_for_attn_score_fp16,
cur_buf_for_rowmax, cur_buf_for_final_rowmax);
pipe_barrier(PIPE_V);
vcmax(maxJVectorUbAddr, cur_buf_for_final_rowmax, rowPerTime, 1, 1, 8, (Order_t)0b10);
pipe_barrier(PIPE_V);
if (!is_head_section) {
int BLOCK_SIZE_64=64;
int BLOCK_SIZE_8=8;
event_t event_ids[] = {EVENT_ID1, EVENT_ID2, EVENT_ID3, EVENT_ID4, EVENT_ID5, EVENT_ID6};
__set_mask(rowPerTime);
vmax(maxJVectorUbAddr,
maxJVectorUbAddr,
lastMaxJVectorUbAddr,
1,
1,
1,
1,
BLOCK_SIZE_64 / BLOCK_SIZE_8,
BLOCK_SIZE_64 / BLOCK_SIZE_8,
BLOCK_SIZE_64 / BLOCK_SIZE_8);
pipe_barrier(PIPE_V);
vsub(subMaxValueResultVectorUbAddr,
lastMaxJVectorUbAddr,
maxJVectorUbAddr,
1,
1,
1,
1,
8,
8,
8);
pipe_barrier(PIPE_V);
vexp(subMaxValueResultVectorUbAddr,
subMaxValueResultVectorUbAddr,
1,
1,
1,
8,
8);
pipe_barrier(PIPE_V);
vadds(subMaxValueResultVectorUbAddr,
subMaxValueResultVectorUbAddr,
FP16_SMALL_NUM,
1,
1,
1,
8,
8);
pipe_barrier(PIPE_V);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
set_flag(PIPE_V, PIPE_S, event_ids[rowmax_build_step]);
}
vconv_f162f32(cur_buf_for_record_rowmax, maxJVectorUbAddr, 1, 1, 1, 8, 4);
pipe_barrier(PIPE_V);
int32_t vbrcbRepeatTimes = (rowPerTime + 7) / 8;
__set_mask(rowPerTime);
vbrcb((__ubuf__ uint32_t *)cur_buf_for_vbrcb_rowmax_fp32, (__ubuf__ uint32_t *)cur_buf_for_record_rowmax,
2, 8, vbrcbRepeatTimes);
pipe_barrier(PIPE_V);
vbrcb((__ubuf__ uint32_t *)cur_buf_for_vbrcb_rowmax_fp32 + 8, (__ubuf__ uint32_t *)cur_buf_for_record_rowmax,
2, 8, vbrcbRepeatTimes);
pipe_barrier(PIPE_V);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
vconv_f322f16(cur_buf_for_vbrcb_rowmax_fp16, cur_buf_for_vbrcb_rowmax_fp32, 1, 1, 1, 4, 8);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::process_line_phase_one_for_short_seq_exp_and_rowsum(int32_t qk_triangle,
PARAM_SHORT_SEQ_MAX param, int32_t ping_pong_flag, __gm__ WORKSPACE_T * attn_score_gm,
__gm__ INPUT_T * attn_mask_gm, UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn, int32_t offset, int32_t lines, bool sparse_flag)
{
auto event_id = (ping_pong_flag == 0 ? EVENT_ID0 : EVENT_ID1);
__ubuf__ INPUT_T * cur_buf_for_attn_score_fp16 = ub_attn.pp_buf_for_attn_score_fp16[ping_pong_flag];
__ubuf__ INPUT_T * cur_buf_for_vbrcb_rowmax_fp16 = ub_attn.buf_for_vbrcb_rowmax_fp16;
if (rowPerTime < 16) {
if (rowPerTime == 1) {
vsub(cur_buf_for_attn_score_fp16, cur_buf_for_attn_score_fp16, cur_buf_for_vbrcb_rowmax_fp16,
param.section_block_num * 2, 1, 1, 0, 8, 8, 0);
pipe_barrier(PIPE_V);
} else {
for (int32_t i = 0 ;i < rowPerTime; i++){
vsub(cur_buf_for_attn_score_fp16 + i * 128,
cur_buf_for_attn_score_fp16 + i * 128,
cur_buf_for_vbrcb_rowmax_fp16 + i * 16,
param.section_block_num,
1, 1, 0, rowPerTime * 8, rowPerTime * 8, 0);
pipe_barrier(PIPE_V);
}
}
}
else {
for (int32_t i = 0 ;i < param.section_block_num; i++) {
vsub(cur_buf_for_attn_score_fp16 + i * 128 * rowPerTime,
cur_buf_for_attn_score_fp16 + i * 128 * rowPerTime,
cur_buf_for_vbrcb_rowmax_fp16,
rowPerTime, 1, 1, 0, 16, 16, 1);
pipe_barrier(PIPE_V);
vsub(cur_buf_for_attn_score_fp16 + i * 128 * rowPerTime + 64,
cur_buf_for_attn_score_fp16 + i * 128 * rowPerTime + 64,
cur_buf_for_vbrcb_rowmax_fp16,
rowPerTime, 1, 1, 0, 16, 16, 1);
pipe_barrier(PIPE_V);
}
}
vexp(cur_buf_for_attn_score_fp16, cur_buf_for_attn_score_fp16, param.section_block_num * rowPerTime *
BASE_BLOCK_SIDE_LEN / BLOCK_NUM_8 / data_num_per_blk,
1, 1, 8, 8);
pipe_barrier(PIPE_V);
set_flag(PIPE_V, PIPE_MTE3, event_id);
wait_flag(PIPE_V, PIPE_MTE3, event_id);
int32_t total_offset_fp16 = param.section_start_line_offset;
copy_ubuf_to_gm(((__gm__ half *)(attn_score_gm + total_offset_fp16)),
(__ubuf__ half *)cur_buf_for_attn_score_fp16,
0,
param.section_block_num,
BASE_BLOCK_SIDE_LEN * rowPerTime / 16,
0,
((BASE_BLOCK_DATA_NUM - rowPerTime * BASE_BLOCK_SIDE_LEN) + 0) / 16);
process_calc_sum(qk_triangle,
param,
ping_pong_flag,
lines == 0,
attn_score_gm,
attn_mask_gm,
ub_attn,
sparse_flag,
lines);
}
template<typename INPUT_T, bool IF_BF16, typename WORKSPACE_T
> __aicore__ __inline__ void VectorForward<INPUT_T, IF_BF16, WORKSPACE_T
>::add_max_wrap(__ubuf__ INPUT_T * dest_addr, __ubuf__ INPUT_T * src_addr, __ubuf__ INPUT_T * src_addr2, int32_t rpt,
int32_t dst_stride, int32_t src_stride, int32_t src_stride2, int32_t dst_rpt_stride, int32_t src_rpt_stride,
int32_t src_rpt_stride2, int32_t optcode) {
if (optcode==2) {
vmax(dest_addr, src_addr, src_addr2, rpt, dst_stride, src_stride, src_stride2, dst_rpt_stride,
src_rpt_stride, src_rpt_stride2);
}else if (optcode==1) {
vadd(dest_addr, src_addr, src_addr2, rpt, dst_stride, src_stride, src_stride2, dst_rpt_stride,
src_rpt_stride, src_rpt_stride2);
}
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::cacl_wrap(__ubuf__ INPUT_T * buf_for_cacl, int32_t _block_num, int32_t optcode)
{
auto cur_block_num = _block_num;
auto repeat_factor = BASE_BLOCK_SIDE_LEN / (256 / SIZE_FP16);
while (cur_block_num > 1)
{
add_max_wrap(buf_for_cacl, buf_for_cacl, buf_for_cacl + BASE_BLOCK_SIDE_LEN * cur_block_num * rowPerTime,
cur_block_num * repeat_factor * rowPerTime, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
cur_block_num = cur_block_num / 2;
}
add_max_wrap(buf_for_cacl, buf_for_cacl, buf_for_cacl + BASE_BLOCK_SIDE_LEN * rowPerTime,
repeat_factor * rowPerTime, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>
::process_cacl_wrap(int32_t basic_block_num, int32_t padding_block_num, bool pp_first_section,
__ubuf__ INPUT_T * cur_buf_for_attn_score, __ubuf__ INPUT_T * cur_buf_for_rowmax,
__ubuf__ INPUT_T * buf_for_cacl_final_rowmax_fp16, int32_t optcode)
{
int32_t all_block_num = basic_block_num + padding_block_num;
int32_t tail_block_num = all_block_num % 16;
int32_t done_block_num = all_block_num / 16 * 16;
bool from_buf_for_attn_score = false;
int32_t threshold = 64;
int32_t sumrep = 0;
int32_t rowfactor = all_block_num<48 ? rowPerTime : 1;
if (all_block_num>=16) {
while (threshold>all_block_num) {
threshold/=2;
}
threshold /= 2;
sumrep = BASE_BLOCK_SIDE_LEN * threshold / (256 / SIZE_FP16);
add_max_wrap(cur_buf_for_rowmax, cur_buf_for_attn_score, cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN *
threshold * rowfactor, sumrep * rowfactor,
1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
if (all_block_num>=48&&all_block_num<64) {
add_max_wrap(cur_buf_for_rowmax, cur_buf_for_rowmax, cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * 32,
sumrep, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
threshold = 8;
add_max_wrap(cur_buf_for_rowmax, cur_buf_for_rowmax, cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * threshold,
16, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
}
cacl_wrap(cur_buf_for_rowmax, threshold/2, optcode);
}
rowfactor=rowPerTime;
if (tail_block_num>0) {
int32_t tail_threshold = 8;
while (tail_threshold>=2) {
if (tail_threshold==tail_block_num) {
break;
}
tail_threshold /=2;
}
if (tail_threshold==1) {
return;
}
tail_threshold /= 2;
sumrep = BASE_BLOCK_SIDE_LEN * tail_threshold / (256 / SIZE_FP16);
auto bufstart = cur_buf_for_rowmax + BASE_BLOCK_SIDE_LEN * (int32_t)(all_block_num>=16) * rowfactor;
auto tailstart = cur_buf_for_attn_score + BASE_BLOCK_SIDE_LEN * done_block_num * rowfactor;
add_max_wrap(bufstart, tailstart, tailstart + BASE_BLOCK_SIDE_LEN * tail_threshold * rowfactor,
sumrep * rowfactor, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
tail_threshold /= 2;
if (tail_threshold>=1) {
cacl_wrap(bufstart, tail_threshold, optcode);
}
pipe_barrier(PIPE_V);
if (all_block_num>=16) {
add_max_wrap(cur_buf_for_rowmax, cur_buf_for_rowmax, bufstart, 2 * rowfactor, 1, 1, 1, 8, 8, 8, optcode);
pipe_barrier(PIPE_V);
}
}
auto src_buf = from_buf_for_attn_score ? cur_buf_for_attn_score : cur_buf_for_rowmax;
copy_ubuf_to_ubuf(buf_for_cacl_final_rowmax_fp16,
src_buf,
0,
rowfactor,
128/16,
0,
0);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::form_diagmat(__ubuf__ float * diagmataddr, __ubuf__ float * vecaddr,
__gm__ float * gm_addr, int step, int rowsPerDiag, int32_t vector_id, int32_t ping_pong_flag, int type) {
event_t event_ids[] = {EVENT_ID1, EVENT_ID2, EVENT_ID3, EVENT_ID4, EVENT_ID5, EVENT_ID6};
int stepshift=rowsPerDiag*BASE_BLOCK_LENGTH*(step);
int init=step*rowsPerDiag+vector_id*BASE_BLOCK_LENGTH/2;
if (type==0) {
wait_flag(PIPE_MTE3, PIPE_S, event_ids[step]);
for (int line=0;line<rowsPerDiag;line++) {
int lineshift=init+(BASE_BLOCK_LENGTH)*line;
(diagmataddr+stepshift)[lineshift+line]=vecaddr[line];
}
}
else {
for (int line=0;line<rowsPerDiag;line++) {
int lineshift=init+(BASE_BLOCK_LENGTH)*line;
(diagmataddr+stepshift)[lineshift+line]=vecaddr[line];
}
}
set_flag(PIPE_S, PIPE_MTE3, event_ids[step]);
wait_flag(PIPE_S, PIPE_MTE3, event_ids[step]);
copy_ubuf_to_gm(gm_addr+stepshift, diagmataddr+stepshift,
0,
1,
BASE_BLOCK_LENGTH*rowsPerDiag*SIZE_FP32/32, 0, 0);
if (type==0) {
set_flag(PIPE_MTE3, PIPE_S, event_ids[step + 1]);
}
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::updateRowsum(__ubuf__ float* ljVectorUbAddr,
__ubuf__ float* lastLjVectorUbAddr,
__ubuf__ float* subMaxValueResultVectorUbAddr,
__ubuf__ float* tempUbAddr)
{
vmul((__ubuf__ float *)tempUbAddr,
subMaxValueResultVectorUbAddr,
lastLjVectorUbAddr,
1,
1,
1,
1,
8,
8,
8);
pipe_barrier(PIPE_V);
vadd(ljVectorUbAddr,
ljVectorUbAddr,
tempUbAddr,
1,
1,
1,
1,
8,
8,
8);
pipe_barrier(PIPE_V);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::backupMaxandSum(__ubuf__ float* ljVectorUbAddr,
__ubuf__ float* lastLjVectorUbAddr,
__ubuf__ INPUT_T* maxJVectorUbAddr,
__ubuf__ INPUT_T* lastMaxJVectorUbAddr)
{
copy_ubuf_to_ubuf(lastLjVectorUbAddr,
ljVectorUbAddr,
0,
1,
64/8,
0,
0);
pipe_barrier(PIPE_V);
copy_ubuf_to_ubuf(lastMaxJVectorUbAddr,
maxJVectorUbAddr,
0,
1,
256/16,
0,
0);
pipe_barrier(PIPE_V);
}
* 短序列一次拷贝
*/
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::attention_score_short_double_line_one(int32_t sectionLoop, int32_t sectionNum,
int32_t qk_triangle,
int32_t section_block_nums,
int32_t tri_matrix_mask_offset,
int32_t each_vector_proc_line_num,
int32_t local_section_start_line_offset,
bool is_head_section,
bool is_tail_section,
int32_t diag_offset,
__gm__ WORKSPACE_T * __restrict__ attn_score_gm,
__gm__ INPUT_T * __restrict__ attn_mask_gm,
UB_FOR_SHORT_LEN_ATTN_SCORE ub_attn, bool sparse_flag)
{
event_t event_ids[] = {EVENT_ID1, EVENT_ID2, EVENT_ID3, EVENT_ID4, EVENT_ID5, EVENT_ID6};
PARAM_SHORT_SEQ_MAX param_ping_pong[2] = {0};
param_ping_pong[0].section_start_line_offset = local_section_start_line_offset;
param_ping_pong[0].section_block_num = section_block_nums;
param_ping_pong[0].record_rowmax_offset = 0;
param_ping_pong[1].section_start_line_offset =
local_section_start_line_offset + BASE_BLOCK_SIDE_LEN * rowPerTime;
param_ping_pong[1].section_block_num = section_block_nums;
param_ping_pong[1].record_rowmax_offset = 0;
get_padding_info_for_row_max(param_ping_pong[0].section_block_num, ¶m_ping_pong[0].section_padding_block_num);
param_ping_pong[1].section_padding_block_num = param_ping_pong[0].section_padding_block_num;
param_ping_pong[0].section_mask_offset = tri_matrix_mask_offset;
param_ping_pong[1].section_mask_offset = tri_matrix_mask_offset + maskSeqLength * rowPerTime;
if (qk_triangle == 1)
{
param_ping_pong[0].section_mask_offset = tri_matrix_mask_offset;
param_ping_pong[1].section_mask_offset = tri_matrix_mask_offset + maskSeqLength * rowPerTime;
}
int32_t ping_pong_flag = 0;
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
int rowsPerDiag=16;
int rowmax_build_step=0;
int rowsum_build_step=0;
set_flag(PIPE_MTE3, PIPE_S, event_ids[rowmax_build_step]);
int32_t vector_id=get_subblockid();
__ubuf__ float* lastLjVectorUbAddr = ub_attn.lastLjVectorUbAddr + (sectionLoop % 2) * 64;
__ubuf__ INPUT_T* ljVectorUbAddr = ub_attn.ljVectorUbAddr;
__ubuf__ float* ljVectorUbAddr_fp32 = ub_attn.ljVectorUbAddr_fp32;
__ubuf__ INPUT_T* subMaxValueResultVectorUbAddr = ub_attn.subMaxValueResultVectorUbAddr;
__ubuf__ float* tempUbAddr = ub_attn.buf_for_record_rowmax_fp32;
__ubuf__ INPUT_T* maxJVectorUbAddr = ub_attn.maxJVectorUbAddr;
__ubuf__ INPUT_T* lastMaxJVectorUbAddr = ub_attn.lastMaxJVectorUbAddr + (sectionLoop % 2) * 128 * 2;
__ubuf__ float* subMaxValueResultVectorUbAddr_fp32 = ub_attn.subMaxValueResultVectorUbAddr_fp32;
vector_dup(maxJVectorUbAddr, 0, 2, 1, 1, 8, 1);
pipe_barrier(PIPE_V);
vector_dup(ljVectorUbAddr, 0, 1, 1, 1, 8, 1);
pipe_barrier(PIPE_V);
vector_dup(diagExpMaxJMatUbAddr, 0, BASE_BLOCK_SIDE_LEN/2*BASE_BLOCK_SIDE_LEN/64, 1, 1, 8, 1);
pipe_barrier(PIPE_V);
for (int32_t lines = 0; lines < each_vector_proc_line_num; lines+=rowPerTime)
{
auto event_id = (ping_pong_flag == 0 ? EVENT_ID0 : EVENT_ID1);
param_ping_pong[0].apply_tri_mask = TRI_MATRIX_NONE;
param_ping_pong[1].apply_tri_mask = TRI_MATRIX_NONE;
if (qk_triangle)
{
param_ping_pong[ping_pong_flag].apply_tri_mask = TRI_MATRIX_TAIL;
}
if (sparse_flag) {
param_ping_pong[ping_pong_flag].apply_tri_mask = TRI_MATRIX_HEAD_AND_TAIL;
}
int linesRound = lines / rowPerTime * 16;
wait_flag(PIPE_MTE3, PIPE_MTE2, event_id);
process_line_phase_one_for_short_seq_max(is_head_section, is_tail_section, qk_triangle,
param_ping_pong[ping_pong_flag],
ping_pong_flag,
lines == 0,
attn_score_gm,
attn_mask_gm,
ub_attn,
sparse_flag,
lines,
maxJVectorUbAddr + linesRound,
lastMaxJVectorUbAddr + linesRound,
subMaxValueResultVectorUbAddr + linesRound,
rowsPerDiag,
rowmax_build_step);
int32_t offset = lines * 128;
process_line_phase_one_for_short_seq_exp_and_rowsum(qk_triangle,
param_ping_pong[ping_pong_flag],
ping_pong_flag,
attn_score_gm,
attn_mask_gm,
ub_attn,
offset, lines, sparse_flag
);
set_flag(PIPE_MTE3, PIPE_MTE2, event_id);
if (!is_head_section) {
wait_flag(PIPE_V, PIPE_S, event_ids[rowmax_build_step]);
for (int i=0; i < rowPerTime; i++) {
(subMaxValueResultVectorUbAddr + lines)[i] = (subMaxValueResultVectorUbAddr + linesRound)[i];
}
set_flag(PIPE_S, PIPE_V, event_ids[rowmax_build_step]);
wait_flag(PIPE_S, PIPE_V, event_ids[rowmax_build_step]);
}
if ((lines+rowPerTime)%rowsPerDiag ==0 && !is_head_section)
{
int offset=((lines+rowPerTime)-rowsPerDiag);
__set_mask(rowsPerDiag);
vconv_f162f32(subMaxValueResultVectorUbAddr_fp32 + offset,
subMaxValueResultVectorUbAddr + offset, 1, 1, 1, 8, 4);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
set_flag(PIPE_V, PIPE_S, event_ids[rowmax_build_step]);
wait_flag(PIPE_V, PIPE_S, event_ids[rowmax_build_step]);
form_diagmat(diagExpMaxJMatUbAddr, subMaxValueResultVectorUbAddr_fp32+offset,
diag_rowmax_gm+ diag_offset, rowmax_build_step, rowsPerDiag, vector_id, ping_pong_flag, 0);
rowmax_build_step+=1;
}
param_ping_pong[ping_pong_flag].section_start_line_offset += BASE_BLOCK_SIZE_DOUBLE * rowPerTime;
param_ping_pong[ping_pong_flag].section_mask_offset += maskSeqLength * 2 * rowPerTime;
ping_pong_flag = 1 - ping_pong_flag;
}
wait_flag(PIPE_MTE3, PIPE_S, event_ids[rowmax_build_step]);
__set_mask(64);
vconv_f162f32(ljVectorUbAddr_fp32, ljVectorUbAddr, 1, 1, 1, 8, 4);
set_vector_mask((uint64_t)-1, (uint64_t)-1);
pipe_barrier(PIPE_V);
if (!is_head_section)
{
updateRowsum(ljVectorUbAddr_fp32, lastLjVectorUbAddr, subMaxValueResultVectorUbAddr_fp32, tempUbAddr);
}
if (!is_tail_section) {
backupMaxandSum(ljVectorUbAddr_fp32, lastLjVectorUbAddr, maxJVectorUbAddr, lastMaxJVectorUbAddr);
}
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
}
template <typename INPUT_T, bool IF_BF16, typename WORKSPACE_T> __aicore__ __inline__ void VectorForward<INPUT_T,
IF_BF16, WORKSPACE_T>::get_uni_rowmax_seq_info_per_proc(int32_t block_num_per_full_line,
int32_t sub_seq_length_per_proc,
int32_t *ping_block_offset_num,
int32_t *pong_block_offset_num,
int32_t *tail_block_offset_num,
int32_t *tail_block_num,
int32_t *ping_pong_times
)
{
*tail_block_num = block_num_per_full_line % 2;
*tail_block_offset_num = block_num_per_full_line - *tail_block_num;
*ping_block_offset_num = 0;
*pong_block_offset_num = (*tail_block_offset_num) / 2;
auto _total_size = *pong_block_offset_num * BASE_BLOCK_SIDE_LEN;
*ping_pong_times = _total_size / sub_seq_length_per_proc * 2;
if (_total_size % sub_seq_length_per_proc > 0)
{
*ping_pong_times += 2;
}
}
#endif
#endif