* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
* AddressMapping头文件: 用于寻址的前反向功能,与cube的寻址解耦开来
*/
#ifndef __ADDRESS_MAPPING_FORWARD_ONLINE_H__
#define __ADDRESS_MAPPING_FORWARD_ONLINE_H__
#include <cstdint>
#include "address_const.h"
namespace Address {
template<typename TYPE>
class AddressMappingForwardOnline {
public:
int64_t batchSize_;
int64_t headNum_;
int64_t querySequenceLen_;
int64_t keyValueSequenceLen_;
int64_t headDim_;
int64_t gqaGroupNum_;
int64_t isOdd_;
int64_t blockRows_;
int64_t blockCols_;
bool isTriangle_;
int64_t sparseMode_;
int64_t windowLength_;
int64_t windowSize_;
int64_t coreNum_;
int64_t localCoreIndex_;
int64_t blockNumPerCore_;
int64_t ky_;
int64_t kx_;
int64_t blockNumPerRow_;
int64_t blockNumPerColumn_;
int64_t blockNumPerHead_;
int64_t blockNumPerBatch_;
int64_t totalBlocks_;
int64_t totalRounds_;
int64_t totalLines_;
public:
* 预处理:提前分好段
* @param addr
* @param addr_len
* @param round_id
*/
__aicore__ __inline__ void
forward_addrMapping_pre(ForWardAddrOnline *addr, int64_t &addr_len, int64_t round_id)
{
if (this->coreNum_ == 0) {
return;
}
if (this->ky_ == 0) {
return;
}
if (this->blockNumPerRow_ == 0) {
return;
}
if (this->blockNumPerColumn_ == 0) {
return;
}
if (this->blockNumPerHead_ == 0) {
return;
}
if (this->blockNumPerBatch_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
int64_t skip_block = this->coreNum_ * this->blockNumPerRow_ * this->ky_;
int64_t outer_row = (round_id * this->kx_) / this->blockNumPerRow_ * skip_block;
int64_t inner_row = this->localCoreIndex_ * this->ky_ * this->blockNumPerRow_;
int64_t inner_col = (round_id * this->kx_) % this->blockNumPerRow_;
int64_t cur_block_id = outer_row + inner_row + inner_col;
int64_t row_num_per_round = this->ky_;
int64_t col_num_per_round = this->kx_;
int64_t cur_core_totalBlocks_ = this->blockNumPerRow_ * this->ky_ *
(this->totalLines_ / this->ky_ / this->coreNum_);
int64_t remain_block_num = (this->totalLines_ % (this->coreNum_ * this->ky_)) / this->ky_;
if (this->localCoreIndex_ < remain_block_num) {
cur_core_totalBlocks_ += this->ky_ * this->blockNumPerRow_;
}
int64_t remain = this->blockNumPerCore_;
if ((round_id + 1) * blockNumPerCore_ > cur_core_totalBlocks_) {
remain = blockNumPerCore_ - ((round_id + 1) * blockNumPerCore_ - cur_core_totalBlocks_);
}
int64_t Ky = this->ky_;
int64_t Kx = remain / Ky;
int64_t b = cur_block_id / this->blockNumPerBatch_;
int64_t n = cur_block_id % this->blockNumPerBatch_ / this->blockNumPerHead_;
int64_t block_row = cur_block_id % this->blockNumPerHead_ / (row_num_per_round * this->blockNumPerRow_);
int64_t ir =
(outer_row / this->blockNumPerRow_ + this->localCoreIndex_ * this->ky_) % this->blockNumPerColumn_;
int64_t ic = inner_col;
addr[0].b = b;
addr[0].n = n;
addr[0].iR = ir;
addr[0].iC = ic;
addr[0].kx = Kx;
addr[0].ky = Ky;
addr[0].k = remain;
int index = 0;
for (; remain > 0;) {
if (addr[index].iC + addr[index].kx > this->blockNumPerRow_) {
addr[index].kx = this->blockNumPerRow_ - addr[index].iC;
addr[index].k = addr[index].kx * addr[index].ky;
addr[index + 1].b = addr[index].b;
addr[index + 1].n = addr[index].n;
addr[index + 1].iR = addr[index].iR + addr[index].ky * coreNum_;
addr[index + 1].iC = 0;
addr[index + 1].k = remain - addr[index].k;
addr[index + 1].ky = addr[index].ky;
addr[index + 1].kx = addr[index + 1].k / addr[index + 1].ky;
if (addr[index + 1].iR >= this->blockNumPerColumn_) {
int64_t skip_head = addr[index + 1].iR / this->blockNumPerColumn_;
addr[index + 1].n = addr[index].n + skip_head;
addr[index + 1].iR = addr[index + 1].iR % this->blockNumPerColumn_;
int64_t skip_batch = addr[index + 1].n / this->headNum_;
if (addr[index + 1].n >= this->headNum_) {
addr[index + 1].b = addr[index].b + skip_batch;
addr[index + 1].n = addr[index + 1].n % this->headNum_;
}
}
}
remain -= addr[index].k;
++index;
}
addr_len = index;
}
* no-mask场景的cube1寻址
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param addr
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void addrMapping_rectangular_cube1(__gm__ T_LEFT *__restrict__ left,
__gm__ T_RIGHT *__restrict__ right, __gm__ T_OUTPUT *__restrict__ out, const ForWardAddrOnline *addr,
PhyAddrForwardCube1Online<T_LEFT, T_RIGHT, T_OUTPUT> *src, int64_t src_len, int64_t round_id) {
auto out_offset_round_even = out +
this->localCoreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
auto out_offset_round_odd =
out + this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE +
this->localCoreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
if (this->gqaGroupNum_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
for (int64_t i = 0; i < src_len; ++i) {
int64_t b = addr[i].b;
int64_t n = addr[i].n;
int64_t ir = addr[i].iR;
int64_t ic = addr[i].iC;
int64_t Kx = addr[i].kx;
int64_t Ky = addr[i].ky;
int64_t k = addr[i].k;
auto bn_left_offset = left + (b * this->headNum_ + n) * this->querySequenceLen_ * this->headDim_;
auto bn_right_offset = right + (b * this->headNum_ + n) * this->keyValueSequenceLen_ * this->headDim_;
int64_t g_index = n / (this->headNum_ / this->gqaGroupNum_);
auto bn_right_offset_gqa =
right + (b * this->gqaGroupNum_ + g_index) * this->keyValueSequenceLen_ * this->headDim_;
src[i].left = bn_left_offset + ir * ATTENTION_SCORE_BLOCK_SIZE;
src[i].right = bn_right_offset_gqa + ic * ATTENTION_SCORE_BLOCK_SIZE;
src[i].out = ((round_id + 1) % 2) ? out_offset_round_even : out_offset_round_odd;
src[i].kx = Kx;
src[i].ky = Ky;
src[i].k = k;
src[i].lineStride = Kx * ATTENTION_SCORE_BLOCK_SIZE;
src[i].lowerLeft = false;
src[i].upperRight = false;
src[i].onStartSection = ic == 0 ? true : false;
src[i].onEndSection = (ic + Kx >= this->blockNumPerRow_ - 1) ? true : false;
out_offset_round_even += k * ATTENTION_SCORE_BLOCK_SIZE;
out_offset_round_odd += k * ATTENTION_SCORE_BLOCK_SIZE;
}
}
* mask场景的cube1寻址
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param addr
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void addrMapping_triangular_cube1(__gm__ T_LEFT *__restrict__ left,
__gm__ T_RIGHT *__restrict__ right, __gm__ T_OUTPUT *__restrict__ out, const ForWardAddrOnline *addr,
PhyAddrForwardCube1Online<T_LEFT, T_RIGHT, T_OUTPUT> *src, int64_t &src_len, int64_t round_id) {
int64_t index = 0;
int64_t tri_block_num_per_column = this->blockNumPerColumn_ - 2 * this->isOdd_;
if (this->gqaGroupNum_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
for (int64_t i = 0; i < src_len && index < src_len; ++i) {
int64_t iR = addr[i].iR;
int64_t iC = addr[i].iC;
int64_t kx = addr[i].kx;
int64_t ky = addr[i].ky;
int64_t k = addr[i].k;
int64_t switch_index = tri_block_num_per_column + iR + (iR + 1) % 2;
int64_t row_offset = (iR + 1) % 2 == 1 ? -1 : 1;
int64_t row_index_left_section = tri_block_num_per_column + iR;
int64_t row_index_right_section = tri_block_num_per_column - 1 - iR + row_offset;
int64_t col_index_left_section = iC;
int64_t col_index_right_section = iC - switch_index - 1;
int64_t g_index = addr[i].n / (this->headNum_ / this->gqaGroupNum_);
int64_t bn_offset_gqa_right_matrix = (addr[i].b * this->gqaGroupNum_ + g_index) *
this->keyValueSequenceLen_ * this->headDim_;
int64_t bn_offset_right_matrix =
(addr[i].b * this->headNum_ + addr[i].n) * this->keyValueSequenceLen_ * this->headDim_;
int64_t bn_offset_left_matrix =
(addr[i].b * this->headNum_ + addr[i].n) * this->querySequenceLen_ * this->headDim_;
int64_t q_left_offset_section = row_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
int64_t q_right_offset_section = row_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
int64_t k_left_offset_section = col_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
int64_t k_right_offset_section = col_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
bool sparse_flag = false;
int64_t window_block_size = this->windowLength_ / 128;
if (this->isTriangle_ && this->sparseMode_ == 1) {
sparse_flag = iR > ((window_block_size - 1) / 2) ? true : false;
switch_index = (window_block_size / 2) + iR;
row_index_left_section = (window_block_size / 2) + iR;
row_index_right_section = (window_block_size / 2) - 1 - iR;
col_index_left_section = iC;
col_index_right_section = iC - switch_index - 1;
q_left_offset_section = row_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
q_right_offset_section = row_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
k_left_offset_section = col_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
k_right_offset_section = col_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
}
int64_t row_index_sparse_section = iR + (window_block_size / 2);
int64_t col_index_sparse_section = iR + iC - (window_block_size / 2);
int64_t q_sparse_offset_section = row_index_sparse_section * ATTENTION_SCORE_BLOCK_SIZE;
int64_t k_sparse_offset_section = col_index_sparse_section * ATTENTION_SCORE_BLOCK_SIZE;
int64_t out_offset = ((addr[i].b * this->headNum_ + addr[i].n) * this->blockNumPerRow_ *
this->blockNumPerColumn_ + (iR * this->blockNumPerRow_)) * ATTENTION_SCORE_BLOCK_SIZE;
int64_t db_offset =
(round_id % 2) * (this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE);
if (index == 0) {
src[index].out = out + db_offset +
this->localCoreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
} else {
src[index].out = src[index - 1].out + src[index - 1].k * ATTENTION_SCORE_BLOCK_SIZE;
}
if (!sparse_flag && switch_index < iC) {
if (index >= src_len) {
break;
}
src[index].left = left + bn_offset_left_matrix + q_right_offset_section;
src[index].right = right + bn_offset_gqa_right_matrix + k_right_offset_section;
src[index].kx = kx;
src[index].ky = ky;
src[index].k = k;
src[index].upperRight = (iC + src[index].kx >= blockNumPerRow_ - 1) ? true : false;
src[index].lowerLeft = false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onEndSection = false;
src[index].onStartSection = false;
src[index].onStartSection = iC == 0 ? true : false;
src[index].onEndSection = src[index].upperRight;
index++;
} else if (!sparse_flag && iC <= switch_index && iC + kx - 1 > switch_index) {
if (index + 1 >= src_len) {
break;
}
src[index].left = left + bn_offset_left_matrix + q_left_offset_section;
src[index].right = right + bn_offset_gqa_right_matrix + k_left_offset_section;
src[index].kx = switch_index - iC + 1;
src[index].ky = ky;
src[index].k = src[index].kx * src[index].ky;
src[index].upperRight = true;
src[index].lowerLeft = false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onEndSection = true;
src[index].onStartSection = iC == 0 ? true : false;
src[index + 1].left = left + bn_offset_left_matrix + q_right_offset_section;
src[index + 1].right = right + bn_offset_gqa_right_matrix;
src[index + 1].out = src[index].out + src[index].k * ATTENTION_SCORE_BLOCK_SIZE;
src[index + 1].kx = kx - src[index].kx;
src[index + 1].ky = ky;
src[index + 1].k = src[index + 1].kx * src[index + 1].ky;
src[index + 1].upperRight =
switch_index + src[index + 1].kx >= this->blockNumPerRow_ - 1 ? true : false;
src[index + 1].lowerLeft = false;
src[index + 1].lineStride = src[index + 1].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index + 1].onStartSection = true;
src[index + 1].onEndSection = src[index + 1].upperRight;
index += 2;
} else if (!sparse_flag && iC <= switch_index && iC + kx - 1 <= switch_index) {
if (index >= src_len) {
break;
}
src[index].left = left + bn_offset_left_matrix + q_left_offset_section;
src[index].right = right + bn_offset_gqa_right_matrix + k_left_offset_section;
src[index].kx = kx;
src[index].ky = ky;
src[index].k = k;
src[index].upperRight = iC + src[index].kx - 1 >= switch_index ? true : false;
src[index].lowerLeft = false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onStartSection = iC == 0 ? true : false;
src[index].onEndSection = src[index].upperRight;
index++;
} else {
if (index >= src_len) {
break;
}
src[index].left = left + bn_offset_left_matrix + q_sparse_offset_section;
src[index].right = right + bn_offset_gqa_right_matrix + k_sparse_offset_section;
src[index].k = k;
index++;
}
}
src_len = index;
}
* no-mask场景的cube2寻址
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param addr
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void addrMapping_rectangular_cube2(__gm__ T_LEFT *__restrict__ left,
__gm__ T_RIGHT *__restrict__ right, __gm__ T_OUTPUT *__restrict__ out, const ForWardAddrOnline *addr,
PhyAddrForwardCube2Online<T_LEFT, T_RIGHT, T_OUTPUT> *src, int64_t src_len, int64_t round_id) {
auto left_offset_round_even =
left + this->localCoreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
auto left_offset_round_odd = left + this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE +
this->localCoreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
if (this->gqaGroupNum_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
for (int64_t i = 0; i < src_len; ++i) {
auto b = addr[i].b;
auto n = addr[i].n;
auto ir = addr[i].iR;
auto ic = addr[i].iC;
auto Kx = addr[i].kx;
auto Ky = addr[i].ky;
auto k = addr[i].k;
auto bn_right_offset = right + (b * this->headNum_ + n) * this->keyValueSequenceLen_ * this->headDim_;
int64_t g_index = n / (this->headNum_ / this->gqaGroupNum_);
auto bn_right_offset_gqa =
right + (b * this->gqaGroupNum_ + g_index) * this->keyValueSequenceLen_ * this->headDim_;
auto bn_out_offset = out + (b * this->headNum_ + n) * this->querySequenceLen_ * this->headDim_;
src[i].left = ((round_id + 1) % 2) ? left_offset_round_even : left_offset_round_odd;
src[i].right = bn_right_offset_gqa + ic * ATTENTION_SCORE_BLOCK_SIZE;
src[i].out = bn_out_offset + ir * ATTENTION_SCORE_BLOCK_SIZE;
src[i].kx = Kx;
src[i].ky = Ky;
src[i].k = k;
src[i].lineStride = Kx * ATTENTION_SCORE_BLOCK_SIZE;
src[i].lowerLeft = false;
src[i].upperRight = false;
src[i].onStartSection = ic == 0 ? true : false;
src[i].onEndSection = (ic + Kx >= this->blockNumPerRow_ - 1) ? true : false;
left_offset_round_even += k * ATTENTION_SCORE_BLOCK_SIZE;
left_offset_round_odd += k * ATTENTION_SCORE_BLOCK_SIZE;
}
}
* mask场景饿cube2寻址
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param addr
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void addrMapping_triangular_cube2(__gm__ T_LEFT *__restrict__ left,
__gm__ T_RIGHT *__restrict__ right, __gm__ T_OUTPUT *__restrict__ out, const ForWardAddrOnline *addr,
PhyAddrForwardCube2Online<T_LEFT, T_RIGHT, T_OUTPUT> *src, int64_t &src_len, int64_t round_id) {
int64_t index = 0;
int64_t tri_blocks_per_column = this->blockNumPerColumn_ - 2 * this->isOdd_;
if (this->gqaGroupNum_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
for (int64_t i = 0; i < src_len && index < src_len; ++i) {
int64_t iR = addr[i].iR;
int64_t iC = addr[i].iC;
int64_t kx = addr[i].kx;
int64_t ky = addr[i].ky;
int64_t k = addr[i].k;
int64_t switch_index = tri_blocks_per_column + iR + (iR + 1) % 2;
int64_t row_index_left_section = tri_blocks_per_column + iR;
int64_t row_offset = (iR + 1) % 2 == 1 ? -1 : 1;
int64_t row_index_right_section = tri_blocks_per_column - 1 - iR + row_offset;
int64_t col_index_left_section = iC;
int64_t col_index_right_section = iC - switch_index - 1;
int64_t g_index = addr[i].n / (this->headNum_ / this->gqaGroupNum_);
int64_t bn_offset_gqa_right_matrix =
(addr[i].b * this->gqaGroupNum_ + g_index) * this->keyValueSequenceLen_ * this->headDim_;
int64_t bn_offset_left_matrix = ((addr[i].b * this->headNum_ + addr[i].n) * this->blockNumPerRow_ *
this->blockNumPerColumn_) * ATTENTION_SCORE_BLOCK_SIZE;
int64_t bn_offset_right_matrix = (addr[i].b * this->headNum_ + addr[i].n) *
this->keyValueSequenceLen_ * this->headDim_;
int64_t bn_offset_out = (addr[i].b * this->headNum_ + addr[i].n) * this->querySequenceLen_ *
this->headDim_;
bool sparse_flag = false;
int64_t window_block_size = this->windowLength_ / 128;
if (this->isTriangle_ && this->sparseMode_ == 1) {
sparse_flag = iR > ((window_block_size - 1) / 2) ? true : false;
switch_index = (window_block_size / 2) + iR;
row_index_left_section = (window_block_size / 2) + iR;
row_index_right_section = (window_block_size / 2) - 1 - iR;
col_index_left_section = iC;
col_index_right_section = iC - switch_index - 1;
}
int64_t row_index_sparse_section = iR + (window_block_size / 2);
int64_t col_index_sparse_section = iR + iC - (window_block_size / 2);
int64_t db_offset =
(round_id % 2) * (this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE);
if (index == 0) {
src[index].left = left + db_offset + this->localCoreIndex_ * this->blockNumPerCore_ *
ATTENTION_SCORE_BLOCK_SIZE;
} else {
src[index].left = src[index - 1].left + src[index - 1].k * ATTENTION_SCORE_BLOCK_SIZE;
}
if (!sparse_flag && switch_index < iC) {
if (index >= src_len) {
break;
}
src[index].right = right + bn_offset_gqa_right_matrix +
col_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].out = out + bn_offset_out + row_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].kx = kx;
src[index].ky = ky;
src[index].k = k;
src[index].upperRight = (iC + src[index].kx >= blockNumPerRow_ - 1) ? true : false;
src[index].lowerLeft = false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onEndSection = false;
src[index].onStartSection = (iC == switch_index + 1) ? true : false;
if (src[index].upperRight) {
src[index].onEndSection = true;
}
if (iC == 0) {
src[index].onStartSection = true;
}
++index;
} else if (!sparse_flag && iC <= switch_index && iC + kx - 1 > switch_index) {
if (index + 1 >= src_len) {
break;
}
src[index].right = right + bn_offset_gqa_right_matrix +
col_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].out = out + bn_offset_out + row_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].kx = switch_index - iC + 1;
src[index].ky = ky;
src[index].k = src[index].kx * src[index].ky;
src[index].upperRight = true;
src[index].lowerLeft = false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onStartSection = iC == 0 ? true : false;
src[index].onEndSection = true;
src[index + 1].left = src[index].left + src[index].k * ATTENTION_SCORE_BLOCK_SIZE;
src[index + 1].right = right + bn_offset_gqa_right_matrix;
src[index + 1].out = out + bn_offset_out + row_index_right_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index + 1].kx = kx - src[index].kx;
src[index + 1].ky = ky;
src[index + 1].k = src[index + 1].kx * src[index + 1].ky;
src[index + 1].upperRight = switch_index + src[index + 1].kx >= blockNumPerRow_ - 1 ? true : false;
src[index + 1].lowerLeft = false;
src[index + 1].lineStride = src[index + 1].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index + 1].onStartSection = true;
src[index + 1].onEndSection = src[index + 1].upperRight;
index += 2;
} else if (!sparse_flag && iC <= switch_index && iC + kx - 1 <= switch_index) {
if (index >= src_len) {
break;
}
src[index].right = right + bn_offset_gqa_right_matrix +
col_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].out = out + bn_offset_out + row_index_left_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].kx = kx;
src[index].ky = ky;
src[index].k = k;
src[index].lowerLeft = false;
src[index].upperRight = iC + src[index].kx - 1 >= switch_index ? true : false;
src[index].lineStride = src[index].kx * ATTENTION_SCORE_BLOCK_SIZE;
src[index].onStartSection = iC == 0 ? true : false;
src[index].onEndSection = src[index].upperRight;
index++;
} else {
if (index >= src_len) {
break;
}
src[index].right = right + bn_offset_gqa_right_matrix +
col_index_sparse_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].out = out + bn_offset_out + row_index_sparse_section * ATTENTION_SCORE_BLOCK_SIZE;
src[index].k = k;
++index;
}
}
src_len = index;
}
public:
* 类的初始化
* @param batch_size
* @param head_num
* @param query_sequence_len
* @param key_value_sequence_len
* @param head_dim
* @param gqa_group_num
* @param is_triangle
* @param sparse_mode
* @param window_length
*/
__aicore__ __inline__ void init(int64_t batch_size, int64_t head_num, int64_t query_sequence_len,
int64_t key_value_sequence_len, int64_t head_dim, int64_t gqa_group_num, bool is_triangle,
int64_t sparse_mode, int64_t window_length)
{
this->batchSize_ = batch_size;
this->headNum_ = head_num;
this->querySequenceLen_ = query_sequence_len;
this->keyValueSequenceLen_ = key_value_sequence_len;
this->headDim_ = head_dim;
this->gqaGroupNum_ = gqa_group_num;
this->isOdd_ = this->querySequenceLen_ / BASE_BLOCK_LENGTH / 2 % 2;
this->blockRows_ = this->querySequenceLen_ / BASE_BLOCK_LENGTH;
this->blockCols_ = this->keyValueSequenceLen_ / BASE_BLOCK_LENGTH;
this->isTriangle_ = is_triangle;
this->sparseMode_ = sparse_mode;
this->windowLength_ = window_length;
this->windowSize_ = this->windowLength_ / SIZE_128;
}
* 总轮次
* @return
*/
__aicore__ __inline__ int64_t get_total_rounds()
{
return this->totalRounds_;
}
* 判断此轮次是否参与计算
* @param round_id
* @return
*/
__aicore__ __inline__ bool is_running(int64_t round_id)
{
if (this->blockNumPerRow_ == 0) {
return false;
}
int64_t skip_block = this->coreNum_ * this->blockNumPerRow_ * this->ky_;
int64_t outer_row = (round_id * this->kx_) / this->blockNumPerRow_ * skip_block;
int64_t inner_row = this->localCoreIndex_ * this->ky_ * this->blockNumPerRow_;
int64_t inner_col = (round_id * this->kx_) % this->blockNumPerRow_;
int64_t cur_block_id = outer_row + inner_row + inner_col;
return (cur_block_id < this->totalBlocks_);
}
* 设置tiling信息
* @param core_num
* @param local_core_index
* @param block_num_per_core
* @param ky
*/
__aicore__ __inline__ void set_tiling(int64_t core_num, int64_t local_core_index,
int64_t block_num_per_core, int64_t ky)
{
if (this->ky_ == 0) {
return;
}
if (this->blockNumPerCore_ == 0) {
return;
}
this->coreNum_ = core_num;
this->localCoreIndex_ = local_core_index;
this->blockNumPerCore_ = block_num_per_core;
this->ky_ = ky;
this->kx_ = this->blockNumPerCore_ / this->ky_;
this->blockNumPerColumn_ = this->blockRows_;
this->blockNumPerRow_ = this->blockCols_;
if (this->isTriangle_) {
this->blockNumPerColumn_ = this->blockRows_ / 2 + this->isOdd_;
this->blockNumPerRow_ = this->blockCols_ + 2 * (1 - this->isOdd_);
if (this->sparseMode_ == 1) {
this->blockNumPerColumn_ = this->blockRows_ - this->windowSize_ / 2;
this->blockNumPerRow_ = this->windowSize_ + 1;
}
}
this->blockNumPerHead_ = this->blockNumPerColumn_ * this->blockNumPerRow_;
this->blockNumPerBatch_ = this->headNum_ * this->blockNumPerHead_;
this->totalBlocks_ = this->batchSize_ * this->blockNumPerBatch_;
this->totalLines_ = this->blockNumPerColumn_ * this->batchSize_ * this->headNum_;
int64_t segment_line_per_round = this->ky_ * this->coreNum_;
int64_t totalRounds_segment_line = (this->totalLines_ + segment_line_per_round - 1) /
segment_line_per_round;
int64_t total_block_num = totalRounds_segment_line * this->ky_ * this->blockNumPerRow_;
this->totalRounds_ = (total_block_num + this->blockNumPerCore_ - 1) / this->blockNumPerCore_;
}
* cube1的接口
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void
addrMapping_cube1(__gm__ T_LEFT *__restrict__ left, __gm__ T_RIGHT *__restrict__ right,
__gm__ T_OUTPUT *__restrict__ out, PhyAddrForwardCube1Online<T_LEFT, T_RIGHT, T_OUTPUT> *src,
int64_t &src_len, int64_t round_id) {
ForWardAddrOnline forward_addr[MAX_SWITCH_TIME];
forward_addrMapping_pre(forward_addr, src_len, round_id);
if (this->isTriangle_) {
addrMapping_triangular_cube1(left, right, out, forward_addr, src, src_len, round_id);
} else {
addrMapping_rectangular_cube1(left, right, out, forward_addr, src, src_len, round_id);
}
}
* cube2的接口
* @tparam T_LEFT
* @tparam T_RIGHT
* @tparam T_OUTPUT
* @param left
* @param right
* @param out
* @param src
* @param src_len
* @param round_id
*/
template<typename T_LEFT, typename T_RIGHT, typename T_OUTPUT>
__aicore__ __inline__ void
addrMapping_cube2(__gm__ T_LEFT *__restrict__ left, __gm__ T_RIGHT *__restrict__ right,
__gm__ T_OUTPUT *__restrict__ out, PhyAddrForwardCube2Online<T_LEFT, T_RIGHT, T_OUTPUT> *src,
int64_t &src_len, int64_t round_id) {
ForWardAddrOnline forward_addr[MAX_SWITCH_TIME];
forward_addrMapping_pre(forward_addr, src_len, round_id);
if (this->isTriangle_) {
addrMapping_triangular_cube2(left, right, out, forward_addr, src, src_len, round_id);
} else {
addrMapping_rectangular_cube2(left, right, out, forward_addr, src, src_len, round_id);
}
}
};
}
#endif