* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file deepseek_mla.cpp
* \brief
*/
#include "operator/models/deepseek/deepseek_mla.h"
namespace npu::tile_fwk {
DeepseekAttention::DeepseekAttention(
std::map<std::string, std::variant<bool, int, float, std::string>> config, AttentionW aw, const int inLayerIdx)
: layerIdx(inLayerIdx)
{
attentionDropout = std::get<int>(config["attentionDropout"]);
hiddenSize = std::get<int>(config["hiddenSize"]);
numHeads = std::get<int>(config["numAttentionHeads"]);
maxPositionEmbeddings = std::get<int>(config["maxPositionEmbeddings"]);
ropeTheta = std::get<int>(config["ropeTheta"]);
qLoraRank = std::get<int>(config["qLoraRank"]);
qkRopeHeadDim = std::get<int>(config["qkRopeHeadDim"]);
kvLoraRank = std::get<int>(config["kvLoraRank"]);
vHeadDim = std::get<int>(config["vHeadDim"]);
qkNopeHeadDim = std::get<int>(config["qkNopeHeadDim"]);
qHeadDim = qkNopeHeadDim + qkRopeHeadDim;
isCausal = true;
qAProjW = aw.qAProjW;
qBProjW = aw.qBProjW;
qBProjWScale = aw.qBProjWScale;
kvAProjWithMqaW = aw.kvAProjWithMqaW;
kvBProjWK = aw.kvBProjWK;
kvBProjWV = aw.kvBProjWV;
oProjW = aw.oProjW;
softmaxScale = static_cast<float>(1.0 / std::sqrt(qHeadDim));
if (std::get<int>(config["ropeScaling"]) == 1) {
int factor = 40;
float mscale = 1.0;
float mscaleAllDim = 1.0;
double valuePointOne = 0.1;
if (mscaleAllDim > 1) {
mscale = static_cast<float>(valuePointOne * mscale * std::log(factor) + 1.0);
}
softmaxScale = softmaxScale * mscale * mscale;
}
}
Tensor DeepseekAttention::Attention(Tensor q, Tensor kv, Tensor attenMask)
{
int b = q.GetShape()[0];
int n2 = kv.GetShape()[1];
int s1 = q.GetShape()[2];
int s2 = kv.GetShape()[2];
int kvLoraRankV = std::get<int>(g_deepseekConfig["kvLoraRank"]);
DataType dType = q.GetStorage()->Datatype();
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s1), std::min(NUM_128, s1)}, {NUM_64, NUM_64}, {NUM_128, NUM_128});
Tensor qk = Matrix::BatchMatmul(dType, q, kv, false, true);
TileShape::Current().SetVecTile({1, 1, NUM_128, NUM_64});
Tensor qkFp32 = Cast(qk, DataType::DT_FP32);
qkFp32 = Mul(qkFp32, Element(DataType::DT_FP32, static_cast<double>(softmaxScale)));
qkFp32 = Add(qkFp32, attenMask);
Tensor qk16 = Cast(qkFp32, dType);
Tensor softmax = SoftmaxNew(qk16);
Tensor v = View(kv, {b, n2, s2, kvLoraRankV}, {0, 0, 0, 0});
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s1), std::min(NUM_128, s1)}, {NUM_64, NUM_64}, {NUM_128, NUM_128});
Tensor attenRes = Matrix::BatchMatmul(dType, softmax, v);
return attenRes;
}
Tensor DeepseekAttention::AttentionPost(Tensor attenRes)
{
int b = attenRes.GetShape()[0];
int n = attenRes.GetShape()[1];
int s = attenRes.GetShape()[2];
int bs = b * s;
DataType dType = attenRes.GetStorage()->Datatype();
TileShape::Current().SetVecTile({1, 1, 1, NUM_512});
Tensor attenRes0 = Transpose(attenRes, {1, 2});
TileShape::Current().SetVecTile({1, 1, NUM_128, NUM_64});
Tensor attenRes1 = Reshape(attenRes0, {b * s, n, kvLoraRank});
TileShape::Current().SetVecTile({1, 1, NUM_512});
Tensor attenRes2 = Transpose(attenRes1, {0, 1});
TileShape::Current().SetVecTile({1, NUM_128, NUM_64});
TileShape::Current().SetCubeTile(
{std::min(NUM_128, bs), std::min(NUM_128, bs)}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor mm7Res = Matrix::BatchMatmul(dType, attenRes2, kvBProjWV);
TileShape::Current().SetVecTile(1, 1, NUM_128);
Tensor mm7Res1 = Transpose(mm7Res, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor mm7Res2 = Reshape(mm7Res1, {b, s, n * vHeadDim});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor attnOutW = Unsqueeze(oProjW, 0);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor attenOutput = Matrix::BatchMatmul(dType, mm7Res2, attnOutW);
return attenOutput;
}
Tensor DeepseekAttention::AttentionPost2(Tensor attenRes)
{
int b = attenRes.GetShape()[0];
int n = attenRes.GetShape()[1];
int s = attenRes.GetShape()[2];
int bs = b * s;
int h = oProjW.GetShape()[1];
DataType dType = attenRes.GetStorage()->Datatype();
TileShape::Current().SetVecTile({NUM_16, NUM_16, 1, NUM_128});
Tensor attenRes0 = Transpose(attenRes, {1, 2});
TileShape::Current().SetVecTile({NUM_16, 1, NUM_16, NUM_128});
Tensor attenRes1 = Reshape(attenRes0, {b * s, n, kvLoraRank});
TileShape::Current().SetVecTile({NUM_16, NUM_16, NUM_128});
Tensor attenRes2 = Transpose(attenRes1, {0, 1});
TileShape::Current().SetCubeTile(
{std::min(NUM_128, bs), std::min(NUM_128, bs)}, {NUM_128, NUM_128},
{std::min(NUM_128, h), std::min(NUM_128, h)});
Tensor mm7Res = Matrix::BatchMatmul(dType, attenRes2, kvBProjWV);
TileShape::Current().SetVecTile(NUM_16, NUM_16, NUM_128);
Tensor mm7Res1 = Transpose(mm7Res, {0, 1});
TileShape::Current().SetVecTile(NUM_16, NUM_16, NUM_128);
Tensor mm7Res2 = Reshape(mm7Res1, {b, s, n * vHeadDim});
TileShape::Current().SetVecTile(NUM_128, std::min(NUM_256, h));
Tensor attnOutW = Unsqueeze(oProjW, 0);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {std::min(NUM_128, h), std::min(NUM_128, h)});
Tensor attenOutput = Matrix::BatchMatmul(dType, mm7Res2, attnOutW);
return attenOutput;
}
std::tuple<Tensor, Tensor> DeepseekAttention::QkvPre(Tensor hiddenStates)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
DataType dType = hiddenStates.GetStorage()->Datatype();
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor qAProjW1 = Unsqueeze(qAProjW, 0);
Tensor qBProjW1 = Unsqueeze(qBProjW, 0);
Tensor kvAProjWithMqaW1 = Unsqueeze(kvAProjWithMqaW, 0);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor qAProj = Matrix::BatchMatmul(dType, hiddenStates, qAProjW1);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor qALayerNorm = RmsNorm(qAProj);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor q = Matrix::BatchMatmul(dType, qALayerNorm, qBProjW1);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor q2 = Reshape(q, {b, s, numHeads, qHeadDim});
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor compressedKv = Matrix::BatchMatmul(dType, hiddenStates, kvAProjWithMqaW1);
return std::tie(q2, compressedKv);
}
std::tuple<Tensor, Tensor> DeepseekAttention::QkvPreCv(Tensor hiddenStates)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
DataType dType = hiddenStates.GetStorage()->Datatype();
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor qAProjW1 = Unsqueeze(qAProjW, 0);
Tensor qBProjW1 = Unsqueeze(qBProjW, 0);
Tensor kvAProjWithMqaW1 = Unsqueeze(kvAProjWithMqaW, 0);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
Tensor qAProj = Matrix::BatchMatmul(dType, hiddenStates, qAProjW1);
TileShape::Current().SetVecTile(NUM_2, 1, NUM_512);
Tensor qALayerNorm = RmsNorm(qAProj);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
Tensor q = Matrix::BatchMatmul(dType, qALayerNorm, qBProjW1);
TileShape::Current().SetVecTile(NUM_2, 1, NUM_384);
Tensor q2 = Reshape(q, {b, s, numHeads, qHeadDim});
TileShape::Current().SetCubeTile(
{std::min(NUM_128, s), std::min(NUM_128, s)}, {NUM_128, NUM_128}, {NUM_64, NUM_64});
Tensor compressedKV = Matrix::BatchMatmul(dType, hiddenStates, kvAProjWithMqaW1);
return std::tie(q2, compressedKV);
}
std::vector<Tensor> DeepseekAttention::QkvPre2(Tensor hiddenStates, bool isQuant)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int h = hiddenStates.GetShape()[2];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
DataType dTypeQuantOut = isQuant ? DataType::DT_INT32 : dType;
std::vector<Tensor> qkvPre2Res;
Tensor input = Reshape(hiddenStates, {bs, h});
int c0 = NUM_16;
int m = (std::min(NUM_32, bs) + c0 - 1) / c0 * c0;
int tileM = std::min(NUM_16, m);
TileShape::Current().SetCubeTile({tileM, tileM}, {NUM_256, NUM_256}, {NUM_128, NUM_128});
Tensor qAProj = Matrix::Matmul(dType, input, qAProjW, false, false);
TileShape::Current().SetVecTile(std::min(NUM_16, bs), NUM_128);
Tensor qAProjNorm = RmsNorm(qAProj);
Tensor qAProjNormScaleDequant;
if (isQuant) {
auto qAProjNormQuantRes = Quant(qAProjNorm);
qAProjNorm = std::get<0>(qAProjNormQuantRes);
qAProjNormScaleDequant = std::get<1>(qAProjNormQuantRes);
TileShape::Current().SetCubeTile({tileM, tileM}, {NUM_256, NUM_256}, {NUM_256, NUM_256});
} else {
TileShape::Current().SetCubeTile({m, m}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
}
Tensor q = Matrix::Matmul(dTypeQuantOut, qAProjNorm, qBProjW, false, false);
qkvPre2Res.emplace_back(q);
TileShape::Current().SetCubeTile({m, m}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
Tensor compressedKv = Matrix::Matmul(dType, input, kvAProjWithMqaW, false, false);
Tensor compressedKvRes = Reshape(compressedKv, {b, s, kvLoraRank + qkRopeHeadDim});
qkvPre2Res.emplace_back(compressedKvRes);
if (isQuant) {
qkvPre2Res.emplace_back(qAProjNormScaleDequant);
}
return qkvPre2Res;
}
std::tuple<Tensor, Tensor> DeepseekAttention::QkvPreFp32(Tensor hiddenStates)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int h = hiddenStates.GetShape()[2];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
Tensor input = Reshape(hiddenStates, {bs, h});
TileShape::Current().SetCubeTile(
{std::min(NUM_64, bs), std::min(NUM_64, bs)}, {NUM_256, NUM_256}, {NUM_128, NUM_128});
Tensor qAProjFp32 = Matrix::Matmul(DataType::DT_FP32, input, qAProjW, false, false);
TileShape::Current().SetVecTile(NUM_32, NUM_128);
Tensor qAProjNormFp32 = RmsNorm(qAProjFp32);
std::vector<int64_t> tileShape = {NUM_32, NUM_128};
TileShape::Current().SetVecTile(tileShape);
Tensor qAProjNorm = Cast(qAProjNormFp32, dType);
TileShape::Current().SetCubeTile(
{std::min(NUM_64, bs), std::min(NUM_64, bs)}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
Tensor qFp32 = Matrix::Matmul(DataType::DT_FP32, qAProjNorm, qBProjW, false, false);
Tensor qRes = Reshape(qFp32, {b, s, numHeads, qHeadDim});
TileShape::Current().SetCubeTile(
{std::min(NUM_64, bs), std::min(NUM_64, bs)}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
Tensor compressedKvFp32 = Matrix::Matmul(DataType::DT_FP32, input, kvAProjWithMqaW, false, false);
Tensor compressedKvRes = Reshape(compressedKvFp32, {b, s, kvLoraRank + qkRopeHeadDim});
return std::tie(qRes, compressedKvRes);
}
Tensor DeepseekAttention::Forward(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
auto qKv = QkvPre(hiddenStates);
Tensor q = std::get<0>(qKv);
Tensor compressedKv = std::get<1>(qKv);
Tensor qNope = View(q, {b, s, numHeads, qkNopeHeadDim}, {0, 0, 0, 0});
Tensor qPe = View(q, {b, s, numHeads, qkRopeHeadDim}, {0, 0, 0, qkNopeHeadDim});
TileShape::Current().SetVecTile(1, 1, 1, NUM_64);
qPe = Transpose(qPe, {1, 2});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor kPe = View(compressedKv, {b, s, qkRopeHeadDim}, {0, 0, kvLoraRank});
compressedKv = View(compressedKv, {b, s, kvLoraRank}, {0, 0, 0});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
kPe = Reshape(kPe, {b, 1, s, qkRopeHeadDim});
TileShape::Current().SetVecTile(1, NUM_128, 1, NUM_64);
Tensor qNope1 = Reshape(qNope, {b * s, numHeads, qkNopeHeadDim});
TileShape::Current().SetVecTile(1, 1, NUM_128);
Tensor qNope2 = Transpose(qNope1, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, bs), std::min(NUM_128, bs)}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNope2, kvBProjWK);
TileShape::Current().SetVecTile(1, 1, NUM_512);
Tensor qNopeNew2 = Transpose(qNopeNew, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
qNopeNew2 = Reshape(qNopeNew2, {b, s, numHeads, kvLoraRank});
TileShape::Current().SetVecTile(1, 1, 1, NUM_512);
qNopeNew2 = Transpose(qNopeNew2, {1, 2});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor kNope = RmsNorm(compressedKv);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
kNope = Reshape(kNope, {b, 1, s, kvLoraRank});
Tensor qPeRope(qPe.GetStorage()->Datatype(), {b, numHeads, s, qkRopeHeadDim}, "qPeRope");
Tensor kPeRope(kPe.GetStorage()->Datatype(), {b, 1, s, qkRopeHeadDim}, "kPeRope");
ApplyRotaryPosEmb(qPe, kPe, cos, sin, positionIds, qPeRope, kPeRope, 1, ropeTileShapeConfig);
TileShape::Current().SetVecTile(1, 1, NUM_128, NUM_64);
Tensor queryStates = Cat({qNopeNew2, qPeRope}, -1);
Tensor keyStates = Cat({kNope, kPeRope}, -1);
auto pastKeyStatesNew = ScatterUpdate(pastKeyStates, kvLen, keyStates, -2);
Tensor attenRes = Attention(queryStates, pastKeyStatesNew, attenMask);
return AttentionPost(attenRes);
}
std::tuple<Tensor, Tensor> DeepseekAttention::AtentionPreForward(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig)
{
(void)attenMask;
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
auto qKv = QkvPre(hiddenStates);
Tensor q = std::get<0>(qKv);
Tensor compressedKv = std::get<1>(qKv);
Tensor qNope = View(q, {b, s, numHeads, qkNopeHeadDim}, {0, 0, 0, 0});
Tensor qPe = View(q, {b, s, numHeads, qkRopeHeadDim}, {0, 0, 0, qkNopeHeadDim});
TileShape::Current().SetVecTile(1, 1, 1, NUM_64);
qPe = Transpose(qPe, {1, 2});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor kPe = View(compressedKv, {b, s, qkRopeHeadDim}, {0, 0, kvLoraRank});
compressedKv = View(compressedKv, {b, s, kvLoraRank}, {0, 0, 0});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
kPe = Reshape(kPe, {b, 1, s, qkRopeHeadDim});
TileShape::Current().SetVecTile(1, NUM_128, 1, NUM_64);
Tensor qNope1 = Reshape(qNope, {b * s, numHeads, qkNopeHeadDim});
TileShape::Current().SetVecTile(1, 1, NUM_128);
Tensor qNope2 = Transpose(qNope1, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, bs), std::min(NUM_128, bs)}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
TileShape::Current().SetVecTile(NUM_128, NUM_64);
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNope2, kvBProjWK);
TileShape::Current().SetVecTile(1, 1, NUM_512);
Tensor qNopeNew2 = Transpose(qNopeNew, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
qNopeNew2 = Reshape(qNopeNew2, {b, s, numHeads, kvLoraRank});
TileShape::Current().SetVecTile(1, 1, 1, NUM_512);
qNopeNew2 = Transpose(qNopeNew2, {1, 2});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
Tensor kNope = RmsNorm(compressedKv);
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
kNope = Reshape(kNope, {b, 1, s, kvLoraRank});
Tensor qPeRope(qPe.GetStorage()->Datatype(), {b, numHeads, s, qkRopeHeadDim}, "qPeRope");
Tensor kPeRope(kPe.GetStorage()->Datatype(), {b, 1, s, qkRopeHeadDim}, "kPeRope");
ApplyRotaryPosEmb(qPe, kPe, cos, sin, positionIds, qPeRope, kPeRope, 1, ropeTileShapeConfig);
TileShape::Current().SetVecTile(1, 1, NUM_128, NUM_64);
Tensor queryStates = Cat({qNopeNew2, qPeRope}, -1);
Tensor keyStates = Cat({kNope, kPeRope}, -1);
auto pastKeyStatesNew = ScatterUpdate(pastKeyStates, kvLen, keyStates, -2);
return std::tie(queryStates, pastKeyStatesNew);
}
std::tuple<Tensor, Tensor> DeepseekAttention::AtentionPreForwardCv(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig)
{
(void)attenMask;
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
auto qKv = QkvPreCv(hiddenStates);
Tensor q = std::get<0>(qKv);
Tensor compressedKv = std::get<1>(qKv);
Tensor qNope = View(q, {b, s, numHeads, qkNopeHeadDim}, {0, 0, 0, 0});
Tensor qPe = View(q, {b, s, numHeads, qkRopeHeadDim}, {0, 0, 0, qkNopeHeadDim});
TileShape::Current().SetVecTile(NUM_2, 1, NUM_32, NUM_64);
qPe = Transpose(qPe, {1, 2});
Tensor kPe = View(compressedKv, {b, s, qkRopeHeadDim}, {0, 0, kvLoraRank});
compressedKv = View(compressedKv, {b, s, kvLoraRank}, {0, 0, 0});
TileShape::Current().SetVecTile(NUM_2, 1, NUM_64);
kPe = Reshape(kPe, {b, 1, s, qkRopeHeadDim});
TileShape::Current().SetVecTile(NUM_2, 1, NUM_32, NUM_128);
Tensor qNope1 = Reshape(qNope, {b * s, numHeads, qkNopeHeadDim});
TileShape::Current().SetVecTile(NUM_2, NUM_32, NUM_128);
Tensor qNope2 = Transpose(qNope1, {0, 1});
TileShape::Current().SetVecTile(1, NUM_128, NUM_64);
TileShape::Current().SetCubeTile(
{std::min(NUM_128, bs), std::min(NUM_128, bs)}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNope2, kvBProjWK);
TileShape::Current().SetVecTile(NUM_16, NUM_2, NUM_512);
Tensor qNopeNew2 = Transpose(qNopeNew, {0, 1});
TileShape::Current().SetVecTile(1, NUM_32, NUM_512);
qNopeNew2 = Reshape(qNopeNew2, {b, s, numHeads, kvLoraRank});
TileShape::Current().SetVecTile(NUM_2, 1, NUM_32, NUM_256);
qNopeNew2 = Transpose(qNopeNew2, {1, 2});
TileShape::Current().SetVecTile(NUM_2, 1, NUM_512);
Tensor kNope = RmsNorm(compressedKv);
TileShape::Current().SetVecTile(NUM_2, 1, NUM_512);
kNope = Reshape(kNope, {b, 1, s, kvLoraRank});
Tensor qPeRope(qPe.GetStorage()->Datatype(), {b, numHeads, s, qkRopeHeadDim}, "qPeRope");
Tensor kPeRope(kPe.GetStorage()->Datatype(), {b, 1, s, qkRopeHeadDim}, "kPeRope");
ApplyRotaryPosEmb(qPe, kPe, cos, sin, positionIds, qPeRope, kPeRope, 1, ropeTileShapeConfig);
TileShape::Current().SetVecTile(NUM_2, NUM_32, 1, NUM_64);
Tensor queryStates = Cat({qNopeNew2, qPeRope}, -1);
Tensor keyStates = Cat({kNope, kPeRope}, -1);
TileShape::Current().SetVecTile(NUM_2, 1, NUM_128, NUM_64);
auto pastKeyStatesNew = ScatterUpdate(pastKeyStates, kvLen, keyStates, -2);
return std::tie(queryStates, pastKeyStatesNew);
}
std::tuple<Tensor, Tensor> DeepseekAttention::MlaPrologAbForward(Tensor hiddenStates, Tensor qPeRope, bool isQuant)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
auto qKv = QkvPre2(hiddenStates, isQuant);
Tensor q = qKv[0];
Tensor kvTmp = qKv[1];
if (isQuant) {
std::vector<int64_t> tileShape = {std::min(NUM_32, bs), NUM_64};
TileShape::Current().SetVecTile(tileShape);
auto qTmpFp32 = Cast(q, DataType::DT_FP32);
auto qTmpScaleDequant = qKv[2];
auto qTmpDequantPerToken = Mul(qTmpFp32, qTmpScaleDequant);
auto qTmpDequantChannel = Mul(qTmpDequantPerToken, qBProjWScale);
q = Cast(qTmpDequantChannel, dType);
}
auto qTmp = Reshape(q, {b, s, numHeads, qHeadDim});
Tensor qNope = View(qTmp, {b, s, numHeads, qkNopeHeadDim}, {0, 0, 0, 0});
std::vector<int64_t> tileShape = {NUM_2, 1, NUM_32, NUM_128};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeR = Reshape(qNope, {bs, numHeads, qkNopeHeadDim});
tileShape = {NUM_2, NUM_32, qkNopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeT = Transpose(qNopeR, {0, 1});
int c0 = NUM_16;
int m = (std::min(NUM_32, bs) + c0 - 1) / c0 * c0;
TileShape::Current().SetCubeTile({m, m}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNopeT, kvBProjWK);
tileShape = {NUM_16, NUM_2, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeNewT = Transpose(qNopeNew, {0, 1});
Tensor qNopeNewR = Reshape(qNopeNewT, {b, s, numHeads, kvLoraRank});
tileShape = {NUM_2, 1, NUM_32, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeNewT2 = Transpose(qNopeNewR, {1, 2});
tileShape = {NUM_2, NUM_32, 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor queryStates = Cat({qNopeNewT2, qPeRope}, -1);
return {queryStates, kvTmp};
}
std::vector<Tensor> DeepseekAttention::MlaPrologFoward(
Tensor hiddenStates, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen, Tensor pastKeyStates,
const RoPETileShapeConfig& ropeTileShapeConfig, bool isQuant)
{
int b = hiddenStates.GetShape()[0];
int s = hiddenStates.GetShape()[1];
int bs = b * s;
DataType dType = hiddenStates.GetStorage()->Datatype();
auto qKv = QkvPre2(hiddenStates, isQuant);
Tensor q = qKv[0];
Tensor kvTmp = qKv[1];
if (isQuant) {
std::vector<int64_t> tileShape = {std::min(NUM_32, bs), NUM_64};
TileShape::Current().SetVecTile(tileShape);
auto qTmpFp32 = Cast(q, DataType::DT_FP32);
auto qTmpScaleDequant = qKv[2];
auto qTmpDequantPerToken = Mul(qTmpFp32, qTmpScaleDequant);
auto qTmpDequantChannel = Mul(qTmpDequantPerToken, qBProjWScale);
q = Cast(qTmpDequantChannel, dType);
}
auto qTmp = Reshape(q, {b, s, numHeads, qHeadDim});
Tensor qNope = View(qTmp, {b, s, numHeads, qkNopeHeadDim}, {0, 0, 0, 0});
std::vector<int64_t> tileShape = {NUM_32, 1, 1, NUM_128};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeR = Reshape(qNope, {bs, numHeads, qkNopeHeadDim});
tileShape = {NUM_2, NUM_32, qkNopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeT = Transpose(qNopeR, {0, 1});
int c0 = NUM_16;
int m = (std::min(NUM_32, bs) + c0 - 1) / c0 * c0;
TileShape::Current().SetCubeTile({m, m}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNopeT, kvBProjWK);
tileShape = {NUM_16, NUM_2, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeNewT = Transpose(qNopeNew, {0, 1});
Tensor qNopeNewR = Reshape(qNopeNewT, {b, s, numHeads, kvLoraRank});
tileShape = {NUM_2, 1, NUM_32, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeNewT2 = Transpose(qNopeNewR, {1, 2});
Tensor compressedKv = View(kvTmp, {b, s, kvLoraRank}, {0, 0, 0});
tileShape = {NUM_2, 1, NUM_512};
TileShape::Current().SetVecTile(tileShape);
Tensor compressedKvNorm = RmsNorm(compressedKv);
Tensor kNope = Reshape(compressedKvNorm, {b, 1, s, kvLoraRank});
Tensor qPe = View(qTmp, {b, s, numHeads, qkRopeHeadDim}, {0, 0, 0, qkNopeHeadDim});
tileShape = {NUM_2, 1, NUM_32, qkNopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
Tensor qPeT = Transpose(qPe, {1, 2});
Tensor kPe = View(kvTmp, {b, s, qkRopeHeadDim}, {0, 0, kvLoraRank});
tileShape = {std::min(NUM_32, bs), 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor kPeR = Reshape(kPe, {b, 1, s, qkRopeHeadDim});
Tensor qPeRope(qPeT.GetStorage()->Datatype(), {b, numHeads, s, qkRopeHeadDim}, "qPeRope");
Tensor kPeRope(kPeR.GetStorage()->Datatype(), {b, 1, s, qkRopeHeadDim}, "kPeRope");
ApplyRotaryPosEmb(qPeT, kPeR, cos, sin, positionIds, qPeRope, kPeRope, 1, ropeTileShapeConfig);
tileShape = {NUM_2, NUM_32, 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor queryStates = Cat({qNopeNewT2, qPeRope}, -1);
tileShape = {1, 1, 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor keyStates = Cat({kNope, kPeRope}, -1);
tileShape = {1, 1, NUM_256, NUM_64};
TileShape::Current().SetVecTile(tileShape);
pastKeyStates = ScatterUpdate(pastKeyStates, kvLen, keyStates, -2);
std::vector<Tensor> res = {queryStates, pastKeyStates, qNopeNewT2, qPeRope};
return res;
}
Tensor DeepseekV2MoE::MoeInfer(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
int nRoutedExperts)
{
int bs = topkIds.GetShape(0);
int expertPerTok = topkIds.GetShape(1);
std::vector<int64_t> zerosShape(NUM_2);
zerosShape[0] = bs;
zerosShape[1] = nRoutedExperts;
Tensor randoms(topkIds.GetStorage()->Datatype(), zerosShape);
Tensor cnts = Mul(randoms, Element(DataType::DT_FP32, F_0));
cnts = Scatter(cnts, topkIds, Element(DataType::DT_FP32, F_1), 1);
Tensor tokensPerExpert = Sum(cnts, 0, true);
TileShape::Current().SetVecTile(NUM_128);
Tensor idxs =
ArgSort(Cast(Reshape(topkIds, {bs * expertPerTok}), DataType::DT_FP32), -1, false);
TileShape::Current().SetVecTile({NUM_128, NUM_128});
Tensor sortedTokens = TensorIndex(
x, Cast(
Div(Cast(idxs, DataType::DT_FP32), Element(DataType::DT_FP32, static_cast<double>(expertPerTok))),
DataType::DT_INT32, CAST_TRUNC));
auto& sortedTokensShape = sortedTokens.GetShape();
std::vector<int> tokensPerExpertCpu(NUM_256, 0);
for (size_t i = 0; i < NUM_8; i++) {
tokensPerExpertCpu[i] = sortedTokensShape[0] / NUM_8;
}
std::vector<Tensor> outputs;
int startIdx = 0;
for (size_t i = 0; i < tokensPerExpertCpu.size(); i++) {
int numTokens = tokensPerExpertCpu[i];
if (numTokens == 0) {
continue;
}
const int endIdx = startIdx + numTokens;
Tensor tokensForThisExpert = View(sortedTokens, {numTokens, sortedTokensShape[1]}, {startIdx, 0});
std::cout << "=numTokens====" << numTokens << std::endl;
for (auto n : tokensForThisExpert.GetShape()) {
std::cout << "=tokensForThisExpert.GetShape()" << n << std::endl;
}
auto expertOut = expert.Forward(tokensForThisExpert, ffnWeight1, ffnWeight2, ffnWeight3);
outputs.emplace_back(expertOut);
startIdx = endIdx;
}
Tensor outs = Cat(outputs, 0);
Tensor newX(outs.GetDataType(), outs.GetShape());
for (auto n : outs.GetShape()) {
std::cout << "=outs.GetShape()" << n << std::endl;
}
TileShape::Current().SetVecTile(NUM_16);
auto newIdxs = Reshape(idxs, {idxs.GetShape(0)});
IndexPut_(newX, {newIdxs}, outs);
int newXSize = std::accumulate(
newX.GetShape().begin(), newX.GetShape().end(), 1, [](const int& a, const int& b) { return a * b; });
std::cout << "===newXSize" << newXSize << std::endl;
std::vector<int64_t> newShape = {bs, expertPerTok, newXSize / (bs * expertPerTok)};
auto newXShape = Reshape(newX, newShape);
TileShape::Current().SetVecTile(NUM_16, NUM_128, NUM_128);
auto wShapes = topkWeight.GetShape();
wShapes.emplace_back(1);
auto newW = Unsqueeze(topkWeight, NUM_2);
auto newMul = Mul(newXShape, newW);
auto reduceRes = Sum(newMul, 1, true);
for (auto n : reduceRes.GetShape()) {
std::cout << "=reduceRes.GetShape().shape" << n << std::endl;
}
auto fOut = Reshape(reduceRes, {bs, newXSize / (bs * expertPerTok)});
return fOut;
}
Tensor DeepseekV2MoE::MoeInferSingleMlp(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
int nRoutedExperts)
{
(void)topkWeight;
int bs = topkIds.GetShape(0);
int expertPerTok = topkIds.GetShape(1);
std::vector<int64_t> zerosShape(NUM_2);
zerosShape[0] = bs;
zerosShape[1] = nRoutedExperts;
Tensor randoms(topkIds.GetStorage()->Datatype(), zerosShape);
Tensor cnts = Mul(randoms, Element(DataType::DT_FP32, F_0));
cnts = Scatter(cnts, topkIds, Element(DataType::DT_FP32, F_1), 1);
Tensor tokensPerExpert = Sum(cnts, 0, true);
TileShape::Current().SetVecTile(NUM_128);
Tensor idxs =
ArgSort(Cast(Reshape(topkIds, {bs * expertPerTok}), DataType::DT_FP32), -1, false);
TileShape::Current().SetVecTile({NUM_128, NUM_128});
Tensor sortedTokens = TensorIndex(
x, Cast(
Div(Cast(idxs, DataType::DT_FP32), Element(DataType::DT_FP32, static_cast<double>(expertPerTok))),
DataType::DT_INT32, CAST_TRUNC));
auto& sortedTokensShape = sortedTokens.GetShape();
std::vector<int> tokensPerExpertCpu(NUM_256, 0);
for (size_t i = 0; i < 1; i++) {
tokensPerExpertCpu[i] = sortedTokensShape[0] / NUM_8;
}
std::vector<Tensor> outputs;
int startIdx = 0;
for (size_t i = 0; i < tokensPerExpertCpu.size(); i++) {
int numTokens = tokensPerExpertCpu[i];
if (numTokens == 0) {
continue;
}
const int endIdx = startIdx + numTokens;
Tensor tokensForThisExpert = View(sortedTokens, {numTokens, sortedTokensShape[1]}, {startIdx, 0});
std::cout << "=numTokens====" << numTokens << std::endl;
for (auto n : tokensForThisExpert.GetShape()) {
std::cout << "=tokensForThisExpert.GetShape().shape" << n << std::endl;
}
auto expertOut = expert.Forward(tokensForThisExpert, ffnWeight1, ffnWeight2, ffnWeight3);
outputs.emplace_back(expertOut);
startIdx = endIdx;
}
Tensor outs = Cat(outputs, 0);
return outs;
}
Tensor DeepseekV2MoE::MoeInferSingleMlpQuant(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
Tensor ffnwight1Scale, Tensor ffnwight2Scale, Tensor ffnwight3Scale, int nRoutedExperts)
{
(void)topkWeight;
int bs = topkIds.GetShape(0);
int expertPerTok = topkIds.GetShape(1);
std::vector<int64_t> zerosShape(NUM_2);
zerosShape[0] = bs;
zerosShape[1] = nRoutedExperts;
Tensor randoms(topkIds.GetStorage()->Datatype(), zerosShape);
Tensor cnts = Mul(randoms, Element(DataType::DT_FP32, F_0));
cnts = Scatter(cnts, topkIds, Element(DataType::DT_FP32, F_1), 1);
Tensor tokensPerExpert = Sum(cnts, 0, true);
TileShape::Current().SetVecTile(NUM_128);
Tensor idxs =
ArgSort(Cast(Reshape(topkIds, {bs * expertPerTok}), DataType::DT_FP32), -1, false);
TileShape::Current().SetVecTile({NUM_32, NUM_512});
Tensor sortedTokens = TensorIndex(
x, Cast(
Div(Cast(idxs, DataType::DT_FP32), Element(DataType::DT_FP32, static_cast<double>(expertPerTok))),
DataType::DT_INT32, CAST_TRUNC));
TileShape::Current().SetVecTile({NUM_256, NUM_256});
auto& sortedTokensShape = sortedTokens.GetShape();
std::vector<int> tokensPerExpertCpu(NUM_256, 0);
for (size_t i = 0; i < 1; i++) {
tokensPerExpertCpu[i] = sortedTokensShape[0] / NUM_8;
}
std::vector<Tensor> outputs;
int startIdx = 0;
for (size_t i = 0; i < tokensPerExpertCpu.size(); i++) {
int numTokens = tokensPerExpertCpu[i];
if (numTokens == 0) {
continue;
}
const int endIdx = startIdx + numTokens;
Tensor tokensForThisExpert = View(sortedTokens, {numTokens, sortedTokensShape[1]}, {startIdx, 0});
std::cout << "=numTokens====" << numTokens << std::endl;
for (auto n : tokensForThisExpert.GetShape()) {
std::cout << "=tokensForThisExpert.GetShape().shape" << n << std::endl;
}
auto expertOut = expert.ForwardWithQuant(
tokensForThisExpert, ffnWeight1, ffnWeight2, ffnWeight3, ffnwight1Scale, ffnwight2Scale, ffnwight3Scale);
outputs.emplace_back(expertOut);
startIdx = endIdx;
}
Tensor outs = Cat(outputs, 0);
return outs;
}
Tensor DeepseekV2MoE::MoeInfer(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3, Tensor& idxs,
Tensor& sortedTokens, Tensor& outs, int nRoutedExperts)
{
int bs = topkIds.GetShape(0);
int expertPerTok = topkIds.GetShape(1);
std::vector<int64_t> zerosShape(NUM_2);
zerosShape[0] = bs;
zerosShape[1] = nRoutedExperts;
Tensor randoms(topkIds.GetStorage()->Datatype(), zerosShape);
Tensor cnts = Mul(randoms, Element(DataType::DT_FP32, F_0));
cnts = Scatter(cnts, topkIds, Element(DataType::DT_FP32, F_1), 1);
Tensor tokensPerExpert = Sum(cnts, 0, true);
TileShape::Current().SetVecTile(NUM_128);
idxs = ArgSort(Cast(Reshape(topkIds, {bs * expertPerTok}), DataType::DT_FP32), -1, false);
TileShape::Current().SetVecTile({NUM_64, NUM_64});
sortedTokens = TensorIndex(
x, Cast(
Div(Cast(idxs, DataType::DT_FP32), Element(DataType::DT_FP32, static_cast<double>(expertPerTok))),
DataType::DT_INT32, CAST_TRUNC));
auto& sortedTokensShape = sortedTokens.GetShape();
std::vector<int> tokensPerExpertCpu(NUM_256, 0);
for (size_t i = 0; i < NUM_8; i++) {
tokensPerExpertCpu[i] = sortedTokensShape[0] / NUM_8;
}
std::vector<Tensor> outputs;
int startIdx = 0;
for (size_t i = 0; i < tokensPerExpertCpu.size(); i++) {
int numTokens = tokensPerExpertCpu[i];
if (numTokens == 0) {
continue;
}
const int endIdx = startIdx + numTokens;
Tensor tokensForThisExpert = View(sortedTokens, {numTokens, sortedTokensShape[1]}, {startIdx, 0});
std::cout << "=numTokens====" << numTokens << std::endl;
for (auto n : tokensForThisExpert.GetShape()) {
std::cout << "=tokensForThisExpert.GetShape().shape" << n << std::endl;
}
auto expertOut = expert.Forward(tokensForThisExpert, ffnWeight1, ffnWeight2, ffnWeight3);
outputs.emplace_back(expertOut);
startIdx = endIdx;
}
outs = Cat(outputs, 0);
Tensor newX(outs.GetDataType(), outs.GetShape());
for (auto n : outs.GetShape()) {
std::cout << "=outs.GetShape().shape" << n << std::endl;
}
TileShape::Current().SetVecTile({NUM_128, NUM_128});
auto newIdxs = Reshape(idxs, {idxs.GetShape(0)});
TileShape::Current().SetVecTile(NUM_16);
IndexPut_(newX, {newIdxs}, outs);
int newXSize = std::accumulate(
newX.GetShape().begin(), newX.GetShape().end(), 1, [](const int& a, const int& b) { return a * b; });
std::cout << "===newXSize" << newXSize << std::endl;
std::vector<int64_t> newShape = {bs, expertPerTok, newXSize / (bs * expertPerTok)};
auto newXShape = Reshape(newX, newShape);
TileShape::Current().SetVecTile(NUM_16, NUM_64, NUM_64);
auto wShape = topkWeight.GetShape();
wShape.emplace_back(1);
auto newW = Unsqueeze(topkWeight, NUM_2);
auto newMul = Mul(newXShape, newW);
auto reduceRes = Sum(newMul, 1, true);
for (auto n : reduceRes.GetShape()) {
std::cout << "=reduceRes.GetShape().shape" << n << std::endl;
}
auto fOut = Reshape(reduceRes, {bs, newXSize / (bs * expertPerTok)});
return fOut;
}
Tensor DeepseekV2MoE::MoeInfer(Tensor x, Tensor topkIds, Tensor topkWeight, int nRoutedExperts)
{
int bs = topkIds.GetShape(0);
int expertPerTok = topkIds.GetShape(1);
const int twoDim = 2;
std::vector<int64_t> zerosShape(twoDim);
zerosShape[0] = bs;
zerosShape[1] = nRoutedExperts;
Tensor randoms(topkIds.GetStorage()->Datatype(), zerosShape);
Tensor cnts = Mul(randoms, Element(DataType::DT_FP32, F_0));
cnts = Scatter(cnts, topkIds, Element(DataType::DT_FP32, F_1), 1);
Tensor tokensPerExpert = Sum(cnts, 0, true);
Tensor idxs = ArgSort(Cast(Reshape(topkIds, {bs * expertPerTok}), DataType::DT_FP32), -1);
Tensor sortedTokens = TensorIndex(
x, Cast(
Div(Cast(idxs, DataType::DT_FP32), Element(DataType::DT_FP32, static_cast<double>(expertPerTok))),
DataType::DT_INT32, CAST_TRUNC));
auto& sortedTokensShape = sortedTokens.GetShape();
std::vector<int> tokensPerExpertCpu(NUM_256, 0);
for (int i = 0; i < NUM_8; i++) {
tokensPerExpertCpu[i] = sortedTokensShape[0] / NUM_8;
}
std::vector<Tensor> outputs;
int startIdx = 0;
for (size_t i = 0; i < tokensPerExpertCpu.size(); i++) {
int numTokens = tokensPerExpertCpu[i];
if (numTokens == 0) {
continue;
}
const int endIdx = startIdx + numTokens;
Tensor tokensForThisExpert = View(sortedTokens, {numTokens, sortedTokensShape[1]}, {startIdx, 0});
auto expertOut = expert.Forward(tokensForThisExpert);
outputs.emplace_back(expertOut);
startIdx = endIdx;
}
auto outs = Cat(outputs, 0);
Tensor newX(outs.GetDataType(), outs.GetShape());
auto newIdxs = Reshape(idxs, {idxs.GetShape(0)});
TileShape::Current().SetVecTile(NUM_8);
IndexPut_(newX, {newIdxs}, outs);
int newXSize = std::accumulate(
newX.GetShape().begin(), newX.GetShape().end(), 1, [](const int& a, const int& b) { return a * b; });
std::vector<int64_t> newShape = {bs, expertPerTok, newXSize / (bs * expertPerTok)};
auto newXShape = Reshape(newX, newShape);
TileShape::Current().SetVecTile(NUM_128, NUM_64, NUM_64);
auto newl = Cast(newXShape, topkWeight.GetDataType());
auto wShape = topkWeight.GetShape();
wShape.emplace_back(1);
auto newW = Unsqueeze(topkWeight, 2);
auto newMul = Mul(newl, newW);
auto fOut = Cast(Sum(newMul, 1, true), newX.GetDataType());
TileShape::Current().SetVecTile(NUM_128, NUM_64);
return fOut;
}
Tensor DeepseekV2MoE::Forward(Tensor hiddenStates)
{
const Tensor identity = hiddenStates;
const std::vector<int64_t>& origShape = hiddenStates.GetShape();
auto moeGateRes = moeGate.Forward(hiddenStates);
const Tensor& topkIdx = std::get<0>(moeGateRes);
const Tensor& topkWeight = std::get<1>(moeGateRes);
Tensor inferRes = MoeInfer(hiddenStates, topkIdx, topkWeight);
inferRes = Reshape(inferRes, origShape);
const Tensor& sharedMlp = sharedExpert.Forward(identity);
return Add(inferRes, sharedMlp);
}
std::tuple<Tensor, Tensor> MoEGate::Forward(const Tensor& hiddenStates)
{
int bs = hiddenStates.GetShape()[0];
auto logits = Matrix::Matmul(
DataType::DT_FP32, hiddenStates, weight, false, true);
auto scores = Sigmoid(logits);
auto scoresForChoice = Add(scores, eScoreCorrectionBias);
std::vector<int64_t> shape = {
scoresForChoice.GetShape()[0] * nGroup, scoresForChoice.GetShape()[1] / nGroup};
auto scoresForChoiceNewShape = Reshape(scoresForChoice, shape);
auto scoresForChoiceIndex = std::get<0>(TopK(scoresForChoiceNewShape, 2, -1));
auto groupScores = Sum(scoresForChoiceIndex, 1, true);
auto groupScoresReshape = Reshape(groupScores, {groupScores.GetShape()[0] / nGroup, nGroup});
auto groupIdx = std::get<1>(TopK(groupScoresReshape, topkGroup, 1));
auto groupMask = Mul(groupScoresReshape, Element(DataType::DT_FP32, F_0));
auto groupMaskScatter = Scatter(groupMask, groupIdx, Element(DataType::DT_FP32, F_1), 1);
int dim0 = groupMaskScatter.GetShape()[0] * groupMaskScatter.GetShape()[1];
auto scoreMask =
Expand(Reshape(groupMaskScatter, {dim0, 1}), {dim0, nRoutedExperts / nGroup});
scoreMask = Reshape(scoreMask, {bs, nRoutedExperts});
auto scoreMaskNot = Mul(scoreMask, Element(DataType::DT_FP32, F_NEGA_1));
auto tmpScores = Mul(scoresForChoice, scoreMaskNot);
auto topkIdx = std::get<1>(TopK(tmpScores, numExpertsPerTok, -1));
auto topkWeight = GatherElements(scores, topkIdx, 1);
auto topkWeightSum = Sum(topkWeight, 1, true);
auto denominator = Add(topkWeightSum, Element(DataType::DT_FP32, DF_1E_20));
topkWeight = Div(topkWeight, denominator);
return std::make_tuple(topkIdx, topkWeight);
}
Tensor DeepseekV2MLP::Forward(Tensor x)
{
auto& xShape = x.GetShape();
auto mSize = std::accumulate(xShape.begin(), xShape.end() - 1, 1, [](const int& a, const int& b) { return a * b; });
if (xShape.size() > NUM_2) {
x = Reshape(x, {mSize, xShape[xShape.size() - 1]});
}
const Tensor& gateProj = Matrix::Matmul(DataType::DT_FP32, x, gateProjW, false, false);
const Tensor& gateSilu =
Div(gateProj, Add(Exp(Mul(gateProj, Element(DataType::DT_FP32, F_NEGA_1))), Element(DataType::DT_FP32, F_1)));
const Tensor& upProj = Matrix::Matmul(DataType::DT_FP32, x, upProjW, false, false);
const Tensor& mul = Mul(gateSilu, upProj);
Tensor downProj = Matrix::Matmul(DataType::DT_FP32, mul, downProjW, false, false);
if (xShape.size() > NUM_2) {
downProj = Reshape(downProj, xShape);
}
return downProj;
}
Tensor DeepseekV2MLP::Forward(Tensor x, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3)
{
auto castRes = Cast(x, DataType::DT_FP16);
auto gate =
Matrix::Matmul(DataType::DT_FP32, castRes, ffnWeight1, false, false);
auto swish = Mul(gate, Element(DataType::DT_FP32, F_NEGA_1));
swish = Exp(swish);
swish = Add(swish, Element(DataType::DT_FP32, F_1));
swish = Div(gate, swish);
auto up =
Matrix::Matmul(DataType::DT_FP32, castRes, ffnWeight2, false, false);
swish = Mul(swish, up);
auto swishFp16 = Cast(swish, DataType::DT_FP16);
Tensor res = Matrix::Matmul(
DataType::DT_FP32, swishFp16, ffnWeight3, false, true);
return res;
}
Tensor DeepseekV2MLP::ForwardWithQuant(
Tensor x, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3, Tensor ffnwight1Scale, Tensor ffnwight2Scale,
Tensor ffnwight3Scale)
{
TileShape::Current().SetVecTile({NUM_32, NUM_512});
auto normQuantRes = Quant(x);
TileShape::Current().SetVecTile({NUM_256, NUM_256});
Tensor castRes = std::get<0>(normQuantRes);
Tensor castResScale = std::get<1>(normQuantRes);
TileShape::Current().SetCubeTile({NUM_64, NUM_64}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
auto gateInt32 = Matrix::Matmul(DataType::DT_INT32, castRes, ffnWeight1, false, false);
auto gateTmpFp32 = Cast(gateInt32, DataType::DT_FP32);
auto gateTmpDequantPerToken = Mul(gateTmpFp32, castResScale);
auto gate = Mul(gateTmpDequantPerToken, ffnwight1Scale);
auto swish = Mul(gate, Element(DataType::DT_FP32, F_NEGA_1));
swish = Exp(swish);
swish = Add(swish, Element(DataType::DT_FP32, F_1));
swish = Div(gate, swish);
auto upInt32 = Matrix::Matmul(DataType::DT_INT32, castRes, ffnWeight2, false, false);
auto upTmpFp32 = Cast(upInt32, DataType::DT_FP32);
auto upTmpDequantPerToken = Mul(upTmpFp32, castResScale);
auto up = Mul(upTmpDequantPerToken, ffnwight2Scale);
swish = Mul(swish, up);
TileShape::Current().SetVecTile({NUM_32, NUM_512});
auto swishQuantRes = Quant(swish);
TileShape::Current().SetVecTile({NUM_256, NUM_256});
Tensor swishRes = std::get<0>(swishQuantRes);
Tensor swishScale = std::get<1>(swishQuantRes);
Tensor resInt32 = Matrix::Matmul(DataType::DT_INT32, swishRes, ffnWeight3, false, true);
auto resTmpFp32 = Cast(resInt32, DataType::DT_FP32);
auto resTmpDequantPerToken = Mul(resTmpFp32, swishScale);
Tensor ffnwight3ScaleTrans = Transpose(ffnwight3Scale, {0, 1});
auto res = Mul(resTmpDequantPerToken, ffnwight3ScaleTrans);
return res;
}
}