* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file dynamic_nsa.cpp
* \brief
*/
#include "operator/models/nsa/dynamic_nsa_v1.h"
namespace npu::tile_fwk {
void GenGatedScoreCompute(
const Tensor& x, const Tensor& gateW1, const Tensor& gateW2, const Tensor& gateSimW1, Tensor& gatingScore,
GateMode gateMode)
{
FUNCTION("FusedCompressKvSelect", {x, gateW1, gateW2, gateSimW1}, {gatingScore})
{
GenGatedScore(x, gateW1, gateW2, gateSimW1, gatingScore, gateMode);
}
}
void GenGatedScore(
const Tensor& x, const Tensor& gateW1, const Tensor& gateW2, const Tensor& gateSimW1, Tensor& gatingScore,
GateMode gateMode)
{
(void)gateSimW1;
(void)gateMode;
DataType dType = x.GetStorage()->Datatype();
int b = x.GetShape()[0];
int s = x.GetShape()[1];
int h = x.GetShape()[2];
int n1 = gateW2.GetShape()[1] / 3;
int tileB = b;
int tileS = s;
int tileBS = tileB * tileS;
SymbolicScalar bLoop = b / tileB;
SymbolicScalar sLoop = s / tileS;
LOOP("LOOP_L0_bIdx_gated_score", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bLoop, 1))
{
LOOP("LOOP_L0_sIdx_gated_score", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(0, sLoop, 1))
{
TileShape::Current().SetVecTile({tileB, tileS, h});
TileShape::Current().SetCubeTile({tileBS, tileBS}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
SymbolicScalar bOfs = bIdx * tileB;
SymbolicScalar sOfs = sIdx * tileS;
SymbolicScalar bsOfs = bOfs * sOfs;
auto xReshape = Reshape(x, {b * s, h});
auto xView = View(xReshape, {tileBS, h}, {bsOfs, 0});
auto mm1Res = Matrix::Matmul(DT_FP32, xReshape, gateW1);
TileShape::Current().SetVecTile({1, h});
auto sigmoidRes = Sigmoid(mm1Res);
sigmoidRes = Cast(sigmoidRes, dType);
TileShape::Current().SetCubeTile({tileBS, tileBS}, {NUM_128, NUM_128}, {NUM_16, NUM_16});
auto mm2Res = Matrix::Matmul(DT_FP32, sigmoidRes, gateW2);
TileShape::Current().SetVecTile({tileBS, n1});
auto res = Reshape(mm2Res, {tileB, tileS, 3, n1});
TileShape::Current().SetVecTile({1, tileS, 3, n1});
res = Transpose(res, {2, 3});
if (gatingScore.GetStorage()->Datatype() != DT_FP32) {
res = Cast(res, dType);
}
Assemble(res, {bOfs, sIdx, 0, 0}, gatingScore);
}
}
}
void GenAttn(Tensor& gatingScore, Tensor& cmpAtten, Tensor& selAtten, Tensor& winAtten, Tensor& attentionOut)
{
int nDimSize = cmpAtten.GetShape()[2];
int vDimSize = cmpAtten.GetShape()[3];
int tileB = 8;
int tileS = 1;
SymbolicScalar bDimSize = GetInputShape(cmpAtten, 0);
SymbolicScalar sDimSize = GetInputShape(cmpAtten, 1);
SymbolicScalar bLoop = bDimSize / tileB;
SymbolicScalar sLoop = sDimSize / tileS;
DataType dType = attentionOut.GetStorage()->Datatype();
LOOP("LOOP_L0_bIdx_gen_attn", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(bLoop), {}, true)
{
SymbolicScalar bOffset = bIdx * tileB;
SymbolicScalar actualBSize = std::min(tileB, (bDimSize - bIdx * tileB));
LOOP("LOOP_L1_sIdx_gen_attn", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(sLoop))
{
SymbolicScalar sOffset = sIdx * tileS;
std::vector<SymbolicScalar> outOffset = {bOffset, sOffset, 0, 0};
SymbolicScalar actualsSize = std::min(tileS, (sDimSize - sIdx * tileS));
TileShape::Current().SetVecTile(1, 1, NUM_16, vDimSize);
auto cmpAttenTile = View(
cmpAtten, {tileB, tileS, nDimSize, vDimSize}, {actualBSize, actualsSize, nDimSize, vDimSize},
{bOffset, sOffset, 0, 0});
auto selAttenTile = View(
selAtten, {tileB, tileS, nDimSize, vDimSize}, {actualBSize, actualsSize, nDimSize, vDimSize},
{bOffset, sOffset, 0, 0});
auto winAttenTile = View(
winAtten, {tileB, tileS, nDimSize, vDimSize}, {actualBSize, actualsSize, nDimSize, vDimSize},
{bOffset, sOffset, 0, 0});
auto cmpAttenFP32Tile = Cast(cmpAttenTile, DT_FP32);
auto selAttenFP32Tile = Cast(selAttenTile, DT_FP32);
auto winAttenFP32Tile = Cast(winAttenTile, DT_FP32);
TileShape::Current().SetVecTile(1, 1, nDimSize, NUM_3);
auto gatingScoreTile = View(
gatingScore, {tileB, tileS, nDimSize, NUM_3}, {actualBSize, actualsSize, nDimSize, NUM_3},
{bOffset, sOffset, 0, 0});
auto gatingScoreFP32 = Cast(gatingScoreTile, DT_FP32);
auto cmpWeight = View(gatingScoreFP32, {tileB, tileS, nDimSize, 1}, {0, 0, 0, 0});
auto selWeight = View(gatingScoreFP32, {tileB, tileS, nDimSize, 1}, {0, 0, 0, 1});
auto winWeight = View(gatingScoreFP32, {tileB, tileS, nDimSize, 1}, {0, 0, 0, 2});
TileShape::Current().SetVecTile(1, 1, NUM_16, vDimSize);
auto mulCmp = Mul(cmpAttenFP32Tile, cmpWeight);
auto mulSel = Mul(selAttenFP32Tile, selWeight);
auto mulWin = Mul(winAttenFP32Tile, winWeight);
auto addCmpSel = Add(mulCmp, mulSel);
auto outFP32 = Add(addCmpSel, mulWin);
TileShape::Current().SetVecTile(1, 1, NUM_16, vDimSize);
auto attentionOutTile = Cast(outFP32, dType, CAST_RINT);
Assemble(attentionOutTile, outOffset, attentionOut);
}
}
}
void DynamicNsa(
const Tensor& x, const Tensor& wDq, const Tensor& wUqQr, const Tensor& wUk, const Tensor& wDkvKr,
const Tensor& gammaCq, const Tensor& gammaCkv, const Tensor& sin, const Tensor& cos, const Tensor& cacheIndex,
Tensor& kvCache, Tensor& krCache, const MlaQuantInputs& quantInputs, const MlaTileConfig& mlaTileConfig,
float epsilonCq, float epsilonCkv, std::string cacheMode,
Tensor& topkIndices, Tensor& kvActSeqs, Tensor& blockTable, int front, int near, int topk, int slcBlockSize,
int blockSize, float softmaxScale, SATileShapeConfig saTileConfig, const Tensor& gateW1, const Tensor& gateW2,
const Tensor& gateSimW1, GateMode gateMode, Tensor& cmpAtten, int winSize,
WinAttenTileShapeConfig& winAttntileConfig,
PostTensors& postTensors, const PostTileConfig& postConfig,
Tensor& kvCacheOut, Tensor& krCacheOut, Tensor& postOut, const Tensor& cmpKvCache, const Tensor& cmpKrCache,
const Tensor& cmpBlockTable, const Tensor& actSeqLen, const Tensor& actCmpSeqLen, const Tensor& mlpWk1,
const Tensor& mlpWk2, const Tensor& mlpCos, const Tensor& mlpSin, Tensor& cmpAttnOut, Tensor& cmpSoftmax,
Tensor& fullK, Tensor& cmpK, Tensor& firstRope, Tensor& firstRopeInput, Tensor& topkRes, Tensor& topkInput,
const int cmpBlockSize, const int cmpStride, CmpAttnTile& tileConfig_v2, bool debug)
{
ASSERT(gateMode == standard);
FUNCTION(
"main",
{
x,
wDq,
wUqQr,
wUk,
wDkvKr,
gammaCq,
gammaCkv,
sin,
cos,
cacheIndex,
kvCache,
krCache,
cmpKvCache,
cmpKrCache,
blockTable,
cmpBlockTable,
actSeqLen,
actCmpSeqLen,
mlpWk1,
mlpWk2,
mlpCos,
mlpSin,
quantInputs.dequantScaleWUqQr,
quantInputs.smoothScalesCq,
topkIndices,
kvActSeqs,
gateW1,
gateW2,
gateSimW1,
cmpAtten,
postTensors.weightUV,
postTensors.weightO,
postTensors.weightUvScale,
postTensors.smoothScalesWUv,
postTensors.weightOScale,
postTensors.smoothScalesWo,
},
{postOut, cmpAttnOut, cmpSoftmax, fullK, cmpK, topkRes, topkInput},
{{kvCacheOut, kvCache}, {krCacheOut, krCache}})
{
config::SetPassOption(CUBE_L1_REUSE_SETTING, std::map<int64_t, int64_t>{{-1, NUM_4}});
config::SetPassOption(CUBE_NBUFFER_SETTING, std::map<int64_t, int64_t>{{NUM_3, NUM_4}});
config::SetPassOption(MG_COPYIN_UPPER_BOUND, NUM_2 * NUM_1024 * NUM_1024);
int b = x.GetShape()[0];
int s = x.GetShape()[1];
int n1 = gateW2.GetShape()[1] / 3;
int n2 = 1;
int vDim = wUk.GetShape()[2];
int ropeDim = sin.GetShape()[2];
auto dtype = x.GetStorage()->Datatype();
* subgragh 0: mla_prolog
* subgragh 1: gen_win_attn
* subgragh 2: kv_compression
* subgragh 3: gen_cmp_atten
* subgragh 4: gen_slc_atten, 其中包括: gen_kv_slc及slc_attn
* subgragh 5: gen_gated_score
* subgragh 6: gen_attn
* subgragh 7: pa_post
*/
{subgraph-1, subgraph-2, subgraph-4}: {subgraph-0}
subgraph-6: {subgraph-1, subgraph-3, subgraph-4, subgraph-5}
subgraph-3: {subgraph-2}
subgraph-7: {subgraph-6}
*/
Tensor queryOut(dtype, {b, s, n1, vDim}, "queryOut");
Tensor queryRopeOut(dtype, {b, s, n1, ropeDim}, "queryRopeOut");
MlaPrologCompute(
x, wDq, wUqQr, wUk, wDkvKr, gammaCq, gammaCkv, sin, cos, cacheIndex, kvCache, krCache, quantInputs,
mlaTileConfig, queryOut, queryRopeOut, kvCacheOut, krCacheOut, epsilonCq, epsilonCkv, cacheMode);
Tensor qNope(dtype, {b * s * n1, vDim}, "qNope");
Tensor qRope(dtype, {b * s * n1, ropeDim}, "qRope");
LOOP("RESHAPE_LOOP_L0_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, b, 1), {}, true)
{
SymbolicScalar bOffset = bIdx * 1;
LOOP("RESHAPE_LOOP_L1_sIdx", FunctionType::DYNAMIC_LOOP, sIdx, LoopRange(0, s, 1))
{
SymbolicScalar sOffset = sIdx * 1;
Tensor nopeView = View(queryOut, {1, 1, n1, vDim}, {bOffset, sOffset, 0, 0});
TileShape::Current().SetVecTile({1, 1, 32, vDim});
Tensor nopeRes = Reshape(nopeView, {1 * 1 * n1, vDim});
Assemble(nopeRes, {(bOffset * s + sOffset) * n1, 0}, qNope);
Tensor ropeView = View(queryRopeOut, {1, 1, n1, ropeDim}, {bOffset, sOffset, 0, 0});
TileShape::Current().SetVecTile({1, 1, n1, ropeDim});
Tensor ropeRes = Reshape(ropeView, {1 * 1 * n1, ropeDim});
Assemble(ropeRes, {(bOffset * s + sOffset) * n1, 0}, qRope);
}
}
Tensor winAtten(DT_FP32, {b, s, n1, vDim}, "winAtten");
WinAttentionCompute(
qNope, kvCacheOut, qRope, krCacheOut, n1, n2, blockTable, kvActSeqs, winSize, blockSize, softmaxScale,
winAtten, winAttntileConfig);
Tensor cmpAttnOut16Tmp(dtype, {b, s, n1, vDim}, "cmpAttnOut16Tmp");
Tensor topkResTmp(DT_INT32, {b, s, 16}, "topkRes");
if (debug) {
FusedCompressKvSelectCompute(
qNope, qRope, kvCacheOut, krCacheOut, cmpKvCache, cmpKrCache, blockTable, cmpBlockTable, actSeqLen,
actCmpSeqLen, mlpWk1, mlpWk2, mlpCos, mlpSin, cmpAttnOut, cmpAttnOut16Tmp, cmpSoftmax, fullK, cmpK,
firstRope, firstRopeInput, topkResTmp, topkInput, blockSize, cmpBlockSize, cmpStride, softmaxScale, n1,
n2, tileConfig_v2);
} else {
FusedCompressKvSelectCompute(
qNope, qRope, kvCacheOut, krCacheOut, cmpKvCache, cmpKrCache, blockTable, cmpBlockTable, actSeqLen,
actCmpSeqLen, mlpWk1, mlpWk2, mlpCos, mlpSin, cmpAttnOut, cmpAttnOut16Tmp, cmpSoftmax, fullK, cmpK,
firstRope, firstRopeInput, topkRes, topkInput, blockSize, cmpBlockSize, cmpStride, softmaxScale, n1, n2,
tileConfig_v2);
}
Tensor slcAttn(DT_FP32, {b, s, n1, vDim}, "slcAttn");
if (debug) {
SelectedAttentionCompute(
topkResTmp, kvCacheOut, krCacheOut, kvActSeqs, blockTable, qNope, qRope, slcAttn, n1, n2, softmaxScale,
front, near, topk, blockSize, cmpBlockSize, slcBlockSize, saTileConfig, debug);
} else {
SelectedAttentionCompute(
topkIndices, kvCacheOut, krCacheOut, kvActSeqs, blockTable, qNope, qRope, slcAttn, n1, n2, softmaxScale,
front, near, topk, blockSize, cmpBlockSize, slcBlockSize, saTileConfig, debug);
}
Tensor gatingScore(dtype, {b, s, n1, 3}, "gatingScore");
GenGatedScore(x, gateW1, gateW2, gateSimW1, gatingScore, gateMode);
Tensor attentionOut(dtype, {b, s, n1, vDim}, "attentionOut");
GenAttn(gatingScore, cmpAttnOut16Tmp, slcAttn, winAtten, attentionOut);
config::SetPassOption(CUBE_NBUFFER_SETTING, std::map<int64_t, int64_t>{{0, 4}});
PostCompute(attentionOut, postTensors, postConfig, postOut);
}
}
}