* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file win_attention.cpp
* \brief
*/
#include "interface/inner/tilefwk.h"
#include "win_attention.h"
using namespace npu::tile_fwk;
namespace npu::tile_fwk {
void WinAttentionCompute(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
auto dtype = qNope.GetStorage()->Datatype();
int dNopeSize = qNope.GetShape()[1];
int dRopeSize = qRope.GetShape()[1];
ASSERT(nKv != 0) << "nKv cant't be zero!";
auto gGroup = nQ / nKv;
int gTile = tileConfig.gTile;
ASSERT(blockSize != 0) << "blockSize can't be zero!";
auto nopeTile = tileConfig.vNopeTileShape;
auto ropeTile = tileConfig.vRopeTileShape;
auto c1Tile = tileConfig.c1TileShape;
auto v1Tile = tileConfig.v1TileShape;
auto c2Tile = tileConfig.c2TileShape;
auto v2Tile = tileConfig.v2TileShape;
SymbolicScalar bSize = blockTable.GetShape()[0];
SymbolicScalar bTile = 1;
ASSERT(bTile != 0) << "bTile can't be zero!";
ASSERT(nQ != 0) << "nQ can't be zero!";
SymbolicScalar bLoop = bSize / bTile;
SymbolicScalar s1Size = qNope.GetShape()[0] / bSize / nQ;
SymbolicScalar s1Tile = 1;
ASSERT(s1Tile != 0) << "s1Tile can't be zero!";
SymbolicScalar s1Loop = s1Size / s1Tile;
SymbolicScalar n2Tile = 1;
ASSERT(n2Tile != 0) << "n2Tile can't be zero!";
SymbolicScalar n2Loop = nKv / n2Tile;
ASSERT(gTile != 0) << "gTile can't be zero!";
SymbolicScalar gLoop = gGroup / gTile;
SymbolicScalar blockStartIndex = 0;
SymbolicScalar blockStartOffset = 0;
SymbolicScalar blockEndIndex = 0;
SymbolicScalar winActualSize = 0;
SymbolicScalar tableLoop = 0;
LOOP("LOOP_L0_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bLoop, 1), {}, true)
{
SymbolicScalar curActualSeqSize = GetTensorData(actSeqs, {bIdx});
LOOP("LOOP_L1_s1Idx", FunctionType::DYNAMIC_LOOP, s1Idx, LoopRange(0, s1Loop, 1))
{
winActualSize = std::min(windowSize, (curActualSeqSize - s1Size + s1Idx + 1));
blockEndIndex = (curActualSeqSize + blockSize - 1) / blockSize - 1;
blockStartIndex = std::max(0, ((curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) / blockSize));
blockStartOffset = (curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) % blockSize;
tableLoop = blockEndIndex - blockStartIndex + 1;
LOOP("LOOP_L2_n2Idx", FunctionType::DYNAMIC_LOOP, n2Idx, LoopRange(0, n2Loop, 1))
{
LOOP("LOOP_L2_gIdx", FunctionType::DYNAMIC_LOOP, gIdx, LoopRange(0, gLoop, 1))
{
std::vector<SymbolicScalar> oiOffset = {bIdx, s1Idx, n2Idx * gGroup + gIdx * gTile, 0};
SymbolicScalar curOffset = bIdx * s1Size * nQ + s1Idx * nQ + n2Idx * gGroup + gIdx * gTile;
Tensor kPart(dtype, {NUM_9 * blockSize, (dNopeSize + dRopeSize)}, "kPart");
for (auto tIdx = 0; tIdx < NUM_9; tIdx++) {
SymbolicScalar curidx = blockStartIndex + tIdx;
SymbolicScalar curBlockIdx = GetTensorData(blockTable, {bIdx, curidx});
TileShape::Current().SetVecTile(nopeTile[0], nopeTile[1]);
auto kNope =
View(vNopeCache, {blockSize, dNopeSize}, {curBlockIdx * blockSize, n2Idx * dNopeSize});
auto tmpK1 = Cast(kNope, DataType::DT_FP32);
auto tmpK2 = Cast(tmpK1, dtype);
Assemble(tmpK2, {tIdx * blockSize, 0}, kPart);
TileShape::Current().SetVecTile(ropeTile[0], ropeTile[1]);
auto kRope =
View(kRopeCache, {blockSize, dRopeSize}, {curBlockIdx * blockSize, n2Idx * dRopeSize});
auto tmpKR1 = Cast(kRope, DataType::DT_FP32);
auto tmpKR2 = Cast(tmpKR1, dtype);
Assemble(tmpKR2, {tIdx * blockSize, dNopeSize}, kPart);
}
auto startOffset = blockStartOffset;
auto kActualPart = View(
kPart, {windowSize, dNopeSize + dRopeSize}, {winActualSize, dNopeSize + dRopeSize},
{startOffset, 0});
auto vActualPart =
View(kPart, {windowSize, dNopeSize}, {winActualSize, dNopeSize}, {startOffset, 0});
Tensor qPart(dtype, {gTile, dNopeSize + dRopeSize}, "qPart");
auto qNopeL = View(qNope, {gTile, dNopeSize}, {gTile, dNopeSize}, {curOffset, 0});
Assemble(qNopeL, {0, 0}, qPart);
auto qRopeR = View(qRope, {gTile, dRopeSize}, {gTile, dRopeSize}, {curOffset, 0});
Assemble(qRopeR, {0, dNopeSize}, qPart);
TileShape::Current().SetCubeTile(
{c1Tile[0], c1Tile[1]}, {c1Tile[2], c1Tile[3]}, {c1Tile[4], c1Tile[5]});
auto qKT = Matrix::Matmul(DataType::DT_FP32, qPart, kActualPart, false, true);
TileShape::Current().SetVecTile(v1Tile[0], v1Tile[1]);
auto qKTScale = Mul(qKT, Element(qKT.GetStorage()->Datatype(), softmaxScale));
auto tileMax = Amax(qKTScale, -1, true);
auto tileSub = Sub(qKTScale, tileMax);
auto tileExp = Exp(tileSub);
auto tileExpF16 = Cast(tileExp, dtype);
auto tileSum = Sum(tileExp, -1, true);
TileShape::Current().SetCubeTile(
{c2Tile[0], c2Tile[1]}, {c2Tile[2], c2Tile[3]}, {c2Tile[4], c2Tile[5]});
auto oiTmp = Matrix::Matmul(DataType::DT_FP32, tileExpF16, vActualPart, false, false);
TileShape::Current().SetVecTile(v2Tile[0], v2Tile[1]);
auto out = Div(oiTmp, tileSum);
TileShape::Current().SetVecTile(1, 1, v2Tile[0], v2Tile[1]);
auto outFinal =
Add(Reshape(out, {bTile, s1Tile, gTile, dNopeSize}),
Element(out.GetStorage()->Datatype(), float(0)));
Assemble(outFinal, oiOffset, attentionOut);
}
}
}
}
}
void WinAttentionComputeFlash(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
auto dtype = qNope.GetStorage()->Datatype();
int dNopeSize = qNope.GetShape()[1];
int dRopeSize = qRope.GetShape()[1];
ASSERT(nKv != 0) << "nKv cant't be zero!";
auto gGroup = nQ / nKv;
int gTile = tileConfig.gTile;
int s2Tile = tileConfig.skvTile;
ASSERT(blockSize != 0) << "blockSize can't be zero!";
auto nopeTile = tileConfig.vNopeTileShape;
auto ropeTile = tileConfig.vRopeTileShape;
auto c1Tile = tileConfig.c1TileShape;
auto v1Tile = tileConfig.v1TileShape;
auto c2Tile = tileConfig.c2TileShape;
auto v2Tile = tileConfig.v2TileShape;
SymbolicScalar bSize = blockTable.GetShape()[0];
SymbolicScalar bTile = 1;
ASSERT(bTile != 0) << "bTile can't be zero!";
ASSERT(nQ != 0) << "nQ can't be zero!";
SymbolicScalar bLoop = bSize / bTile;
SymbolicScalar s1Size = qNope.GetShape()[0] / bSize / nQ;
SymbolicScalar s1Tile = 1;
ASSERT(s1Tile != 0) << "s1Tile can't be zero!";
SymbolicScalar s1Loop = s1Size / s1Tile;
SymbolicScalar n2Tile = 1;
ASSERT(n2Tile != 0) << "n2Tile can't be zero!";
SymbolicScalar n2Loop = nKv / n2Tile;
ASSERT(gTile != 0) << "gTile can't be zero!";
SymbolicScalar gLoop = gGroup / gTile;
SymbolicScalar blockStartIndex = 0;
SymbolicScalar blockStartOffset = 0;
SymbolicScalar blockEndIndex = 0;
SymbolicScalar winActualSize = 0;
SymbolicScalar tableLoop = 0;
LOOP("LOOP_L0_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bLoop, 1), {}, true)
{
SymbolicScalar curActualSeqSize = GetTensorData(actSeqs, {bIdx});
Tensor kPart(dtype, {bSize * s1Size * NUM_9 * blockSize, (dNopeSize + dRopeSize)}, "kPart");
LOOP("LOOP_L1_s1Idx", FunctionType::DYNAMIC_LOOP, s1Idx, LoopRange(0, s1Loop, 1))
{
winActualSize = std::min(windowSize, (curActualSeqSize - s1Size + s1Idx + 1));
blockEndIndex = (curActualSeqSize + blockSize - 1) / blockSize - 1;
blockStartIndex = std::max(0, ((curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) / blockSize));
blockStartOffset = (curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) % blockSize;
tableLoop = blockEndIndex - blockStartIndex + 1;
LOOP("LOOP_L2_n2Idx", FunctionType::DYNAMIC_LOOP, n2Idx, LoopRange(0, n2Loop, 1))
{
LOOP("LOOP_L2_gIdx", FunctionType::DYNAMIC_LOOP, gIdx, LoopRange(0, gLoop, 1))
{
std::vector<SymbolicScalar> oiOffset = {bIdx, s1Idx, n2Idx * gGroup + gIdx * gTile, 0};
SymbolicScalar curOffset = bIdx * s1Size * nQ + s1Idx * nQ + n2Idx * gGroup + gIdx * gTile;
SymbolicScalar s2Loop = (winActualSize + s2Tile - 1) / s2Tile;
Tensor oiUpdate(DT_FP32, {gTile, dNopeSize}, "oiUpdate");
Tensor liUpdate(DT_FP32, {gTile, 1}, "liUpdate");
Tensor miUpdate(DT_FP32, {gTile, 1}, "miUpdate");
SymbolicScalar kvTensorIdx = (bIdx * s1Size + s1Idx) * NUM_9 * blockSize;
LOOP("LOOP_L2_tIdx", FunctionType::DYNAMIC_LOOP, tIdx, LoopRange(0, tableLoop, 1))
{
SymbolicScalar curidx = blockStartIndex + tIdx;
SymbolicScalar curBlockIdx = GetTensorData(blockTable, {bIdx, curidx});
TileShape::Current().SetVecTile(nopeTile[0], nopeTile[1]);
auto kNope =
View(vNopeCache, {blockSize, dNopeSize}, {curBlockIdx * blockSize, n2Idx * dNopeSize});
auto tmpK1 = Cast(kNope, DataType::DT_FP32);
auto tmpK2 = Cast(tmpK1, dtype);
Assemble(tmpK2, {kvTensorIdx + tIdx * blockSize, 0}, kPart);
TileShape::Current().SetVecTile(ropeTile[0], ropeTile[1]);
auto kRope =
View(kRopeCache, {blockSize, dRopeSize}, {curBlockIdx * blockSize, n2Idx * dRopeSize});
auto tmpKR1 = Cast(kRope, DataType::DT_FP32);
auto tmpKR2 = Cast(tmpKR1, dtype);
Assemble(tmpKR2, {kvTensorIdx + tIdx * blockSize, dNopeSize}, kPart);
}
LOOP("LOOP_L2_s2Idx", FunctionType::DYNAMIC_LOOP, s2Idx, LoopRange(0, s2Loop, 1))
{
auto startOffset = blockStartOffset + s2Idx * s2Tile + kvTensorIdx;
auto kActualPart = View(
kPart, {s2Tile, dNopeSize + dRopeSize},
{std::min(winActualSize - s2Idx * s2Tile, s2Tile), dNopeSize + dRopeSize},
{startOffset, 0});
auto vActualPart = View(
kPart, {s2Tile, dNopeSize}, {std::min(winActualSize - s2Idx * s2Tile, s2Tile), dNopeSize},
{startOffset, 0});
Tensor qPart(dtype, {gTile, dNopeSize + dRopeSize}, "qPart");
auto qNopeL = View(qNope, {gTile, dNopeSize}, {gTile, dNopeSize}, {curOffset, 0});
Assemble(qNopeL, {0, 0}, qPart);
auto qRopeR = View(qRope, {gTile, dRopeSize}, {gTile, dRopeSize}, {curOffset, 0});
Assemble(qRopeR, {0, dNopeSize}, qPart);
TileShape::Current().SetCubeTile(
{c1Tile[0], c1Tile[1]}, {c1Tile[2], c1Tile[3]}, {c1Tile[4], c1Tile[5]});
auto qKT = Matrix::Matmul(DataType::DT_FP32, qPart, kActualPart, false, true);
TileShape::Current().SetVecTile(v1Tile[0], v1Tile[1]);
auto qKTScale = Mul(qKT, Element(qKT.GetStorage()->Datatype(), softmaxScale));
auto tileMax = Amax(qKTScale, -1, true);
auto tileSub = Sub(qKTScale, tileMax);
auto tileExp = Exp(tileSub);
auto tileExpF16 = Cast(tileExp, dtype);
auto tileSum = Sum(tileExp, -1, true);
IF(IsLoopBegin(s2Idx, 0))
{
TileShape::Current().SetCubeTile(
{c2Tile[0], c2Tile[1]}, {c2Tile[2], c2Tile[3]}, {c2Tile[4], c2Tile[5]});
auto oiTmp = Matrix::Matmul(DataType::DT_FP32, tileExpF16, vActualPart, false, false);
TileShape::Current().SetVecTile(v2Tile[0], v2Tile[1]);
IF(IsLoopEnd(s2Idx, s2Loop))
{
oiUpdate = Div(oiTmp, tileSum);
TileShape::Current().SetVecTile(1, 1, v2Tile[0], v2Tile[1]);
auto outFinal =
Add(Reshape(oiUpdate, {bTile, s1Tile, gTile, dNopeSize}),
Element(oiUpdate.GetStorage()->Datatype(), float(0)));
Assemble(outFinal, oiOffset, attentionOut);
}
ELSE { oiUpdate = oiTmp; }
liUpdate = tileSum;
miUpdate = tileMax;
}
ELSE
{
auto oi = oiUpdate;
auto li = liUpdate;
auto mi = miUpdate;
auto miNew = Maximum(mi, tileMax);
auto t1 = Sub(mi, miNew);
auto t2 = Exp(t1);
auto t3 = Sub(tileMax, miNew);
auto t4 = Exp(t3);
auto t5 = Mul(t4, tileSum);
auto t6 = Mul(t2, li);
auto liNew = Add(t6, t5);
auto q3 = Mul(oi, t2);
TileShape::Current().SetCubeTile(
{c2Tile[0], c2Tile[1]}, {c2Tile[2], c2Tile[3]}, {c2Tile[4], c2Tile[5]});
auto q1 = Matrix::Matmul(DataType::DT_FP32, tileExpF16, vActualPart, false, false);
TileShape::Current().SetVecTile(v2Tile[0], v2Tile[1]);
auto q2 = Mul(q1, t4);
auto oiTmp = Add(q3, q2);
IF(IsLoopEnd(s2Idx, s2Loop))
{
oiUpdate = Div(oiTmp, liNew);
TileShape::Current().SetVecTile(1, 1, v2Tile[0], v2Tile[1]);
auto outFinal =
Add(Reshape(oiUpdate, {bTile, s1Tile, gTile, dNopeSize}),
Element(oiUpdate.GetStorage()->Datatype(), float(0)));
Assemble(outFinal, oiOffset, attentionOut);
}
ELSE
{
oiUpdate = oiTmp;
}
liUpdate = liNew;
miUpdate = miNew;
}
}
}
}
}
}
}
void WinAttentionDebugCompute(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
auto dtype = qNope.GetStorage()->Datatype();
int dNopeSize = qNope.GetShape()[1];
int dRopeSize = qRope.GetShape()[1];
ASSERT(nKv != 0) << "nKv cant't be zero!";
auto gGroup = nQ / nKv;
int gTile = tileConfig.gTile;
ASSERT(blockSize != 0) << "blockSize can't be zero!";
auto nopeTile = tileConfig.vNopeTileShape;
auto ropeTile = tileConfig.vRopeTileShape;
auto outTile = tileConfig.outTileShape;
auto c1Tile = tileConfig.c1TileShape;
auto v1Tile = tileConfig.v1TileShape;
auto c2Tile = tileConfig.c2TileShape;
auto v2Tile = tileConfig.v2TileShape;
SymbolicScalar bSize = blockTable.GetShape()[0];
SymbolicScalar bTile = 1;
ASSERT(bTile != 0) << "bTile can't be zero!";
ASSERT(nQ != 0) << "nQ can't be zero!";
SymbolicScalar bLoop = bSize / bTile;
SymbolicScalar s1Size = qNope.GetShape()[0] / bSize / nQ;
SymbolicScalar s1Tile = 1;
ASSERT(s1Tile != 0) << "s1Tile can't be zero!";
SymbolicScalar s1Loop = s1Size / s1Tile;
SymbolicScalar n2Tile = 1;
ASSERT(n2Tile != 0) << "n2Tile can't be zero!";
SymbolicScalar n2Loop = nKv / n2Tile;
ASSERT(gTile != 0) << "gTile can't be zero!";
SymbolicScalar gLoop = gGroup / gTile;
SymbolicScalar blockStartIndex = 0;
SymbolicScalar blockStartOffset = 0;
SymbolicScalar blockEndIndex = 0;
SymbolicScalar winActualSize = 0;
SymbolicScalar tableLoop = 0;
LOOP("LOOP_L0_bIdx", FunctionType::DYNAMIC_LOOP, bIdx, LoopRange(0, bLoop, 1), {}, true)
{
SymbolicScalar curActualSeqSize = GetTensorData(actSeqs, {bIdx});
LOOP("LOOP_L1_s1Idx", FunctionType::DYNAMIC_LOOP, s1Idx, LoopRange(0, s1Loop, 1))
{
winActualSize = std::min(windowSize, (curActualSeqSize - s1Size + s1Idx + 1));
blockEndIndex = (curActualSeqSize + blockSize - 1) / blockSize - 1;
blockStartIndex = std::max(0, ((curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) / blockSize));
blockStartOffset = (curActualSeqSize - winActualSize - s1Size + 1 + s1Idx) % blockSize;
tableLoop = blockEndIndex - blockStartIndex + 1;
LOOP("LOOP_L2_n2Idx", FunctionType::DYNAMIC_LOOP, n2Idx, LoopRange(0, n2Loop, 1))
{
LOOP("LOOP_L2_gIdx", FunctionType::DYNAMIC_LOOP, gIdx, LoopRange(0, gLoop, 1))
{
std::vector<SymbolicScalar> outOffset = {bIdx, s1Idx, n2Idx * gGroup + gIdx * gTile, 0};
Tensor kPart(dtype, {5 * blockSize, (dNopeSize + dRopeSize)}, "kPart");
Tensor vPart(dtype, {5 * blockSize, dNopeSize}, "vPart");
LOOP("LOOP_L2_tIdx", FunctionType::DYNAMIC_LOOP, tIdx, LoopRange(0, tableLoop, 1), {}, true)
{
SymbolicScalar curidx = blockStartIndex + tIdx;
SymbolicScalar curBlockIdx = GetTensorData(blockTable, {bIdx, curidx});
auto kNope =
View(vNopeCache, {blockSize, dNopeSize}, {curBlockIdx * blockSize, n2Idx * dNopeSize});
TileShape::Current().SetVecTile(nopeTile[0], nopeTile[1]);
auto tmpK1 = Cast(kNope, DataType::DT_FP32);
auto tmpK2 = Cast(tmpK1, dtype);
Assemble(tmpK2, {tIdx * blockSize, 0}, kPart);
auto kRope =
View(kRopeCache, {blockSize, dRopeSize}, {curBlockIdx * blockSize, n2Idx * dRopeSize});
TileShape::Current().SetVecTile(ropeTile[0], ropeTile[1]);
auto tmpKR1 = Cast(kRope, DataType::DT_FP32);
auto tmpKR2 = Cast(tmpKR1, dtype);
Assemble(tmpKR2, {tIdx * blockSize, dNopeSize}, kPart);
auto vNope =
View(vNopeCache, {blockSize, dNopeSize}, {curBlockIdx * blockSize, n2Idx * dNopeSize});
TileShape::Current().SetVecTile(nopeTile[0], nopeTile[1]);
auto tmpV1 = Cast(vNope, DataType::DT_FP32);
auto tmpV2 = Cast(tmpV1, dtype);
Assemble(tmpV2, {tIdx * blockSize, 0}, vPart);
}
LOOP("LOOP_L2_Idx", FunctionType::DYNAMIC_LOOP, oIdx, LoopRange(1), {}, true)
{
(void)oIdx;
SymbolicScalar curOffset = bIdx * s1Size * nQ + s1Idx * nQ + n2Idx * gGroup + gIdx * gTile;
auto kActualPart = View(
kPart, {windowSize, dNopeSize + dRopeSize}, {winActualSize, dNopeSize + dRopeSize},
{blockStartOffset, 0});
auto vActualPart =
View(vPart, {windowSize, dNopeSize}, {winActualSize, dNopeSize}, {blockStartOffset, 0});
Tensor qPart(dtype, {gTile, dNopeSize + dRopeSize}, "qPart");
auto qNopeL = View(qNope, {gTile, dNopeSize}, {gTile, dNopeSize}, {curOffset, 0});
Assemble(qNopeL, {0, 0}, qPart);
auto qRopeR = View(qRope, {gTile, dNopeSize}, {gTile, dRopeSize}, {curOffset, 0});
Assemble(qRopeR, {0, dNopeSize}, qPart);
TileShape::Current().SetCubeTile(
{c1Tile[0], c1Tile[1]}, {c1Tile[2], c1Tile[3]}, {c1Tile[4], c1Tile[5]}, true);
auto qKT = Matrix::Matmul(DataType::DT_FP32, qPart, kActualPart, false, true);
TileShape::Current().SetVecTile(v1Tile[0], v1Tile[1]);
auto qKTScale = Mul(qKT, Element(qKT.GetStorage()->Datatype(), softmaxScale));
auto tileMax = Amax(qKTScale, -1, true);
auto tileSub = Sub(qKTScale, tileMax);
auto tileExp = Exp(tileSub);
auto tilSum = Sum(tileExp, -1, true);
auto tileSoftmx = Div(tileExp, tilSum);
auto valueType16 = Cast(tileSoftmx, dtype);
TileShape::Current().SetCubeTile(
{c2Tile[0], c2Tile[1]}, {c2Tile[2], c2Tile[3]}, {c2Tile[4], c2Tile[5]});
auto out = Matrix::Matmul(DataType::DT_FP32, valueType16, vActualPart, false, false);
TileShape::Current().SetVecTile(v2Tile[0], v2Tile[1]);
auto outNew = Reshape(out, {bTile, s1Tile, gTile, dNopeSize});
TileShape::Current().SetVecTile(1, 1, outTile[0], outTile[1]);
auto outFinal = Add(outNew, Element(outNew.GetStorage()->Datatype(), 0.0));
Assemble(outFinal, outOffset, attentionOut);
}
}
}
}
}
}
void WinAttentionDebug(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
FUNCTION("main", {qNope, vNopeCache, qRope, kRopeCache, blockTable, actSeqs}, {attentionOut})
{
WinAttentionDebugCompute(
qNope, vNopeCache, qRope, kRopeCache, nQ, nKv, blockTable, actSeqs, windowSize, blockSize, softmaxScale,
attentionOut, tileConfig);
}
}
void WinAttention(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
FUNCTION("main", {qNope, vNopeCache, qRope, kRopeCache, blockTable, actSeqs}, {attentionOut})
{
WinAttentionCompute(
qNope, vNopeCache, qRope, kRopeCache, nQ, nKv, blockTable, actSeqs, windowSize, blockSize, softmaxScale,
attentionOut, tileConfig);
}
}
void WinAttentionFlash(
const Tensor& qNope, Tensor& vNopeCache, const Tensor& qRope, Tensor& kRopeCache, int nQ, int nKv,
Tensor& blockTable, Tensor& actSeqs, int windowSize, int blockSize, float softmaxScale, Tensor& attentionOut,
WinAttenTileShapeConfig& tileConfig)
{
FUNCTION("main", {qNope, vNopeCache, qRope, kRopeCache, blockTable, actSeqs}, {attentionOut})
{
WinAttentionComputeFlash(
qNope, vNopeCache, qRope, kRopeCache, nQ, nKv, blockTable, actSeqs, windowSize, blockSize, softmaxScale,
attentionOut, tileConfig);
}
}
}