* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file deepseek_mla.h
* \brief
*/
#pragma once
#ifndef DEEPSEEK_MLA_H
#define DEEPSEEK_MLA_H
#include "interface/inner/pre_def.h"
#include "tilefwk/tilefwk.h"
#include "interface/inner/tilefwk.h"
namespace npu::tile_fwk {
constexpr int SCATTER_UPADATE_DIM = -2;
constexpr int NUM_1 = 1;
constexpr int NUM_2 = 2;
constexpr int NUM_3 = 3;
constexpr int NUM_4 = 4;
constexpr int NUM_8 = 8;
constexpr int NUM_16 = 16;
constexpr int NUM_20 = 20;
constexpr int NUM_24 = 24;
constexpr int NUM_32 = 32;
constexpr int NUM_48 = 48;
constexpr int NUM_64 = 64;
constexpr int NUM_128 = 128;
constexpr int NUM_256 = 256;
constexpr int NUM_384 = 384;
constexpr int NUM_512 = 512;
constexpr int NUM_1024 = 1024;
constexpr int NUM_1536 = 1536;
constexpr int NUM_1792 = 1792;
constexpr int NUM_4096 = 4096;
constexpr int NUM_6144 = 6144;
constexpr int NUM_8192 = 8192;
constexpr int NUM_7168 = 7168;
constexpr float F_1 = 1.0;
constexpr float F_0 = 0.0;
constexpr float F_NEGA_1 = -1.0;
constexpr double DF_1E_20 = 1e-20;
static std::map<std::string, std::variant<bool, int, float, std::string>> g_deepseekConfig = {
{"architectures", "DeepseekForCausalLM"},
{"attention_bias", false},
{"attentionDropout", 0},
{"AutoConfig", "DeepseekConfig"},
{"AutoModel", "DeepseekModel"},
{"AutoModelForCausalLM", "DeepseekForCausalLM"},
{"auxLossAlpha", 0.001f},
{"bosTokenId", 100000},
{"eosTokenId", 100001},
{"epSize", 1},
{"firstKDenseReplace", 3},
{"hiddenAct", "silu"},
{"hiddenSize", 7168},
{"initializerRange", 0.02f},
{"intermediateSize", 18432},
{"kvLoraRank", 512},
{"lmHead", false},
{"maxPositionEmbeddings", 4096},
{"modelType", "deepseek_v3"},
{"moeIntermediateSize", 2048},
{"moeLayerFreq", 1},
{"nGroup", 8},
{"nRoutedExperts", 256},
{"nSharedExperts", 1},
{"normTopkProb", true},
{"numAttentionHeads", 128},
{"numExpertsPerTok", 8},
{"numHiddenLayers", 61},
{"numKeyValueHeads", 128},
{"pretrainingTp", 1},
{"qLoraRank", 1536},
{"qkNopeHeadDim", 128},
{"qkRopeHeadDim", 64},
{"rmHead", false},
{"rmsNormEps", 1e-06f},
{"ropeScaling", 1},
{"ropeTheta", 10000},
{"routedScalingFactor", 2.5f},
{"scoringFunc", "sigmoid"},
{"seqAux", true},
{"tieWordEmbeddings", false},
{"topkGroup", 4},
{"topkMethod", "noaux_tc"},
{"torchDtype", "bfloat16"},
{"transformersVersion", "4.33.1"},
{"useCache", true},
{"vHeadDim", 128},
{"vocabSize", 129280},
{"fp8Format", "e4m3"},
{"initFp8Params", true}};
struct AttenTilingData {
std::vector<int> bmmVec;
std::vector<int> commonVec;
int kvLoraRank;
};
struct AttentionW {
Tensor qAProjW;
Tensor qBProjW;
Tensor qBProjWScale;
Tensor kvAProjWithMqaW;
Tensor kvBProjWK;
Tensor kvBProjWV;
Tensor oProjW;
};
class DeepseekAttention {
public:
DeepseekAttention(
std::map<std::string, std::variant<bool, int, float, std::string>> config, AttentionW aw, const int inLayerIdx);
Tensor Attention(Tensor q, Tensor kv, Tensor attenMask);
Tensor AttentionPost(Tensor attenRes);
Tensor AttentionPost2(Tensor attenRes);
std::tuple<Tensor, Tensor> QkvPre(Tensor hiddenStates);
std::tuple<Tensor, Tensor> QkvPreCv(Tensor hiddenStates);
std::vector<Tensor> QkvPre2(Tensor hiddenStates, bool isQuant = false);
std::tuple<Tensor, Tensor> QkvPreFp32(Tensor hiddenStates);
Tensor Forward(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig);
std::tuple<Tensor, Tensor> AtentionPreForward(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig);
std::tuple<Tensor, Tensor> AtentionPreForwardCv(
Tensor hiddenStates, Tensor attenMask, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen,
Tensor pastKeyStates, const RoPETileShapeConfig& ropeTileShapeConfig);
std::tuple<Tensor, Tensor> MlaPrologAbForward(Tensor hiddenStates, Tensor qPeRope, bool isQuant = false);
std::vector<Tensor> MlaPrologFoward(
Tensor hiddenStates, Tensor positionIds, Tensor cos, Tensor sin, Tensor kvLen, Tensor pastKeyStates,
const RoPETileShapeConfig& ropeTileShapeConfig, bool isQuant = false);
private:
int layerIdx = 0;
int attentionDropout = 0;
int hiddenSize = 0;
int numHeads = 0;
int maxPositionEmbeddings = 0;
int ropeTheta = 0;
int qLoraRank = 0;
int qkRopeHeadDim = 0;
int kvLoraRank = 0;
int vHeadDim = 0;
int qkNopeHeadDim = 0;
int qHeadDim = 0;
bool isCausal = true;
Tensor qAProjW;
Tensor qBProjW;
Tensor qBProjWScale;
Tensor kvAProjWithMqaW;
Tensor kvBProjWK;
Tensor kvBProjWV;
Tensor oProjW;
float softmaxScale = 0.0f;
};
class DeepseekV2MLP {
public:
explicit DeepseekV2MLP(std::map<std::string, std::variant<bool, int, float, std::string>> config)
{
hiddenSize = std::get<int>(config["hiddenSize"]);
intermediateSize = std::get<int>(config["intermediateSize"]);
gateProjW = Tensor(DataType::DT_FP32, {hiddenSize, intermediateSize});
upProjW = Tensor(DataType::DT_FP32, {hiddenSize, intermediateSize});
downProjW = Tensor(DataType::DT_FP32, {intermediateSize, hiddenSize});
}
DeepseekV2MLP(int hs, int is) : hiddenSize(hs), intermediateSize(is)
{
gateProjW = Tensor(DataType::DT_FP32, {hiddenSize, intermediateSize});
upProjW = Tensor(DataType::DT_FP32, {hiddenSize, intermediateSize});
downProjW = Tensor(DataType::DT_FP32, {intermediateSize, hiddenSize});
}
Tensor Forward(Tensor x);
Tensor Forward(Tensor x, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3);
Tensor ForwardWithQuant(
Tensor x, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3, Tensor ffnwight1Scale, Tensor ffnwight2Scale,
Tensor ffnwight3Scale);
private:
int hiddenSize = 0;
int intermediateSize = 0;
Tensor gateProjW;
Tensor upProjW;
Tensor downProjW;
};
class MoEGate {
public:
explicit MoEGate(std::map<std::string, std::variant<bool, int, float, std::string>> config)
{
nRoutedExperts = std::get<int>(config["nRoutedExperts"]);
nGroup = std::get<int>(config["nGroup"]);
topkGroup = std::get<int>(config["topkGroup"]);
numExpertsPerTok = std::get<int>(config["numExpertsPerTok"]);
int hiddenSize = std::get<int>(config["hiddenSize"]);
std::vector<int64_t> biasShape = {1, nRoutedExperts};
weight = Tensor(DataType::DT_FP32, {nRoutedExperts, hiddenSize});
eScoreCorrectionBias = Tensor(DataType::DT_FP32, biasShape, "eScoreCorrectionBias");
}
std::tuple<Tensor, Tensor> Forward(const Tensor& hiddenStates);
private:
int nRoutedExperts = 0;
int nGroup = 0;
int topkGroup = 0;
int numExpertsPerTok = 0;
Tensor weight;
Tensor eScoreCorrectionBias;
};
class DeepseekV2MoE {
public:
explicit DeepseekV2MoE(std::map<std::string, std::variant<bool, int, float, std::string>> config)
: expert(std::get<int>(config["hiddenSize"]), std::get<int>(config["moeIntermediateSize"])),
moeGate(config),
sharedExpert(
std::get<int>(config["hiddenSize"]),
std::get<int>(config["moeIntermediateSize"]) * std::get<int>(config["nSharedExperts"]))
{
numExpertsPerTok = std::get<int>(config["numExpertsPerTok"]);
epSize = 1;
expertsPerRank = std::get<int>(config["nRoutedExperts"]);
epRank = 0;
}
Tensor MoeInfer(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
int nRoutedExperts);
Tensor MoeInferSingleMlp(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
int nRoutedExperts);
Tensor MoeInferSingleMlpQuant(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
Tensor ffnwight1Scale, Tensor ffnwight2Scale, Tensor ffnwight3Scale, int nRoutedExperts);
Tensor MoeInfer(
Tensor x, Tensor topkIds, Tensor topkWeight, Tensor ffnWeight1, Tensor ffnWeight2, Tensor ffnWeight3,
Tensor& idxs, Tensor& sortedTokens, Tensor& outs, int nRoutedExperts);
Tensor MoeInfer(Tensor x, Tensor topkIds, Tensor topkWeight, int nRoutedExperts = 256);
Tensor Forward(Tensor hiddenStates);
private:
int numExpertsPerTok = 0;
int epSize = 0;
int expertsPerRank = 0;
int epRank = 0;
;
DeepseekV2MLP expert;
MoEGate moeGate;
DeepseekV2MLP sharedExpert;
};
}
#endif