* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
\file page_attention.cpp
\brief
*/
#include "interface/operation/operation_impl.h"
#include "interface/operation/operation.h"
#include "interface/function/function.h"
#include "tilefwk/tensor.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/tensor/raw_tensor.h"
#include "interface/tensor/tensormap.h"
#include "interface/configs/config_manager.h"
#include "interface/utils/common.h"
#include "interface/utils/id_gen.h"
#include "gen_kv_slc.h"
using namespace npu::tile_fwk;
namespace npu::tile_fwk {
void KvSlcCompute(
Tensor& topK_indcies, Tensor& topK_tensor_shape, Tensor& kvNopeCache, Tensor& kRopeCache, Tensor& kvActSeqs,
int front, int near, int topk, int l_prime, int n2, Tensor& blockTable, int blockSize, Tensor& k_slcOut,
Tensor& v_slcOut, Tensor& kvSlcActSeqs, KvSlcTileShapeConfig& tileConfig, bool debug)
{
auto v0Tile = tileConfig.v0TileShape;
SymbolicScalar b = topK_indcies.GetShape()[0];
SymbolicScalar s = topK_indcies.GetShape()[1];
SymbolicScalar kv_lora_rank = kvNopeCache.GetShape(1) / n2;
SymbolicScalar rope_dim = kRopeCache.GetShape(1) / n2;
LOOP("LOOP_L0_batchIdx", FunctionType::DYNAMIC_LOOP, batchIdx, LoopRange(0, b, 1), {}, true)
{
SymbolicScalar curActSeq = GetTensorData(kvActSeqs, {batchIdx});
LOOP("LOOP_L1_slcIdx", FunctionType::DYNAMIC_LOOP, slcIdx, LoopRange(0, s, 1))
{
LOOP("LOOP_L2_kvSlcIdx", FunctionType::DYNAMIC_LOOP, nkvIdx, LoopRange(0, n2, 1))
{
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
SymbolicScalar s_slc = GetTensorData(topK_tensor_shape, {batchIdx, slcIdx});
SymbolicScalar positions = 0;
SymbolicScalar prime_value = l_prime;
SymbolicScalar slcSeqLen = 0;
for (int topKIdx = 0; topKIdx < topk; topKIdx++) {
if (topKIdx < front) {
positions = topKIdx * l_prime;
} else if (topKIdx > (topk - near - front)) {
positions = (s_slc - near + (topKIdx - (topk - front - near)) - 1) * l_prime;
} else {
SymbolicScalar topk_index;
if (debug) {
TileShape::Current().SetVecTile(1, 1, NUM16);
topk_index = GetTensorData(topK_indcies, {batchIdx, slcIdx, topKIdx - front});
} else {
topk_index = GetTensorData(topK_indcies, {batchIdx, slcIdx, topKIdx - front});
}
positions = topk_index * prime_value;
}
slcSeqLen = slcSeqLen + prime_value;
SymbolicScalar blockIdxInBatch = positions / blockSize;
SymbolicScalar tail = positions % blockSize;
SymbolicScalar slcBlockIdx = GetTensorData(blockTable, {batchIdx, blockIdxInBatch});
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
auto kv_slcBlock = View(
kvNopeCache, {l_prime, kv_lora_rank}, {slcBlockIdx * blockSize + tail, nkvIdx * kv_lora_rank});
auto kRope_slcBlock =
View(kRopeCache, {l_prime, rope_dim}, {slcBlockIdx * blockSize + tail, nkvIdx * rope_dim});
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
auto kv_slcBlock_fp32 = Cast(kv_slcBlock, DataType::DT_FP32);
auto kRope_slcBlock_fp32 = Cast(kRope_slcBlock, DataType::DT_FP32);
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
auto kv_slcBlock_tiled =
Mul(kv_slcBlock_fp32, Element(kv_slcBlock_fp32.GetStorage()->Datatype(), float(1)));
auto kRope_slcBlock_tiled =
Mul(kRope_slcBlock_fp32, Element(kRope_slcBlock_fp32.GetStorage()->Datatype(), float(1)));
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
auto kv_slcBlock_fp16 = Cast(kv_slcBlock_tiled, k_slcOut.GetStorage()->Datatype());
auto kRope_slcBlock_fp16 = Cast(kRope_slcBlock_tiled, v_slcOut.GetStorage()->Datatype());
TileShape::Current().SetVecTile(v0Tile[0], v0Tile[1]);
SymbolicScalar output_axis1_value = batchIdx * s * n2 * topk * l_prime +
slcIdx * n2 * topk * l_prime + nkvIdx * topk * l_prime +
topKIdx * l_prime;
Assemble(kv_slcBlock_fp16, {output_axis1_value, 0}, k_slcOut);
Assemble(kRope_slcBlock_fp16, {output_axis1_value, kv_lora_rank}, k_slcOut);
Assemble(kv_slcBlock_fp16, {output_axis1_value, 0}, v_slcOut);
}
SetTensorData(slcSeqLen, {batchIdx, slcIdx}, kvSlcActSeqs);
}
}
}
}
void GenKvSlc(
Tensor& topK_indcies, Tensor& topK_tensor_shape, Tensor& kvNopeCache, Tensor& kRopeCache, Tensor& kvActSeqs,
int front, int near, int topk, int l_prime, int n2, Tensor& blockTable, int blockSize, Tensor& k_slcOut,
Tensor& v_slcOut, Tensor& kvSlcActSeqs, KvSlcTileShapeConfig& tileConfig)
{
FUNCTION(
"main_slc", {topK_indcies, topK_tensor_shape, kvNopeCache, kRopeCache, kvActSeqs, blockTable},
{k_slcOut, v_slcOut, kvSlcActSeqs})
{
KvSlcCompute(
topK_indcies, topK_tensor_shape, kvNopeCache, kRopeCache, kvActSeqs, front, near, topk, l_prime, n2,
blockTable, blockSize, k_slcOut, v_slcOut, kvSlcActSeqs, tileConfig);
}
}
}