* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mla_prolog.cpp
* \brief
*/
#include "operator/models/deepseek/mla_prolog.h"
#include "operator/models/deepseek/deepseek_mla.h"
namespace npu::tile_fwk {
std::vector<Tensor> QkvPre(
const Tensor& tokenX, const Tensor& wDq, const Tensor& wUqQr, const Tensor& wDkvKr, const Tensor& gammaCq,
float epsilonCq, MlaQuantInputs quantInputs, bool splitReduceLastDim, bool splitK)
{
Tensor dequantScaleWUqQr = quantInputs.dequantScaleWUqQr;
bool isQuant = (dequantScaleWUqQr.GetStorage() != nullptr);
Tensor smoothScalesCq = quantInputs.smoothScalesCq;
bool hasSmooth = (smoothScalesCq.GetStorage() != nullptr);
int b = tokenX.GetShape()[0];
int s = tokenX.GetShape()[1];
int h = tokenX.GetShape()[2];
int bs = b * s;
int q_lora_rank = wDq.GetShape()[1];
DataType dType = tokenX.GetStorage()->Datatype();
DataType dTypeQuantOut = isQuant ? DataType::DT_INT32 : dType;
std::vector<Tensor> qkvPreRes;
Tensor input = Reshape(tokenX, {bs, h});
int c0 = NUM_16;
int tieM = (bs + c0 - 1) / c0 * c0;
Tensor qMmRes;
if (splitK) {
TileShape::Current().SetCubeTile({tieM, tieM}, {NUM_256, NUM_256}, {NUM_64, NUM_64}, true);
Tensor qMmResF32 = Matrix::Matmul(DT_FP32, input, wDq);
TileShape::Current().SetVecTile(std::min(NUM_32, bs), NUM_128);
qMmRes = Cast(qMmResF32, dType);
} else {
TileShape::Current().SetCubeTile({tieM, tieM}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
qMmRes = Matrix::Matmul(dType, input, wDq);
}
if (splitReduceLastDim) {
TileShape::Current().SetVecTile(std::min(NUM_16, bs), NUM_128);
} else {
TileShape::Current().SetVecTile(std::min(NUM_8, bs), q_lora_rank);
}
Tensor normRes = RmsNorm(qMmRes, gammaCq, epsilonCq);
Tensor normDequantScale;
std::tuple<Tensor, Tensor> normQuantRes;
if (isQuant) {
if (hasSmooth) {
normQuantRes = Quant(normRes, true, true, smoothScalesCq);
} else {
normQuantRes = Quant(normRes);
}
normRes = std::get<0>(normQuantRes);
normDequantScale = std::get<1>(normQuantRes);
TileShape::Current().SetCubeTile({tieM, tieM}, {NUM_256, NUM_256}, {NUM_256, NUM_256});
} else {
TileShape::Current().SetCubeTile({tieM, tieM}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
}
Tensor q = Matrix::Matmul(dTypeQuantOut, normRes, wUqQr, false, false);
qkvPreRes.emplace_back(q);
TileShape::Current().SetCubeTile({tieM, tieM}, {NUM_256, NUM_256}, {NUM_64, NUM_64});
Tensor compressedKv;
if (splitK) {
Tensor kvMmResF32 = Matrix::Matmul(DT_FP32, input, wDkvKr);
TileShape::Current().SetVecTile(std::min(NUM_32, bs), NUM_64);
compressedKv = Cast(kvMmResF32, dType);
} else {
compressedKv = Matrix::Matmul(dType, input, wDkvKr);
}
Tensor compressedKvRes = Reshape(compressedKv, {b, s, wDkvKr.GetShape()[1]});
qkvPreRes.emplace_back(compressedKvRes);
if (isQuant) {
qkvPreRes.emplace_back(normDequantScale);
}
return qkvPreRes;
}
void MlaProlog(
Tensor tokenX, const Tensor& wDq, const Tensor& wUqQr, const Tensor& wUk, const Tensor& wDkvKr,
const Tensor& gammaCq, const Tensor& gammaCkv, const Tensor& sin, const Tensor& cos, const Tensor& cacheIndex,
Tensor& kvCache, Tensor& krCache, MlaQuantInputs quantInputs, const RoPETileShapeConfigNew& ropeConfig,
Tensor& queryOut, Tensor& queryRopeOut, Tensor& kvCacheOut, Tensor& krCacheOut, float epsilonCq, float epsilonCkv,
std::string cacheMode, bool splitReduceLastDim, bool splitK)
{
assert(
tokenX.GetShape().size() == SHAPE_DIM3 && wUk.GetShape().size() == SHAPE_DIM3 &&
sin.GetShape().size() == SHAPE_DIM3);
assert(kvCache.GetShape().size() == SHAPE_DIM4 && krCache.GetShape().size() == SHAPE_DIM4);
assert(cacheMode == "BNSD" || cacheMode == "PA_BSND" || cacheMode == "PA_NZ");
Tensor dequantScaleWUqQr = quantInputs.dequantScaleWUqQr;
bool isQuant = (dequantScaleWUqQr.GetStorage() != nullptr);
std::cout << "isQuant +++ " << isQuant << std::endl;
DataType dType = tokenX.GetStorage()->Datatype();
int b = tokenX.GetShape()[0];
int s = tokenX.GetShape()[1];
int bs = b * s;
int n = wUk.GetShape()[0];
int qkNopeHeadDim = wUk.GetShape()[1];
int kvLoraRank = wUk.GetShape()[2];
int qkRopeHeadDim = sin.GetShape()[2];
int qHeadDim = qkNopeHeadDim + qkRopeHeadDim;
auto qKv = QkvPre(tokenX, wDq, wUqQr, wDkvKr, gammaCq, epsilonCq, quantInputs, splitReduceLastDim, splitK);
Tensor q = qKv[0];
Tensor kvTmp = qKv[1];
if (isQuant) {
std::vector<int64_t> tileShape = {bs, NUM_64};
TileShape::Current().SetVecTile(tileShape);
auto qTmpFp32 = Cast(q, DataType::DT_FP32);
auto qTmpDequantScale = qKv[2];
auto qTmpDequantPerToken = Mul(qTmpFp32, qTmpDequantScale);
auto qTmpDequantChannel = Mul(qTmpDequantPerToken, dequantScaleWUqQr);
q = Cast(qTmpDequantChannel, dType);
}
auto qTmp = Reshape(q, {b, s, n, qHeadDim});
std::vector<int64_t> tileShape = {b, 1, 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor qNope = View(qTmp, {b, s, n, qkNopeHeadDim}, {0, 0, 0, 0});
tileShape = {b, 1, 1, NUM_128};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeRes = Reshape(qNope, {bs, n, qkNopeHeadDim});
tileShape = {bs, 1, qkNopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeTrans = Transpose(qNopeRes, {0, 1});
int c0 = NUM_16;
int m = (bs + c0 - 1) / c0 * c0;
TileShape::Current().SetCubeTile({m, m}, {NUM_128, NUM_128}, {NUM_128, NUM_128});
Tensor qNopeNew = Matrix::BatchMatmul(dType, qNopeTrans, wUk);
tileShape = {1, bs, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor qNopeNewTrans = Transpose(qNopeNew, {0, 1});
queryOut = Reshape(qNopeNewTrans, {b, s, n, kvLoraRank});
Tensor compressedKv = View(kvTmp, {b, s, kvLoraRank}, {0, 0, 0});
tileShape = {NUM_2, 1, NUM_512};
TileShape::Current().SetVecTile(tileShape);
Tensor compressedKvNorm = RmsNorm(compressedKv, gammaCkv, epsilonCkv);
Tensor qPe = View(qTmp, {b, s, n, qkRopeHeadDim}, {0, 0, 0, qkNopeHeadDim});
Tensor kPe = View(kvTmp, {b, s, qkRopeHeadDim}, {0, 0, kvLoraRank});
tileShape = {bs, 1, NUM_64};
TileShape::Current().SetVecTile(tileShape);
Tensor kPeRes = Reshape(kPe, {b, s, 1, qkRopeHeadDim});
Tensor kRope(kPeRes.GetStorage()->Datatype(), {b, s, 1, qkRopeHeadDim}, "kRope");
ApplyRotaryPosEmbV2(qPe, kPeRes, cos, sin, queryRopeOut, kRope, 2, ropeConfig);
if (cacheMode == "PA_BSND") {
int blockNum = kvCache.GetShape()[0];
int blockSize = kvCache.GetShape()[1];
int n2 = kvCache.GetShape()[2];
Tensor kvCacheRes = Reshape(kvCache, {blockNum * blockSize * n2, kvLoraRank});
Tensor krCacheRes = Reshape(krCache, {blockNum * blockSize * n2, qkRopeHeadDim});
Tensor kNope = Reshape(compressedKvNorm, {b * s, kvLoraRank});
Tensor kRopeRes = Reshape(kRope, {b * s * 1, qkRopeHeadDim});
tileShape = {1, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
Tensor kvCacheUpdate = ScatterUpdate(kvCacheRes, cacheIndex, kNope, -2, cacheMode);
kvCacheOut = Reshape(kvCacheUpdate, {blockNum, blockSize, n2, kvLoraRank});
tileShape = {1, qkRopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
Tensor krCacheUpdate = ScatterUpdate(krCacheRes, cacheIndex, kRopeRes, -2, cacheMode);
krCacheOut = Reshape(krCacheUpdate, {blockNum, blockSize, n2, qkRopeHeadDim});
} else {
Tensor kNope = Reshape(compressedKvNorm, {b, 1, s, kvLoraRank});
Tensor kRopeRes = Reshape(kRope, {b, 1, s, qkRopeHeadDim});
tileShape = {1, 1, 1, kvLoraRank};
TileShape::Current().SetVecTile(tileShape);
kvCacheOut = ScatterUpdate(kvCache, cacheIndex, kNope, -2);
tileShape = {1, 1, 1, qkRopeHeadDim};
TileShape::Current().SetVecTile(tileShape);
krCacheOut = ScatterUpdate(krCache, cacheIndex, kRopeRes, -2);
}
}
}