* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
* 该头文件用于前向vector_online的寻址模块
*/
#ifndef ADDRESS_MODULE_ADDRESSMAPPING_VECTOR_FORWARD_ONLINE_H
#define ADDRESS_MODULE_ADDRESSMAPPING_VECTOR_FORWARD_ONLINE_H
#include <cstdint>
#include "address_const.h"
namespace Address {
class AddressMappingVectorForwardOnline {
public:
int64_t batchSize_;
int64_t headNum_;
int64_t querySequenceLen_;
int64_t keyValueSequenceLen_;
int64_t maskSequenceLen_;
int64_t coreNum_;
int64_t vectorNum_;
int64_t coreIndex_;
int64_t vectorIndex_;
bool isTriangle_;
int64_t sparseMode_;
int64_t windowSize_;
int64_t isOdd_;
int64_t blockNumPerCol_;
int64_t blockNumPerRow_;
int64_t blockNumPerHead_;
int64_t blockNumPerBatch_;
int64_t blockRowsPerHead_;
int64_t blockRowsPerBatch_;
int64_t totalRows_;
int64_t totalBlocks_;
int64_t totalRounds_;
int64_t blockNumPerCore_;
int64_t kx_;
int64_t ky_;
int64_t processLineNum_;
int64_t coreOffset_;
int64_t startLine_;
int64_t startLineOffset_;
int64_t vectorStartOffset_;
int64_t maskOffset_;
int64_t totalLines_;
int64_t normProcessLine_;
int64_t normLineOffset_;
public:
* 预处理:提前分好段: vector分段略微和cube不同,遇到跳变点也会提前切分
* @param addr
* @param addr_len
* @param round_id
*/
__aicore__ __inline__ void
forward_addrMapping_pre(ForWardAddrOnline *addr, int64_t &addr_len, int64_t round_id)
{
if (this->coreNum_ == 0) {
return;
}
if (this->ky_ == 0) {
return;
}
if (this->blockNumPerRow_ == 0) {
return;
}
if (this->blockNumPerCol_ == 0) {
return;
}
if (this->blockNumPerHead_ == 0) {
return;
}
if (this->blockNumPerBatch_ == 0) {
return;
}
if (this->headNum_ == 0) {
return;
}
int64_t skip_block = this->coreNum_ * this->blockNumPerRow_ * this->ky_;
int64_t outer_row = (round_id * this->kx_) / this->blockNumPerRow_ * skip_block;
int64_t inner_row = this->coreIndex_ * this->ky_ * this->blockNumPerRow_;
int64_t inner_col = (round_id * this->kx_) % this->blockNumPerRow_;
int64_t cur_block_id = outer_row + inner_row + inner_col;
int64_t row_num_per_round = this->ky_;
int64_t col_num_per_round = this->kx_;
int64_t cur_core_total_blocks = this->blockNumPerRow_ * this->ky_ *
(this->totalRows_ / this->ky_ / this->coreNum_);
int64_t remain_block_num = (this->totalRows_ % (this->coreNum_ * this->ky_)) / this->ky_;
if (this->coreIndex_ < remain_block_num) {
cur_core_total_blocks += this->ky_ * this->blockNumPerRow_;
}
int64_t remain = this->blockNumPerCore_;
if ((round_id + 1) * blockNumPerCore_ > cur_core_total_blocks) {
remain = this->blockNumPerCore_ -
((round_id + 1) * this->blockNumPerCore_ - cur_core_total_blocks);
}
int64_t Ky = this->ky_;
int64_t Kx = remain / Ky;
int64_t b = cur_block_id / this->blockNumPerBatch_;
int64_t n = cur_block_id % this->blockNumPerBatch_ / this->blockNumPerHead_;
int64_t ir = (outer_row / this->blockNumPerRow_ + this->coreIndex_ * this->ky_) %
this->blockNumPerCol_;
int64_t ic = inner_col;
addr[0].b = b;
addr[0].n = n;
addr[0].iR = ir;
addr[0].iC = ic;
addr[0].kx = Kx;
addr[0].ky = Ky;
addr[0].k = remain;
int index = 0;
for (; remain > 0;) {
int64_t switch_index = addr[index].iR + this->blockNumPerCol_ + (addr[index].iR + 1) % 2;
switch_index = this->isTriangle_ ? switch_index - 2 * this->isOdd_:switch_index;
if (this->isTriangle_ && (addr[index].iC <= switch_index) &&
(addr[index].iC + addr[index].kx - 1 > switch_index)) {
addr[index].kx = switch_index - addr[index].iC + 1;
addr[index].k = addr[index].kx * addr[index].ky;
addr[index + 1].b = addr[index].b;
addr[index + 1].n = addr[index].n;
addr[index + 1].iR = addr[index].iR;
addr[index + 1].iC = switch_index + 1;
addr[index + 1].k = remain - addr[index].k;
addr[index + 1].ky = addr[index].ky;
addr[index + 1].kx = addr[index + 1].k / addr[index + 1].ky;
}
if (addr[index].iC + addr[index].kx > this->blockNumPerRow_) {
addr[index].kx = this->blockNumPerRow_ - addr[index].iC;
addr[index].k = addr[index].kx * addr[index].ky;
addr[index + 1].b = addr[index].b;
addr[index + 1].n = addr[index].n;
addr[index + 1].iR = addr[index].iR + addr[index].ky * coreNum_;
addr[index + 1].iC = 0;
addr[index + 1].k = remain - addr[index].k;
addr[index + 1].ky = addr[index].ky;
addr[index + 1].kx = addr[index + 1].k / addr[index + 1].ky;
if (addr[index + 1].iR >= this->blockNumPerCol_) {
int64_t skip_head = addr[index + 1].iR / this->blockNumPerCol_;
addr[index + 1].n = addr[index].n + skip_head;
addr[index + 1].iR = addr[index + 1].iR % this->blockNumPerCol_;
int64_t skip_batch = addr[index + 1].n / this->headNum_;
if (addr[index + 1].n >= this->headNum_) {
addr[index + 1].b = addr[index].b + skip_batch;
addr[index + 1].n = addr[index + 1].n % this->headNum_;
}
}
}
remain -= addr[index].k;
++index;
}
int64_t pos = 0;
for (size_t i = 0; i < index; ++i) {
if (addr[i].k == 0) {
continue;
}
addr[pos++] = addr[i];
}
addr_len = pos;
}
* 设置全局的信息:轮次、总块数等
* @return
*/
__aicore__ __inline__ void set_global_info()
{
this->blockNumPerCol_ = this->querySequenceLen_ / SIZE_128;
this->blockNumPerRow_ = this->keyValueSequenceLen_ / SIZE_128;
if (this->isTriangle_) {
this->blockNumPerCol_ = this->querySequenceLen_ / SIZE_128 / 2 + this->isOdd_;
this->blockNumPerRow_ = this->keyValueSequenceLen_ / SIZE_128 + 2 * (1 - this->isOdd_);
}
this->blockNumPerHead_ = this->blockNumPerCol_ * this->blockNumPerRow_;
this->blockNumPerBatch_ = this->blockNumPerHead_ * this->headNum_;
this->totalBlocks_ = this->blockNumPerBatch_ * this->batchSize_;
this->blockRowsPerHead_ = this->blockNumPerCol_;
this->blockRowsPerBatch_ = this->blockRowsPerHead_ * this->headNum_;
this->totalRows_ = this->blockRowsPerBatch_ * this->batchSize_;
int64_t segment_line_per_round = this->ky_ * this->coreNum_;
int64_t totalRounds_segment_line = (this->totalRows_ + segment_line_per_round - 1) /
segment_line_per_round;
int64_t total_block_num = totalRounds_segment_line * this->ky_ * this->blockNumPerRow_;
this->totalRounds_ = (total_block_num + this->blockNumPerCore_ - 1) /
this->blockNumPerCore_;
}
* 设置vector本地的信息,当前vector的序号等
* @param vector_index
* @return
*/
__aicore__ __inline__ void set_local_info()
{
if (this->vectorNum_ == 0) {
return;
}
this->processLineNum_ = SIZE_128 / 2;
this->totalLines_ = this->batchSize_ * this->headNum_ * this->querySequenceLen_;
int64_t vector_id = this->coreIndex_ * 2 + this->vectorIndex_;
this->normProcessLine_ = this->totalLines_ / this->vectorNum_;
this->normLineOffset_ = this->normProcessLine_ * vector_id;
int64_t rows_remain = this->totalLines_ % this->vectorNum_;
if (rows_remain > 0 && vector_id < rows_remain) {
this->normProcessLine_ += 1;
}
this->normLineOffset_ += vector_id < rows_remain ? vector_id : rows_remain;
}
* 基本偏移量的设置
*/
__aicore__ __inline__ void set_init_offset()
{
this->coreOffset_ = this->coreIndex_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
this->startLine_ = this->processLineNum_ * this->vectorIndex_;
this->startLineOffset_ = this->startLine_ * SIZE_128;
this->vectorStartOffset_ = this->coreOffset_ + this->startLineOffset_;
this->maskOffset_ = this->startLine_ * this->maskSequenceLen_;
}
* nomask场景的偏移量设置
* @param round_id
* @param section
*/
__aicore__ __inline__ void
addrMapping_nomask(const ForWardAddrOnline *addr, int64_t &src_len, int64_t round_id,
FORWARD_SECTION_INFO §ion)
{
int64_t diag_out_offset = this->coreIndex_ * this->ky_ * ATTENTION_SCORE_BLOCK_SIZE * MAX_SWITCH_TIME *2;
for (int64_t i = 0; i < src_len; ++i) {
int64_t b = addr[i].b;
int64_t n = addr[i].n;
int64_t ir = addr[i].iR;
int64_t ic = addr[i].iC;
int64_t Kx = addr[i].kx;
int64_t row_max_bn_offset = (b * this->headNum_ + n) * this->querySequenceLen_;
int64_t row_max_inner_offset = ir * SIZE_128 + SIZE_128 / 2 * this->vectorIndex_;
section.rowmaxOffset[2 * i] = row_max_bn_offset + row_max_inner_offset;
section.rowmaxOffset[2 * i + 1] = section.rowmaxOffset[2 * i] + SIZE_128;
section.sectionBlockNums[2 * i] = Kx;
section.sectionBlockOffset[2 * i] =
i == 0 ? this->vectorStartOffset_ : section.sectionBlockOffset[2 * (i - 1) + 1] +
section.sectionBlockNums[2 * (i - 1) + 1] *
ATTENTION_SCORE_BLOCK_SIZE;
section.sectionBlockNums[2 * i + 1] = Kx;
section.sectionBlockOffset[2 * i + 1] = section.sectionBlockOffset[2 * i] +
section.sectionBlockNums[2 * i] *
ATTENTION_SCORE_BLOCK_SIZE;
section.diagOffset[2 * i] = diag_out_offset +
((round_id % 2) * MAX_SWITCH_TIME+ i) * 2 * ATTENTION_SCORE_BLOCK_SIZE +
this->vectorIndex_ * ATTENTION_SCORE_BLOCK_SIZE / 2;
section.diagOffset[2 * i + 1] = section.diagOffset[2 * i] + ATTENTION_SCORE_BLOCK_SIZE;
section.isHeadSection[2 * i] = ic == 0 ? true : false;
section.isTailSection[2 * i] = (ic + Kx >= this->blockNumPerRow_ - 1) ? true : false;
section.isHeadSection[2 * i + 1] = ic == 0 ? true : false;
section.isTailSection[2 * i + 1] = (ic + Kx >= this->blockNumPerRow_ - 1) ? true : false;
}
section.sectionNum = src_len * 2;
section.maskNum = 0;
section.matrixMaskOffset = this->maskOffset_;
section.processLineNum = this->processLineNum_;
section.sparseFlag = false;
section.isTriangle = false;
section.attentionScoreOffset = (round_id % 2) *
this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
}
* mask场景的偏移量设置
* @param round_id
* @param section
*/
__aicore__ __inline__ void addrMapping_mask(const ForWardAddrOnline *addr,
int64_t &src_len, int64_t round_id, FORWARD_SECTION_INFO §ion)
{
int64_t index = 0;
int64_t diag_out_offset = this->coreIndex_ * this->ky_ * ATTENTION_SCORE_BLOCK_SIZE * MAX_SWITCH_TIME*2;
int64_t tri_block_num_per_column = this->blockNumPerCol_ - 2 * this->isOdd_;
if (this->vectorIndex_ == 0) {
return;
}
for (int64_t i = 0; i < src_len; ++i) {
int64_t b = addr[i].b;
int64_t n = addr[i].n;
int64_t i_r = addr[i].iR;
int64_t i_c = addr[i].iC;
int64_t kx = addr[i].kx;
int64_t ky = addr[i].ky;
int64_t k = addr[i].k;
int64_t switch_index = tri_block_num_per_column + i_r + (i_r + 1) % 2;
int64_t row_offset = (i_r + 1) % 2 == 1 ? -1 : 1;
int64_t row_index_left_section = tri_block_num_per_column + i_r;
int64_t row_index_right_section = tri_block_num_per_column - 1 - i_r + row_offset;
int64_t col_index_left_section = i_c;
int64_t col_index_right_section = i_c - switch_index - 1;
int64_t row_max_bn_offset = (b * this->headNum_ + n) * this->querySequenceLen_;
if (switch_index < i_c) {
int64_t row_max_inner_offset =
row_index_right_section * SIZE_128 + SIZE_128 / 2 * this->vectorIndex_;
section.rowmaxOffset[index] = row_max_bn_offset + row_max_inner_offset;
section.rowmaxOffset[index + 1] = section.rowmaxOffset[index] + SIZE_128;
section.sectionBlockNums[index] = (i_c + kx >= this->blockNumPerRow_ - 1) ? kx - 1 : kx;
section.sectionBlockOffset[index] =
index == 0 ? this->vectorStartOffset_ : section.sectionBlockOffset[index - 1] +
section.sectionBlockNums[index - 1] *
ATTENTION_SCORE_BLOCK_SIZE;
section.sectionBlockNums[index + 1] = kx;
section.sectionBlockOffset[index + 1] =
section.sectionBlockOffset[index] + kx * ATTENTION_SCORE_BLOCK_SIZE;
section.isHeadSection[index] = (i_c == switch_index + 1) ? true : false;
section.isHeadSection[index + 1] = section.isHeadSection[index];
section.isTailSection[index] = (i_c + kx >= this->blockNumPerRow_ - 1) ? true : false;
section.isTailSection[index + 1] = section.isTailSection[index];
section.diagOffset[index] = diag_out_offset +
((round_id % 2) * MAX_SWITCH_TIME * 2 + index) * ATTENTION_SCORE_BLOCK_SIZE +
this->vectorIndex_ * ATTENTION_SCORE_BLOCK_SIZE / 2;
section.diagOffset[index + 1] = section.diagOffset[index] + ATTENTION_SCORE_BLOCK_SIZE;
index += 2;
} else {
int64_t row_max_inner_offset =
row_index_left_section * SIZE_128 + SIZE_128 / 2 * this->vectorIndex_;
section.rowmaxOffset[index] = row_max_bn_offset + row_max_inner_offset;
section.rowmaxOffset[index + 1] = section.rowmaxOffset[index] + SIZE_128;
section.sectionBlockNums[index] = (i_c + kx >= switch_index) ? kx - 1 : kx;
section.sectionBlockNums[index + 1] = kx;
section.sectionBlockOffset[index] =
index == 0 ? this->vectorStartOffset_ : section.sectionBlockOffset[index - 1] +
section.sectionBlockNums[index - 1] *
ATTENTION_SCORE_BLOCK_SIZE;
section.sectionBlockOffset[index + 1] =
section.sectionBlockOffset[index] + kx * ATTENTION_SCORE_BLOCK_SIZE;
section.isHeadSection[index] = (i_c == 0) ? true : false;
section.isHeadSection[index + 1] = section.isHeadSection[index];
section.isTailSection[index] = (i_c + kx - 1 >= switch_index) ? true : false;
section.isTailSection[index + 1] = section.isTailSection[index];
section.diagOffset[index] = diag_out_offset +
((round_id % 2) * MAX_SWITCH_TIME * 2 + index) * ATTENTION_SCORE_BLOCK_SIZE +
this->vectorIndex_ * ATTENTION_SCORE_BLOCK_SIZE / 2;
section.diagOffset[index + 1] = section.diagOffset[index] + ATTENTION_SCORE_BLOCK_SIZE;
index += 2;
}
}
int64_t pos = 0;
for (int64_t i = 0; i < index; ++i) {
if (section.sectionBlockNums[i] == 0) {
continue;
}
section.sectionBlockNums[pos] = section.sectionBlockNums[i];
section.sectionBlockOffset[pos] = section.sectionBlockOffset[i];
section.rowmaxOffset[pos] = section.rowmaxOffset[i];
section.isHeadSection[pos] = section.isHeadSection[i];
section.isTailSection[pos] = section.isTailSection[i];
++pos;
}
section.sectionNum = pos;
section.maskNum = pos;
section.matrixMaskOffset = this->maskOffset_;
section.isTriangle = true;
section.sparseFlag = false;
section.processLineNum = this->processLineNum_;
section.attentionScoreOffset =
(round_id % 2) * this->coreNum_ * this->blockNumPerCore_ * ATTENTION_SCORE_BLOCK_SIZE;
}
public:
* 类的初始化
* @param batch_size
* @param head_num
* @param query_sequence_len
* @param key_value_sequence_len
* @param mask_sequence_len
* @param is_triangle
* @param window_size
* @param sparse_mode
* @param block_num_per_core
* @param ky
*/
__aicore__ __inline__ void init(int64_t batch_size, int64_t head_num, int64_t query_sequence_len,
int64_t key_value_sequence_len, int64_t mask_sequence_len,
bool is_triangle, int64_t window_size, int64_t sparse_mode,
int64_t block_num_per_core, int64_t ky)
{
this->batchSize_ = batch_size;
this->headNum_ = head_num;
this->querySequenceLen_ = query_sequence_len;
this->keyValueSequenceLen_ = key_value_sequence_len;
this->maskSequenceLen_ = mask_sequence_len;
this->isTriangle_ = is_triangle;
this->windowSize_ = window_size;
this->sparseMode_ = sparse_mode;
this->blockNumPerCore_ = block_num_per_core;
this->ky_ = ky;
if (ky == 0) {
return;
}
this->kx_ = (ky != 0) ? (block_num_per_core / ky) : 0;
this->isOdd_ = this->querySequenceLen_ / BASE_BLOCK_LENGTH / 2 % 2;
}
* 设置核组信息
* @param core_num
* @param cur_core_index
* @param vector_index
*/
__aicore__ __inline__ void set_core_info(int64_t core_num, int64_t cur_core_index, int64_t vector_index)
{
this->coreNum_ = core_num;
this->coreIndex_ = cur_core_index;
this->vectorNum_ = this->coreNum_ * 2;
this->vectorIndex_ = vector_index;
}
* 前向vector寻址启动,计算一些基本的偏移量
* @return
*/
__aicore__ __inline__ void start()
{
set_global_info();
set_local_info();
set_init_offset();
}
* 总轮次
* @return
*/
__aicore__ __inline__ int64_t get_total_round()
{
return this->totalRounds_;
}
* 归一化时需要处理的行数
* @return
*/
__aicore__ __inline__ int64_t get_norm_process_lines()
{
return this->normProcessLine_;
}
* 归一化的偏移量
* @return
*/
__aicore__ __inline__ int64_t get_norm_offset()
{
return this->normLineOffset_;
}
* 判断当前轮次、当前核是否要计算
* @param round_id
* @return
*/
__aicore__ __inline__ bool is_running(int64_t round_id)
{
if (this->blockNumPerRow_ == 0) {
return false;
}
int64_t skip_block = this->coreNum_ * this->blockNumPerRow_ * this->ky_;
int64_t outer_row = (round_id * this->kx_) / this->blockNumPerRow_ * skip_block;
int64_t inner_row = this->coreIndex_ * this->ky_ * this->blockNumPerRow_;
int64_t inner_col = (round_id * this->kx_) % this->blockNumPerRow_;
int64_t cur_block_id = outer_row + inner_row + inner_col;
return (cur_block_id < this->totalBlocks_);
}
* 获取当前轮次的section信息
* @param round_id
* @param section
*/
__aicore__ __inline__ void get_section_info(int64_t round_id, FORWARD_SECTION_INFO §ion)
{
int64_t src_len = 0;
ForWardAddrOnline forward_addr[MAX_SWITCH_TIME];
forward_addrMapping_pre(forward_addr, src_len, round_id);
if (this->isTriangle_) {
return addrMapping_mask(forward_addr, src_len, round_id, section);
}
return addrMapping_nomask(forward_addr, src_len, round_id, section);
}
};
}
#endif