"""DeepseekV3 models' APIs."""
import math
from enum import Enum
from typing import Tuple, Optional, Dict
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, mint, ops, Parameter, mutable
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.common.initializer import Zero
from mindspore.communication import get_group_size, get_rank
try:
from mindspore.ops.auto_generate import (FusedAddTopKDiv,
MoeDistributeDispatch,
MoeDistributeCombine)
MOE_FUSED_OP_VALID = True
except ImportError:
MOE_FUSED_OP_VALID = False
try:
from mindspore._checkparam import Validator
except ImportError:
import mindspore._checkparam as Validator
from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP
from research.deepseek3.utils import convert_model_config
from research.deepseek3.infer.norm import RMSNorm
from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import lazy_inline, check_fine_grain_interleave_valid, predict_lazy_inline,\
jit
from mindformers.parallel_core.training_graph.transformer.utils import LayerSetting
from mindformers.modules.layers import Linear, FreqsMgr, _check_input_dtype, _yarn_get_mscale
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.modules.transformer.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer import TransformerOpParallelConfig
from mindformers.modules.infer_attention import InferRotaryEmbedding, FlashAttention
from mindformers.tools.logger import logger
from mindformers.tools.utils import get_predict_run_mode
from mindformers.parallel_core.inference.parallel_state import (get_moe_expert_parallel_group,
get_tensor_model_parallel_group, get_world_group,
initialize_model_parallel, is_initialized)
from mindformers.parallel_core.inference.utils import get_tp_world_size, get_moe_tp_world_size, get_moe_ep_world_size
from mindformers.parallel_core.inference.tensor_parallel.mappings import (gather_from_model_parallel_region,
reduce_from_model_parallel_region,
reduce_scatter_to_model_parallel_region,
scatter_to_model_parallel_region)
from mindformers.version_control import is_910b
from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group
__all__ = ['InferenceDeepseekV3ForCausalLM', 'DeepseekV3Model']
class CacheConfig(Enum):
KEY_VALUE_CACHE = 0
KEY_CACHE = 1
KEY_VALUE_CACHE_KVSCALE_CACHE = 2
class MLAPagedAttentionMgr(nn.Cell):
r""" Paged Attention Manager for MLA, which only stores the cache of key_cache.
Args:
- **n_head** (int): The head num of query.
- **head_dim** (int): The dim of head.
- **n_kv_head** (int): The head num of key and value.
- **kv_shape** (tuple): Shape of key and value: math:`(num_blocks, block_size, self.n_kv_head, head_dim)`.
- **compute_dtype** (mstype): Compute dtype for infer attention. Default mstype.float16.
- **parallel_decoding** (mstype): If open parallel decoding. Default False.
- **scale_value** (mstype): The scale factor of score. Default None.
- **mla_v_dim** (int): The dim of value in Multi-Latent Attention. Default 512.
Inputs:
- **key** (Tensor[float16, bfloat16]) - The key tensor.
Input tensor of shape: math:`(B, S2, H2)` or math:`(B, N2, S2, D)`.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
- **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type
are the same as the key.
"""
def __init__(self,
n_heads,
head_dim,
n_kv_heads,
kv_shape,
compute_dtype=mstype.float16,
parallel_decoding=False,
scale_value=None,
mla_v_dim=512,
npu_mem_size=2):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.n_kv_heads = n_kv_heads
self.scale_value = 1 / math.sqrt(self.head_dim) if scale_value is None else scale_value
self.key_cache = None
self.npu_mem_size = npu_mem_size
if self.npu_mem_size > 0:
self.key_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="key_cache",
requires_grad=False)
self.reshape_and_cache = ops.auto_generate.ReshapeAndCache()
self.paged_attention = ops.auto_generate.PagedAttention(self.n_heads,
self.scale_value,
self.n_kv_heads,
mla_v_dim=mla_v_dim)
self.parallel_decoding = parallel_decoding
def construct(self, key, slot_mapping, key_cache=None):
"""The forward compute of single cache for Paged Attention."""
if self.npu_mem_size > 0:
return self.reshape_and_cache(key, None, self.key_cache, None, slot_mapping)
return self.reshape_and_cache(key, None, key_cache, None, slot_mapping)
def paged_attn(self, query, batch_valid_length, block_tables, attn_mask=None, q_seq_lens=None, key_cache=None):
"""The forward compute of Paged Attention."""
if self.npu_mem_size > 0:
return self.paged_attention(query, self.key_cache, self.key_cache, block_tables, batch_valid_length,
None, None, attn_mask, q_seq_lens)
return self.paged_attention(query, key_cache, key_cache, block_tables, batch_valid_length,
None, None, attn_mask, q_seq_lens)
class MLAInferAttention(nn.Cell):
r"""Multi-Latent-Attention Layer for infer.
This function contains the InferAttention primitives used with FlashAttention and PagedAttention for MLA infer.
B -- Batch size
S1 -- Sequence length of query. The value ranges from 1 to 32768 and is a multiple of 16.
S2 -- Sequence length of key and value. The value ranges from 1 to 32768 and is a multiple of 16.
N1 -- Num heads of query
N2 -- Num heads of key and value, and N2 must be a factor of N1
D -- Head size. Support value: 64, 80, 96, 120, 128 and 256.
H1 -- Hidden size of query, which equals to N1 * D
H2 -- Hidden size of key and value, which equals to N2 * D
Args:
n_head (int): The head num of query.
head_dim (int): The dim of head.
n_kv_head (int): The head num of key and value.
pa_n_head_split (int): The query head num of paged attention op after split.
pa_n_kv_head_split (int): The key and value head num of paged attention op after split.
keep_prob (float): The keep probability of dropout. Default: 1.0.
scale_value (float): The scale factor of score. Default: 1.0.
pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
Default: "BSH".
sparse_mode (int): Indicates sparse mode. Default 0.
- 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full attn_mask
matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and nextTokens needs to
be calculated.
- 1: Represents allMask, that is, passing in the complete attn_mask matrix.
- 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
vertex, and the optimized attn_mask matrix (2048*2048) is required.
- 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
right vertex, and the optimized attn_mask matrix (2048*2048) is required.
- 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
optimized attn_mask matrix (2048*2048) is required..
- 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of
each Batch axis is different. Not implemented yet.
- 6: Represents the global scenario, not implemented yet.
- 7: Represents the dilated scenario, not implemented yet.
- 8: Represents the block_local scenario, not implemented yet.
block_size (int): Block size for paged attention.
num_blocks (int): Block num for paged attention.
use_alibi_mask (bool): The value is True if alibi_mask is passed. Default: False.
use_rope_rotary_emb (bool): If use rotary embedding. Default True.
rotary_cos_format (int): Choose the rotary embedding cos format. Default 0.
rotary_dtype (mstype): Compute dtype for rope op. Default mstype.float16.
compute_dtype (mstype): Compute dtype for infer attention. Default mstype.float16.
parallel_decoding (mstype): If open parallel decoding. Default False.
prefill_head_dim (int): The dim of head for prefill attention. Default None.
Inputs:
- **query** (Tensor[float16, bfloat16]) - The query tensor.
Input tensor of shape :math:`(B, S1, H1)` or :math:`(B, N1, S1, D)`.
- **key** (Tensor[float16, bfloat16]) - The key tensor.
Input tensor of shape :math:`(B, S2, H2)` or :math:`(B, N2, S2, D)`.
- **value** (Tensor[float16, bfloat16]) - The value tensor.
Input tensor of shape :math:`(B, S2, H2)` or :math:`(B, N2, S2, D)`.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
- **freqs_cos** (Tensor[float16, bfloat16]) - The precompute freqs cos for rotary position embedding used in
attention, shape is (seq_len, head_dim).
- **freqs_sin** (Tensor[float16, bfloat16]) - The precompute freqs sin for rotary position embedding used in
attention, shape is (seq_len, head_dim).
- **attn_mask** (Union[Tensor[uint8], None]) - The attention mask tensor. For each element, 0 indicates
retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`,
:math:`(S1, S2)` or (2048, 2048).
- **alibi_mask** (Union[Tensor[float16, bfloat16], None]) - The position embedding code. If S is greater than
1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of the lower triangle for
memory optimization.
Input tensor of shape :math:`(B, N1, S1, S2)`, :math:`(1, N1, S1, S2)`, :math:`(B, N1, 1024, S2)`,
:math:`(1, N1, 1024, S2)` or (1024, 1024).
- **prefix_keys_values** (Union[Tensor[float16, bfloat16], None]) - The prefix keys values.
- **q_seq_lens** (Union[Tensor[int32], None]) - The query actual seq len.
Outputs:
- **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type
are the same as the query.
"""
def __init__(self,
n_head,
head_dim,
n_kv_head,
pa_n_head_split=None,
pa_n_kv_head_split=None,
keep_prob=1.0,
scale_value=1.0,
pre_tokens=2147483647,
next_tokens=2147483647,
sparse_mode=0,
block_size=16,
num_blocks=1024,
use_alibi_mask=False,
compute_dtype=mstype.float16,
parallel_decoding=False,
prefill_head_dim=None,
config: DeepseekV3Config = None
):
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.n_kv_head = n_kv_head
self.pa_n_head_split = pa_n_head_split if pa_n_head_split is not None else n_head
self.pa_n_kv_head_split = pa_n_kv_head_split if pa_n_kv_head_split is not None else n_kv_head
self.keep_prob = keep_prob
self.scale_value = scale_value
self.pre_tokens = pre_tokens
self.next_tokens = next_tokens
self.sparse_mode = sparse_mode
self.block_size = block_size
self.num_blocks = num_blocks
self.use_alibi_mask = use_alibi_mask
self.compute_dtype = compute_dtype
self.is_first_iteration = True
self.reshape = P.Reshape()
self.input_layout = "TH"
self.use_attention_mask = not self.use_alibi_mask
self.flash_attention = FlashAttention(head_num=self.n_head,
pre_tokens=self.pre_tokens,
next_tokens=self.next_tokens,
keep_prob=self.keep_prob,
scale_value=self.scale_value,
sparse_mode=self.sparse_mode,
use_attention_mask=self.use_attention_mask,
use_alibi_mask=self.use_alibi_mask,
input_layout=self.input_layout)
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
npu_mem_size = config.npu_mem_size if hasattr(config, 'npu_mem_size') else 2
self.paged_attention_mgr = MLAPagedAttentionMgr(self.pa_n_head_split,
self.head_dim,
self.pa_n_kv_head_split,
kv_shape,
compute_dtype=self.compute_dtype,
parallel_decoding=parallel_decoding,
scale_value=self.scale_value,
npu_mem_size=npu_mem_size)
self.prefill_head_dim = prefill_head_dim
def _prefill_attention(self, query, key, value, attn_mask, alibi_mask, actual_seq_qlen=None,
actual_seq_kvlen=None):
"""
prefill attention
"""
query = self.reshape(query, (-1, self.n_head * self.prefill_head_dim))
key = self.reshape(key, (-1, self.n_head * self.prefill_head_dim))
value = self.reshape(value, (-1, self.n_head * self.prefill_head_dim))
return self.flash_attention(query, key, value, attn_mask, alibi_mask, None, None,
actual_seq_qlen, actual_seq_kvlen)
def _incre_attention(self, query, batch_valid_length, block_tables, attn_mask, q_seq_lens, key_cache=None):
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables,
attn_mask, q_seq_lens, key_cache=key_cache)
def construct(self, query, key, value, batch_valid_length, block_tables,
attn_mask=None, alibi_mask=None, q_seq_lens=None, key_cache=None):
""" Forward process of the MLA Infer Attention Cell """
if self.is_first_iteration:
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, q_seq_lens,
batch_valid_length)
return self._incre_attention(query, batch_valid_length, block_tables,
attn_mask, q_seq_lens, key_cache=key_cache)
class DeepseekV3Attention(nn.Cell):
r"""
This is an implementation of self-attention mechanism in DeepSeek-V3.
Args:
- **dim** (int): The hidden size of the input.
- **n_heads** (int): The number of the heads.
- **n_kv_heads** (int): The number of key_value heads that should be used to implement
Grouped Query Attention.
- **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16.
Should be mstype.float32 or mstype.float16.
- **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype.
float32. Should be mstype.float32 or mstype.float16.
- **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not.
- **use_past** (bool): Use the past state to compute, used for incremental prediction.
For example, if we have two words and want to generate the ten more words.
We just need to compute the two words' state only once, and generate the next word one by one.
When use_past is True, there are two steps to run the prediction.
In the first step, set the is_first_iteration to be True by
`model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
pass the single step's input tensor, and loop it. Default True.
- **use_flash_attention** (bool): Whether to enable flash attention ops. Default True.
- **block_size** (int): The maximum number of tokens in one block can have when using paged attention.
Default 16.
- **num_blocks** (int): The maximum number of blocks when using paged attention. Default 512.
- **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
- **kv_lora_rank** (int): kv_lora_rank for Multi-Latent-Attention. Default 512.
- **q_lora_rank** (int): q_lora_rank for Multi-Latent-Attention. Default 1536.
- **qk_rope_head_dim** (int): qk_rope_head_dim for Multi-Latent-Attention. Default 64.
- **v_head_dim** (int): v_head_dim for Multi-Latent-Attention. Default 128.
- **qk_nope_head_dim** (int): qk_nope_head_dim for Multi-Latent-Attention. Default 128.
- **max_position_embeddings** (int): The maximum sequence length that this model might ever be used with.
Default 2048.
- **scaling_factor** (float): Scaling factor of Multi-Latent Attention. Default None.
- **norm_eps** (float): The epsilon value of the denominator. Default 1e-5.
- **layernorm_compute_dtype** (dtype.Number): The computation type of layernorm. Default mstype.float32.
- **delay_allreduce** (bool): Whether postpone the allreduce step. Default False.
Inputs:
- **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or
(batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
Otherwise, must be (batch_size, 1, hidden_size)
- **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention.
- **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size),
if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, head_dim, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, head_dim)).
"""
def __init__(self,
dim=512,
n_heads=8,
n_kv_heads=None,
compute_dtype=mstype.float16,
param_init_type=mstype.float32,
qkv_has_bias=False,
use_past=True,
use_flash_attention=True,
block_size=None,
num_blocks=None,
parallel_config=TransformerOpParallelConfig(),
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
max_position_embeddings=2048,
scaling_factor=None,
norm_eps=1e-5,
layernorm_compute_dtype=mstype.float32,
delay_allreduce=False,
config: DeepseekV3Config = None
):
super().__init__()
self.hidden_size = dim
self.tp_group = get_tensor_model_parallel_group()
self.tensor_parallel_group_size = get_tp_world_size()
self.n_head = n_heads
self.n_local_heads = n_heads // self.tensor_parallel_group_size
self.head_dim = dim // n_heads
self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
self.kv_dim = self.n_kv_head * self.head_dim
self.block_size = block_size
self.num_blocks = num_blocks
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.max_position_embeddings = max_position_embeddings
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.scaling_factor = scaling_factor
self.dtype = compute_dtype
self.is_first_iteration = True
self.use_past = use_past
self.use_flash_attention = use_flash_attention
self.qkv_concat = config.qkv_concat
if not self.use_past:
raise ValueError("For 'DeepseekV3Attention', the use_past must be enabled.")
if not self.use_flash_attention:
raise ValueError("For 'DeepseekV3Attention', the use_flash_attention must be enabled.")
if self.hidden_size % self.n_head != 0:
raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
f"of 'n_head', but got the hidden_size is {self.hidden_size} and "
f"the n_head is {self.n_head}.")
if self.n_kv_head % parallel_config.model_parallel != 0:
raise ValueError(f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} "
f"and the parallel_config.model_parallel is {parallel_config.model_parallel}.")
self.shape = P.Shape()
self.cast = P.Cast()
if self.q_lora_rank == 0:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.n_head * self.q_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype,
tp_group=self.tp_group,
)
self.kv2l = Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
else:
if self.qkv_concat:
self.qkv2l = Linear(
self.hidden_size,
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
else:
self.q2l_proj = Linear(
self.hidden_size,
self.q_lora_rank,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
self.kv2l = Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
self.lq_norm = RMSNorm(self.q_lora_rank, norm_eps, compute_type=layernorm_compute_dtype)
self.l2q_proj = ColumnParallelLinear(
self.q_lora_rank,
self.n_head * self.q_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype,
tp_group=self.tp_group,
)
self.lkv_norm = RMSNorm(self.kv_lora_rank, norm_eps, compute_type=layernorm_compute_dtype)
self.lkv2kv_k_nope = ColumnParallelLinear(
self.kv_lora_rank,
self.n_head * self.qk_nope_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype,
tp_group=self.tp_group,
)
self.lkv2kv_v = ColumnParallelLinear(
self.kv_lora_rank,
self.n_head * self.v_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype,
tp_group=self.tp_group,
)
self.wo = RowParallelLinear(
self.n_head * self.v_head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
config=parallel_config,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
delay_allreduce=delay_allreduce,
tp_group=self.tp_group,
)
self.scale_fa = 1. / math.sqrt(self.q_head_dim)
if self.scaling_factor is not None:
mscale_all_dim = self.scaling_factor.get("mscale_all_dim", 0)
factor = self.scaling_factor["factor"]
if mscale_all_dim:
mscale = _yarn_get_mscale(factor, mscale_all_dim)
self.scale_fa = mscale * mscale / (math.sqrt(self.q_head_dim))
self.reshape = P.Reshape()
self.tile_kv = P.Tile()
self.dim_slice_4d = P.Slice()
self.kpe_concat = P.Concat(1)
self.pe_concat = P.Concat(2)
self.qabsorb_matmul = P.BatchMatMul()
self.outabsorb_matmul = P.BatchMatMul(transpose_b=True)
self.infer_attention = MLAInferAttention(self.n_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
1,
scale_value=self.scale_fa,
next_tokens=0,
block_size=self.block_size,
num_blocks=self.num_blocks,
compute_dtype=compute_dtype,
prefill_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
config=config)
self.apply_rotary_emb = InferRotaryEmbedding(rotary_cos_format=2)
def construct(self, x: Tensor, freqs_cis: Tuple[Tensor, Tensor], mask=None, batch_valid_length=None,
block_tables=None, slot_mapping=None, q_seq_lens=None, key_cache=None):
""" Forward process of the DeepseekV3Attention. """
ori_dtype = x.dtype
if self.q_lora_rank == 0:
q = self.q_proj(x)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.qkv_concat:
qkv2l = self.qkv2l(x)
q, latent_kv, k_pe = ops.function.array_func.split_ext(
qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
else:
q = self.q2l_proj(x)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = ops.function.array_func.split_ext(
latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
q = self.reshape(q, (-1, self.n_local_heads, self.q_head_dim))
q_nope, q_pe = ops.function.array_func.split_ext(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
i_kv = self.lkv_norm(latent_kv)
q_pe = self.reshape(q_pe, (-1, self.n_local_heads * self.qk_rope_head_dim))
q_pe, k_pe = self.apply_rotary_emb(q_pe, k_pe, freqs_cis, batch_valid_length)
q_pe = self.reshape(q_pe, (-1, self.n_local_heads, self.qk_rope_head_dim))
key_states_cache = self.kpe_concat((i_kv, k_pe))
key_out = self.infer_attention.paged_attention_mgr(key_states_cache, slot_mapping, key_cache=key_cache)
q_nope = ops.depend(q_nope, key_out)
if self.is_first_iteration:
o_k_nope = self.lkv2kv_k_nope(i_kv)
o_v = self.lkv2kv_v(i_kv)
k_nope = self.reshape(o_k_nope, (-1, self.n_local_heads, self.qk_nope_head_dim))
value_states = self.reshape(o_v, (-1, self.n_local_heads, self.v_head_dim))
query_states = self.pe_concat((q_nope, q_pe))
k_pe = self.reshape(k_pe, (-1, 1, self.qk_rope_head_dim))
k_pe = self.tile_kv(k_pe, (1, self.n_local_heads, 1))
key_states = self.pe_concat((k_nope, k_pe))
value_states = self.pe_concat((value_states, k_pe))
context_layer = self.infer_attention(query_states, key_states, value_states, batch_valid_length,
block_tables, mask, q_seq_lens=q_seq_lens, key_cache=key_cache)
context_layer = context_layer.view(-1, self.n_local_heads, self.q_head_dim)
context_layer = self.dim_slice_4d(context_layer, (0, 0, 0), (-1, self.n_local_heads, self.v_head_dim))
else:
q_absorb = self.lkv2kv_k_nope.weight.view(self.n_local_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.lkv2kv_v.weight.view(self.n_local_heads, self.v_head_dim, self.kv_lora_rank)
q_nope = self.qabsorb_matmul(q_nope.transpose(1, 0, 2), q_absorb).transpose(1, 0, 2)
query_states = self.pe_concat((q_nope, q_pe))
query_states = self.reshape(query_states, (-1, self.n_local_heads *
(self.kv_lora_rank + self.qk_rope_head_dim)))
context_layer = self.infer_attention(query_states, None, None, batch_valid_length,
block_tables, attn_mask=mask, q_seq_lens=q_seq_lens,
key_cache=key_cache)
context_layer = context_layer.view(-1, self.n_local_heads, self.kv_lora_rank)
context_layer = self.outabsorb_matmul(context_layer.transpose(1, 0, 2), out_absorb).transpose(1, 0, 2)
attn_out = context_layer.view(-1, self.n_local_heads * self.v_head_dim)
output = self.wo(attn_out)
output = self.cast(output, ori_dtype)
return output
class DeepseekV3ParallelMLP(ParallelMLP):
r"""
Implementation of parallel feedforward block.
Args:
config (dict): Configuration.
is_expert (book): This block is an expert block. Default: False.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Outputs:
- **output** (Tensor) - Output tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, config, is_expert=False):
super().__init__(config)
self.add = P.Add()
if is_expert:
raise NotImplementedError("For ParallelMLP, `is_expert` is not supported for now.")
def construct(self, x, h):
""" Construct function of mlp block. """
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x)
gate, hidden = ops.function.array_func.split_ext(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x)
hidden = self.w3(x)
gate = self.act_func(gate.contiguous())
hidden = mint.mul(hidden, gate)
output = self.w2(hidden)
output = self.add(h, output)
return output
class DeepseekV3MoE(Cell):
r"""
This is an implementation of self-attention mechanism in DeepSeek-V3.
Args:
- **config** (Config): Model config of DeepSeek-V3.
Inputs:
- **x** (Tensor): Should be `[batch, seq_length, hidden_size]`. Float tensor.
Outputs:
- **output** (Tensor): The output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.parallel_config = config.parallel_config
self.moe_config = config.moe_config
self.moe_config.router_dense_type = config.router_dense_type
intermediate_size = self.moe_config.moe_intermediate_size
self.tp_group = get_tensor_model_parallel_group()
self.world_group = get_world_group()
ffn = RoutedParallelMLP(config)
if self.parallel_config.expert_parallel == 1:
self.routed_experts = ParallelMoEV2(ffn, self.config.hidden_size, self.moe_config)
else:
self.routed_experts = ExpertParallelMoE(ffn, self.config.hidden_size,
self.moe_config, self.config.parallel_config.use_alltoall,
self.config.compute_dtype)
self.attn_reduce_scatter = config.parallel_config.attn_reduce_scatter
self.attn_allgather = config.parallel_config.attn_allgather
self.attn_allreduce = config.parallel_config.attn_allreduce
self.ffn_allgather = config.parallel_config.ffn_allgather
self.ffn_allreduce = config.parallel_config.ffn_allreduce
self.ffn_reduce_scatter = config.parallel_config.ffn_reduce_scatter
self.use_alltoall = config.parallel_config.use_alltoall
if self.moe_config.shared_expert_num is not None:
intermediate_size = intermediate_size * self.moe_config.shared_expert_num
self.shared_experts = SharedMLP(config, intermediate_size) if self.use_alltoall else \
SharedParallelMLP(config, intermediate_size, tp_group=self.world_group)
self.add = P.Add()
def construct(self, x, attn_unpadding_idx, ffn_padding_idx, h):
""" Construct function of moe block. """
if self.attn_allgather:
x = gather_from_model_parallel_region(x, self.world_group, dim=0)
x = ops.gather(x, attn_unpadding_idx, 0)
output = self.routed_experts(x)
if self.moe_config.shared_expert_num is not None:
shared_res = self.shared_experts(x)
output = self.add(output, shared_res)
if self.ffn_allreduce:
output = reduce_from_model_parallel_region(output, self.world_group)
elif self.ffn_reduce_scatter:
output = ops.gather(output, ffn_padding_idx, 0)
output = reduce_scatter_to_model_parallel_region(output, self.world_group)
output = self.add(h, output)
if self.ffn_allgather:
output = gather_from_model_parallel_region(output, self.tp_group, dim=0)
return output
class DeepseekV3MoEWithMicroBatch(DeepseekV3MoE):
r"""
This is an implementation DualPipe based on self-attention mechanism in DeepSeek-V3.
Args:
- **config** (Config): Model config of DeepSeek-V3.
Inputs:
- **x** (Tensor): Should be `[batch, seq_length, hidden_size]`. Float tensor.
Outputs:
- **output** (Tensor): The output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
"""
def __init__(self, config):
super().__init__(config=config)
self.moe_tp_size = get_moe_tp_world_size()
self.moe_ep_size = get_moe_ep_world_size()
self.ep_rank_id = get_rank() // self.moe_tp_size
self.moe_expert_num = config.moe_config.expert_num
self.moe_ep_group = get_moe_expert_parallel_group()
self.dispatch_tp_world_size = 0 if is_910b() else 1
self.dispatch_shared_expert_num = 0 if is_910b() else 1
self.cast = P.Cast()
self.fused_add_topk_div = FusedAddTopKDiv()
self.dispatch = MoeDistributeDispatch()
self.dispatch.add_prim_attr("group", self.moe_ep_group)
self.dispatch.add_prim_attr("is_comm_op", True)
self.combine = MoeDistributeCombine()
self.combine.add_prim_attr("group", self.moe_ep_group)
self.combine.add_prim_attr("is_comm_op", True)
def moe_dispatch(self, x, expert_index):
"""moe dispatch"""
expand_x, _, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _ = \
self.dispatch(x=x,
expert_ids=expert_index,
ep_world_size=self.moe_ep_size,
ep_rank_id=self.ep_rank_id,
moe_expert_num=self.moe_expert_num,
group_ep=self.moe_ep_group,
tp_world_size=self.dispatch_tp_world_size,
shared_expert_num=self.dispatch_shared_expert_num,
global_bs=self.routed_experts.dispatch_global_max_bs*self.moe_ep_size,
expert_token_nums_type=1)
return expand_x, _, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, _
def moe_combine(self, ffn_res, expert_index, expand_idx, ep_recv_counts, expert_weight, tp_recv_counts):
"""moe combine"""
moe_output = self.combine(expand_x=ffn_res,
expert_ids=expert_index,
expand_idx=expand_idx,
ep_send_counts=ep_recv_counts,
expert_scales=expert_weight,
ep_world_size=self.moe_ep_size,
ep_rank_id=self.ep_rank_id,
moe_expert_num=self.moe_expert_num,
tp_send_counts=tp_recv_counts,
group_ep=self.moe_ep_group,
tp_world_size=self.dispatch_tp_world_size,
shared_expert_num=self.dispatch_shared_expert_num,
global_bs=self.routed_experts.dispatch_global_max_bs*self.moe_ep_size)
return moe_output
def gating_logits(self, x):
"""gating logits"""
gating_logits = self.routed_experts.gating(self.cast(x, self.routed_experts.router_dense_type))
gating_logits = self.routed_experts.cast(gating_logits, mstype.float32)
expert_weight, expert_index = self.fused_add_topk_div(gating_logits,
self.routed_experts.router.e_score_correction_bias,
self.routed_experts.num_experts_chosen,
self.routed_experts.topk_group,
self.routed_experts.group_topk_inner,
self.routed_experts.num_experts_chosen,
0,
True,
self.routed_experts.moe_config.routed_scaling_factor)
return expert_weight, expert_index
class AttentionReduceScatter(Cell):
r"""
This is an implementation of self-attention mechanism in DeepSeek-V3.
Args:
- **config** (Config): Model config of DeepSeek-V3.
Inputs:
- **x** (Tensor): Should be `[batch, seq_length, hidden_size]`. Float tensor.
Outputs:
- **output** (Tensor): The output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.compute_dtype = config.compute_dtype
self.hidden_size = config.hidden_size
self.model_parallel = config.parallel_config.model_parallel
self.moe_config = config.moe_config
self.is_first_iteration = True
self.tp_group = get_tensor_model_parallel_group()
self.reshape = P.Reshape()
def padding_with_idx(self, hidden_state, x, attn_padding_idx):
hidden_state = ops.gather(hidden_state, attn_padding_idx, 0)
x = ops.gather(x, attn_padding_idx, 0)
return hidden_state, x
def construct(self, hidden_state, x):
hidden_state = reduce_scatter_to_model_parallel_region(hidden_state, self.tp_group)
x = scatter_to_model_parallel_region(x, self.tp_group, dim=0)
return hidden_state, x
class DeepseekV3DecodeLayer(nn.Cell):
r"""
Transformer Layer. This is an implementation of the single layer of the transformer
encoder layer, including multihead attention and feedward layer.
Args:
- **layer_id** (int): The layer id of current transformer block layer.
- **dim** (int): The hidden size of the input.
- **num_heads** (int): The number of the heads.
- **n_kv_heads** (int): The number of key_value heads that should be used to implement
Grouped Query Attention.
- **norm_eps** (float): The epsilon value of the denominator. Default 1e-5.
- **compute_dtype** (dtype.Number): The computation type of the layer.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **layernorm_compute_type** (dtype.Number): The computation type of the norm.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **param_init_type** (dtype.Number): The parameter initialization type of the module.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not.
- **use_past** (bool): Use the past state to compute, used for incremental prediction.
For example, if we have two words and want to generate the ten more words.
We just need to compute the two words' state only once, and generate the next word one by one.
When use_past is True, there are two steps to run the prediction. In the first step,
set the is_first_iteration to be True by `model.add_flags_recursive(is_first_iteration=True)`,
and pass the full inputs. Then, set the is_first_iteration to be False by
`model.add_flags_recursive(is_first_iteration=False)`.
At this moment, pass the single step's input tensor, and loop it. Default True.
- **moe_config** (MoEConfig): The MoE configuration. Default: ``default_moe_config`` ,
an instance of `MoEConfig` with default args.
- **use_flash_attention** (bool): Whether to enable flash attention ops. Default True.
- **block_size** (int): The maximum number of tokens in one block can have when using paged attention.
Default 16.
- **num_blocks** (int): The maximum number of blocks when using paged attention. Default 512.
- **parallel_config** (OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
- **kv_lora_rank** (int): kv_lora_rank for Multi-Latent-Attention. Default 512.
- **q_lora_rank** (int): q_lora_rank for Multi-Latent-Attention. Default 1536.
- **qk_rope_head_dim** (int): qk_rope_head_dim for Multi-Latent-Attention. Default 64.
- **v_head_dim** (int): v_head_dim for Multi-Latent-Attention. Default 128.
- **qk_nope_head_dim** (int): qk_nope_head_dim for Multi-Latent-Attention. Default 128.
- **max_position_embeddings** (int): The maximum sequence length that this model might ever be used with.
Default 2048.
- **scaling_factor** (float): Scaling factor of Multi-Latent Attention. Default None.
- **config** (Config): Model config of DeepSeek-V3. Default None.
Inputs:
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
should be [batch_size, 1, hidden_size]
- **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention.
- **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`).
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past
is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size)
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, head_dim, seq_length),
(batch_size, num_heads, seq_length, head_dim)).
"""
@predict_lazy_inline
def __init__(self,
layer_id,
dim: int = 512,
n_heads: int = 8,
n_kv_heads: Optional[int] = None,
norm_eps: float = 1e-5,
compute_dtype=mstype.float32,
layernorm_compute_dtype=mstype.float32,
param_init_type=mstype.float32,
qkv_has_bias=False,
use_past=True,
moe_config=None,
use_flash_attention=True,
block_size: Optional[int] = None,
num_blocks: Optional[int] = None,
parallel_config=TransformerOpParallelConfig(),
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
max_position_embeddings=2048,
scaling_factor: Optional[Dict] = None,
config: DeepseekV3Config = None
):
super().__init__()
self.layer_id = layer_id
self.hidden_size = dim
self.n_head = n_heads
self.head_dim = self.hidden_size // self.n_head
self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
self.dtype = compute_dtype
self.is_first_iteration = True
self.use_past = use_past
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.add = P.Add()
self.ffn_norm = RMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype)
self.attention_norm = RMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype)
self.moe_config = config.moe_config
self.enable_micro_batch = config.enable_micro_batch
self.num_layers = config.num_layers
self.first_k_dense = (moe_config.first_k_dense_replace and layer_id < moe_config.first_k_dense_replace)
self.attn_delay_allreduce = not config.parallel_config.attn_allreduce and not self.first_k_dense
self.attention = DeepseekV3Attention(dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
qkv_has_bias=qkv_has_bias,
use_past=use_past,
use_flash_attention=use_flash_attention,
block_size=block_size,
num_blocks=num_blocks,
parallel_config=parallel_config,
kv_lora_rank=kv_lora_rank,
q_lora_rank=q_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
norm_eps=norm_eps,
layernorm_compute_dtype=layernorm_compute_dtype,
delay_allreduce=self.attn_delay_allreduce,
config=config)
self.expert_num = 1 if moe_config is None else moe_config.expert_num
self.shared_expert_num = 0 if moe_config is None else moe_config.shared_expert_num
if self.first_k_dense:
logger.warning("first_k_dense_replace is provided in MoEConfig, "
"a normal dense FFN will be used in this block.")
self.feed_forward = DeepseekV3ParallelMLP(config)
elif self.enable_micro_batch:
self.feed_forward = DeepseekV3MoEWithMicroBatch(config=config)
else:
self.feed_forward = DeepseekV3MoE(config=config)
self.predict_run_mode = get_predict_run_mode()
if self.predict_run_mode:
self.no_inline = False
self.attn_reduce_scatter = config.parallel_config.attn_reduce_scatter and not self.first_k_dense
self.attn_allgather = config.parallel_config.attn_allgather and not self.first_k_dense
self.use_alltoall = config.parallel_config.use_alltoall
self.need_padding = self.attn_reduce_scatter or \
(self.attn_delay_allreduce and not self.attn_reduce_scatter and not self.use_alltoall)
if self.need_padding:
self.attention_reduce_scatter = AttentionReduceScatter(config)
def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None,
slot_mapping=None, q_seq_lens=None, attn_padding_idx=None, attn_unpadding_idx=None,
ffn_padding_idx=None, ffn_unpadding_idx=None, key_cache=None):
""" Forward of transformer block. """
if not self.use_past:
self._check_input(x, freqs_cis, mask)
if self.enable_micro_batch and self.is_first_iteration and not self.first_k_dense:
return self._micro_batch_overlap(x, freqs_cis, mask, batch_valid_length, block_tables,
slot_mapping, q_seq_lens, key_cache)
input_x = self.attention_norm(x)
h = self.attention(input_x, freqs_cis, mask, batch_valid_length, block_tables,
slot_mapping, q_seq_lens, key_cache=key_cache)
if self.need_padding:
h, x = self.attention_reduce_scatter.padding_with_idx(h, x, attn_padding_idx)
if self.attn_reduce_scatter:
h, x = self.attention_reduce_scatter(h, x)
h = self.add(x, h)
ffn_norm = self.ffn_norm(h)
out = self.feed_forward(ffn_norm, h) if self.first_k_dense else \
self.feed_forward(ffn_norm, attn_unpadding_idx, ffn_padding_idx, h)
if self.attn_reduce_scatter:
out = ops.gather(out, ffn_unpadding_idx, 0)
return out
def _check_input(self, x, freqs_cis, mask):
r"""Check inputs"""
_check_input_dtype(
x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
freqs_cos, freqs_sin, swap_mask = freqs_cis
_check_input_dtype(freqs_cos.dtype, "freqs_cos",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
_check_input_dtype(freqs_sin.dtype, "freqs_sin",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
if swap_mask is not None:
_check_input_dtype(swap_mask.dtype, "swap_mask",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
if mask is not None:
_check_input_dtype(mask.dtype, "input_mask",
[mstype.float32, mstype.float16, mstype.bfloat16, mstype.uint8, mstype.bool_],
self.cls_name)
return True
def _micro_batch_overlap(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None,
slot_mapping=None, q_seq_lens=None, key_cache=None):
"""dual batch dual-batch step level"""
if self.layer_id == self.moe_config.first_k_dense_replace:
input_x_0 = self.attention_norm(x[0])
h_0 = self.attention(input_x_0, freqs_cis, mask, batch_valid_length[0], block_tables[0], slot_mapping[0],
q_seq_lens[0], key_cache=key_cache)
h_0_new = self.add(x[0], h_0)
else:
h_0_new = x[0]
ffn_norm_0 = self.ffn_norm(h_0_new)
expert_weight_0, expert_index_0 = self.feed_forward.gating_logits(ffn_norm_0)
expand_x_0, _, expand_idx_0, expert_token_nums_0, ep_recv_counts_0, tp_recv_counts_0, _ = \
self.feed_forward.moe_dispatch(x=ffn_norm_0,
expert_index=expert_index_0)
input_x_1 = self.attention_norm(x[1])
input_x_1 = ops.depend(input_x_1, expert_index_0)
h_1 = self.attention(input_x_1, freqs_cis, mask, batch_valid_length[1], block_tables[1], slot_mapping[1],
q_seq_lens[1], key_cache=key_cache)
h_1_new = self.add(x[1], h_1)
ffn_norm_1 = self.ffn_norm(h_1_new)
expert_weight_1, expert_index_1 = self.feed_forward.gating_logits(ffn_norm_1)
shared_res_0 = self.feed_forward.shared_experts(ffn_norm_0)
shared_res_0 = ops.depend(shared_res_0, expert_index_1)
expand_x_0 = ops.depend(expand_x_0, expert_index_1)
ffn_res_0 = self.feed_forward.routed_experts.ffn(expand_x_0, expert_token_nums_0)
ffn_norm_1 = ops.depend(ffn_norm_1, ffn_res_0)
expand_x_1, _, expand_idx_1, expert_token_nums_1, ep_recv_counts_1, tp_recv_counts_1, _ = \
self.feed_forward.moe_dispatch(x=ffn_norm_1,
expert_index=expert_index_1)
shared_res_1 = self.feed_forward.shared_experts(ffn_norm_1)
shared_res_1 = ops.depend(shared_res_1, ffn_res_0)
ffn_res_1 = self.feed_forward.routed_experts.ffn(expand_x_1, expert_token_nums_1)
out_0 = self.feed_forward.moe_combine(ffn_res=ffn_res_0, expert_index=expert_index_0, expand_idx=expand_idx_0,
ep_recv_counts=ep_recv_counts_0, expert_weight=expert_weight_0,
tp_recv_counts=tp_recv_counts_0)
out_0 = self.add(out_0, shared_res_0)
out_0 = self.add(out_0, h_0_new)
out_1 = self.feed_forward.moe_combine(ffn_res=ffn_res_1, expert_index=expert_index_1, expand_idx=expand_idx_1,
ep_recv_counts=ep_recv_counts_1, expert_weight=expert_weight_1,
tp_recv_counts=tp_recv_counts_1)
out_1 = self.add(out_1, shared_res_1)
out_1 = self.add(out_1, h_1_new)
if self.layer_id != self.num_layers - 1:
next_layer_x_0 = self.attention_norm(out_0)
next_layer_x_0 = ops.depend(next_layer_x_0, ffn_res_1)
next_layer_h_0 = self.attention(next_layer_x_0, freqs_cis, mask, batch_valid_length[0], block_tables[0],
slot_mapping[0], q_seq_lens[0], key_cache=key_cache)
next_layer_h_0_new = self.add(out_0, next_layer_h_0)
else:
next_layer_h_0_new = out_0
return next_layer_h_0_new, out_1
class DeepseekV3PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DeepseekV3Config
base_model_prefix = "deepseekv3"
class DeepseekV3Model(DeepseekV3PreTrainedModel):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
Args:
config(DeepseekV3Config): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
Returns:
output: Tensor, the output of deepseek decoderlayer
"""
def __init__(self,
config: DeepseekV3Config = None):
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
if config.batch_size or config.use_past:
Validator.check_positive_int(config.batch_size)
self.dtype = config.compute_dtype
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.n_head = config.num_heads
self.head_dim = self.hidden_size // self.n_head
self.pad_token_id = config.pad_token_id
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.max_position_embeddings = config.max_position_embeddings
self.is_first_iteration = True
self.use_past = config.use_past
self.moe_config = config.moe_config
self.enable_micro_batch = config.enable_micro_batch
self.is_dynamic = config.is_dynamic
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.tile = P.Tile()
self.expand_dims = P.ExpandDims()
self.gather = P.Gather()
self.slice = P.StridedSlice()
self.mb_split = ops.Slice()
self.freqs_mgr = FreqsMgr(head_dim=self.qk_rope_head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embeddings,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
is_dynamic=config.is_dynamic)
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_past=config.use_past)
if config.parallel_config.vocab_emb_dp:
self.tp_group = get_tensor_model_parallel_group()
self.tok_embeddings = VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
parallel_config=config.parallel_config,
init_method="normal",
init_type=config.param_init_type,
tp_group=self.tp_group)
else:
self.tok_embeddings = VocabEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
param_init_type=config.param_init_type,
param_init="normal")
self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave,
config.parallel_config)
self.layers = nn.CellList()
self.layer_setting = LayerSetting(config.num_layers,
config.offset,
config.parallel_config,
config.pp_interleave_num)
for layer_id in range(config.num_layers):
layer = DeepseekV3DecodeLayer(layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
n_kv_heads=config.n_kv_heads,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
param_init_type=config.param_init_type,
use_past=config.use_past,
use_flash_attention=config.use_flash_attention,
block_size=config.block_size,
num_blocks=config.num_blocks,
parallel_config=config.parallel_config,
moe_config=config.moe_config,
kv_lora_rank=config.kv_lora_rank,
q_lora_rank=config.q_lora_rank,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
qk_nope_head_dim=config.qk_nope_head_dim,
max_position_embeddings=config.max_position_embeddings,
scaling_factor=config.scaling_factor,
config=config)
self.layer_setting(layer, layer_id)
self.layers.append(layer)
self.norm_out = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, position_ids=None, q_seq_lens=None,
attention_mask=None, attn_padding_idx=None, attn_unpadding_idx=None, ffn_padding_idx=None,
ffn_unpadding_idx=None, key_cache=None):
"""
Forward of deepseekv3 model.
Args:
tokens: the tokenized inputs with datatype int32
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
batch_index(Tensor): The generated batch index when use continuous batching in LLM serving.
Tensor of shape :math:`(batch_size,)`. Default None.
zactivate_len(Tensor): The slice length of KVCache when use dynamic shape infer.
Tensor of shape :math:`(seq_length,)`. Default None.
block_tables(Tensor[int64]): Store mapping tables for each sequence.
slot_mapping(Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of deepseekv3 decoderlayer
"""
mask = attention_mask
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill()
else:
freqs_cis = self.freqs_mgr.chunk_with_decode(position_ids)
h = self.cast(self.tok_embeddings(tokens), self.dtype)
split_input = None
split_bvl = None
split_bt = None
split_sm = None
split_qsl = None
for i in range(self.num_layers):
key_cache_i = key_cache[i] if key_cache is not None else None
if (self.moe_config.first_k_dense_replace and i < self.moe_config.first_k_dense_replace) \
or not (self.enable_micro_batch and self.is_first_iteration):
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping,
q_seq_lens=q_seq_lens, attn_padding_idx=attn_padding_idx,
attn_unpadding_idx=attn_unpadding_idx, ffn_padding_idx=ffn_padding_idx,
ffn_unpadding_idx=ffn_unpadding_idx, key_cache=key_cache_i)
else:
if i == self.moe_config.first_k_dense_replace:
split_input, split_bvl, split_bt, split_sm, split_qsl = self._split_micro_batch_input(h, \
batch_valid_length, block_tables, slot_mapping, q_seq_lens)
split_input = self.layers[i](split_input, freqs_cis, mask, batch_valid_length=split_bvl,
block_tables=split_bt, slot_mapping=split_sm,
q_seq_lens=split_qsl, attn_padding_idx=attn_padding_idx,
attn_unpadding_idx=attn_unpadding_idx, ffn_padding_idx=ffn_padding_idx,
ffn_unpadding_idx=ffn_unpadding_idx, key_cache=key_cache_i)
if i == self.num_layers - 1:
h = mint.concat((split_input[0], split_input[1]), dim=0)
output = self.norm_out(h)
return output
def _split_micro_batch_input(self, h, batch_valid_length, block_tables, slot_mapping, q_seq_lens):
"""split micro batch input tensors"""
seq_split_idx, token_split_idx = self._get_split_index(q_seq_lens)
split_input_0 = self.mb_split(h, (0, 0), (token_split_idx, h.shape[1]))
split_input_1 = self.mb_split(h, (token_split_idx, 0), (h.shape[0]-token_split_idx, h.shape[1]))
split_bvl_0 = self.mb_split(batch_valid_length, (0,), (seq_split_idx,))
split_bvl_1 = self.mb_split(batch_valid_length, (seq_split_idx,), \
(batch_valid_length.shape[0]-seq_split_idx,))
split_bt_0 = self.mb_split(block_tables, (0, 0), (seq_split_idx, block_tables.shape[1]))
split_bt_1 = self.mb_split(block_tables, (seq_split_idx, 0), \
(block_tables.shape[0]-seq_split_idx, block_tables.shape[1]))
split_sm_0 = self.mb_split(slot_mapping, (0,), (token_split_idx,))
split_sm_1 = self.mb_split(slot_mapping, (token_split_idx,), (slot_mapping.shape[0]-token_split_idx,))
split_qsl_0 = self.mb_split(q_seq_lens, (0,), (seq_split_idx,))
split_qsl_1 = self.mb_split(q_seq_lens, (seq_split_idx,), (q_seq_lens.shape[0]-seq_split_idx,))
return (split_input_0, split_input_1), (split_bvl_0, split_bvl_1), (split_bt_0, split_bt_1), \
(split_sm_0, split_sm_1), (split_qsl_0, split_qsl_1)
def _get_split_index(self, q_seq_lens):
"""compute the minimum difference sum-tokens of the tensor after splitting"""
seq_split_idx = 0
token_split_idx = 0
if self.is_first_iteration:
total_sum = ops.ReduceSum()(q_seq_lens)
prefix_sums = ops.CumSum()(q_seq_lens, 0)[:-1]
suffix_sums = total_sum - prefix_sums
suffix_sums = prefix_sums
diffs = ops.Abs()(prefix_sums - suffix_sums)
seq_split_idx = ops.Argmin()(diffs)
token_split_idx = prefix_sums[seq_split_idx]
seq_split_idx += 1
else:
seq_split_idx = q_seq_lens.shape[0] // 2
token_split_idx = seq_split_idx
return int(seq_split_idx), int(token_split_idx)
class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
r"""
Provide DeepseekV3 logits through network.
Args:
config (DeepseekV3Config): The config of DeepseekV3 model.
Inputs:
input_ids(Tensor): The tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
labels(Tensor): The tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
input_position(Tensor): Current position, used by model.predict.
position_ids(Tensor): Reserved param, not used.
attention_mask(Tensor): Reserved param, not used.
input_embeds(Tensor): Reserved param, not used.
init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Default True.
batch_valid_length(Tensor): The past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
batch_index(Tensor): The generated batch index when use continuous batching in LLM serving.
Tensor of shape :math:`(batch_size,)`. Default None.
zactivate_len(Tensor): The slice length of KVCache when use dynamic shape infer.
Tensor of shape :math:`(seq_length,)`. Default None.
block_tables(Tensor, optional): Int64 type Tensor, store mapping tables for each sequence. Default None.
slot_mapping(Tensor, optional): Int32 type Tensor, token cache physical slot index. Default None.
Returns:
Tensor. If it is in prediction mode, the output Tensor contains logits;
If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks.
"""
@lazy_inline
def __init__(self, config: DeepseekV3Config = None):
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = convert_model_config(config)
self.parallel_config = self.config.parallel_config
self.npu_mem_size = config.npu_mem_size if hasattr(config, 'npu_mem_size') else 2
if not is_initialized():
initialize_model_parallel(pipeline_model_parallel_size=self.parallel_config.pipeline_model_parallel_size,
expert_model_parallel_size=self.parallel_config.expert_parallel,
tensor_model_parallel_size=self.parallel_config.tensor_model_parallel_size,
data_parallel_size=self.parallel_config.data_parallel,
order='tp-ep-dp-pp',)
self.tp_group = get_tensor_model_parallel_group()
config = self.update_comm_config(config)
self.seq_length = config.seq_length
self.ignore_token_id = config.ignore_token_id
self.pad_token_id = config.pad_token_id
self.use_past = config.use_past
self.vocab_size = config.vocab_size
self.is_first_iteration = True
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.slice = P.StridedSlice()
self.not_equal = P.NotEqual()
self.mul = P.Mul()
self.add = P.Add()
self.ones = P.Ones()
self.gather = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.model = DeepseekV3Model(config=config)
if config.parallel_config.vocab_emb_dp:
self.lm_head = Linear(
in_channels=config.hidden_size,
out_channels=config.vocab_size,
weight_init="normal",
has_bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype
)
else:
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
config=config.parallel_config,
bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
weight_init="normal",
gather_output=True,
tp_group=self.tp_group,
)
self.prefill_gather_flatten = P.Gather()
self.load_checkpoint(config)
self.predict_run_mode = get_predict_run_mode()
logger.info(f"Predict run mode:{self.predict_run_mode}")
self.return_hidden_states = config.return_hidden_states
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
""" Get deepseekv3 model input tuple for transform ckpt. """
input_ids = Tensor(input_ids, mstype.int32)
labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None
bs, seq = input_ids.shape[0], input_ids.shape[1]
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
return input_ids, labels, None, None, None, None, None, None, None, None, None, \
slot_mapping
def update_comm_config(self, config):
"""update communication config"""
config.parallel_config.attn_reduce_scatter = False
config.parallel_config.attn_allgather = False
config.parallel_config.attn_allreduce = True
config.parallel_config.ffn_reduce_scatter = False
config.parallel_config.ffn_allgather = False
config.parallel_config.ffn_allreduce = True
config.parallel_config.use_alltoall = False
world_size = get_group_size()
self.tp_size = self.parallel_config.tensor_model_parallel_size
self.dp_size = world_size // self.tp_size
self.moe_ep_size = self.parallel_config.expert_parallel
self.moe_tp_size = world_size // self.moe_ep_size
if self.dp_size > 1 and self.tp_size == 1:
if self.moe_tp_size == 1:
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_allreduce = False
config.parallel_config.use_alltoall = True
else:
config.parallel_config.attn_allgather = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_reduce_scatter = True
config.parallel_config.ffn_allreduce = False
config.enable_micro_batch = False
elif self.dp_size > 1:
if self.moe_tp_size == 1:
config.parallel_config.attn_reduce_scatter = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_allgather = True
config.parallel_config.ffn_allreduce = False
config.parallel_config.use_alltoall = True
config.enable_micro_batch = False
else:
config.parallel_config.attn_reduce_scatter = True
config.parallel_config.attn_allgather = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_reduce_scatter = True
config.parallel_config.ffn_allgather = True
config.parallel_config.ffn_allreduce = False
config.enable_micro_batch = False
return config
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""
prepare inputs for generation.
A model class needs to define a `prepare_inputs_for_generation` method
in order to use `.generate()`
"""
model_inputs = {"input_ids": Tensor.from_numpy(input_ids.astype(np.int32))}
batch_valid_length = kwargs.get("valid_length_each_example")
prefill = kwargs.get("prefill")
if self.config.is_dynamic:
if prefill and "origin_inputs" in kwargs:
origin_inputs = kwargs["origin_inputs"]
slot_mapping = kwargs.get("slot_mapping")
model_inputs = self._prepare_inputs_for_prefill_flatten(origin_inputs,
batch_valid_length,
slot_mapping,
model_inputs)
position_ids = batch_valid_length - 1
model_inputs["position_ids"] = ms.Tensor(position_ids, dtype=ms.int32).reshape(-1)
if not prefill:
q_seq_lens = np.ones(batch_valid_length.shape, dtype=np.int32).reshape(-1)
else:
q_seq_lens = batch_valid_length.astype(np.int32).reshape(-1)
model_inputs["q_seq_lens"] = Tensor.from_numpy(q_seq_lens)
model_inputs["attention_mask"] = self.model.casual_mask.gen_attention_mask(prefill)
model_inputs["need_flatten"] = True
return model_inputs
def update_padding_index_to_inputs(self, model_inputs):
"""generate padding index in tp region."""
if not hasattr(self.config, 'parallel_config') or \
not hasattr(self.config.parallel_config, 'data_parallel') or \
not hasattr(self.config.parallel_config, 'model_parallel'):
return model_inputs
dp_size = self.config.parallel_config.data_parallel
tp_size = self.config.parallel_config.model_parallel
q_seq_len = model_inputs.get("q_seq_lens", None)
if dp_size == 1 or q_seq_len is None:
return model_inputs
tokens_len_per_dp = q_seq_len.sum().reshape(-1)
tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp)
tokens_len_per_dp = tokens_len_per_dp.asnumpy()
padding_size = (tokens_len_per_dp.max() + tp_size - 1) // tp_size * tp_size
dp_rank_id = get_rank() // tp_size
attn_padding_idx = None
attn_unpadding_idx = None
ffn_padding_idx = None
ffn_unpadding_idx = None
last_arange_index = 0
for dp_rank, tokens_length in enumerate(tokens_len_per_dp):
arange_data = np.arange(0, int(tokens_length), dtype=np.int32)
if dp_rank == dp_rank_id:
ffn_unpadding_idx = arange_data
pad = np.zeros(padding_size - arange_data.shape[0], dtype=np.int32)
attn_padding_idx = np.concatenate((arange_data, pad), axis=0)
if dp_rank == 0:
attn_unpadding_idx = arange_data
last_arange_index = arange_data[-1]
pad = np.zeros(padding_size - attn_unpadding_idx.shape[0], dtype=np.int32)
ffn_padding_idx = np.concatenate((attn_unpadding_idx, pad), axis=0)
else:
attn_offset_idx = arange_data + padding_size * dp_rank
attn_unpadding_idx = np.concatenate((attn_unpadding_idx, attn_offset_idx), axis=0)
ffn_offset_idx = arange_data + last_arange_index + 1
last_arange_index = ffn_offset_idx[-1]
pad = np.zeros(padding_size - ffn_offset_idx.shape[0], dtype=np.int32)
ffn_padding_idx = np.concatenate((ffn_padding_idx, ffn_offset_idx, pad), axis=0)
model_inputs["attn_padding_idx"] = ms.from_numpy(attn_padding_idx)
model_inputs["attn_unpadding_idx"] = ms.from_numpy(attn_unpadding_idx)
model_inputs["ffn_padding_idx"] = ms.from_numpy(ffn_padding_idx)
model_inputs["ffn_unpadding_idx"] = ms.from_numpy(ffn_unpadding_idx)
return model_inputs
def set_dynamic_inputs(self, **kwargs):
""" Mindspore's feature, Set dynamic input for DeepseekV3. """
dynamic_input_ids = Tensor(shape=[None], dtype=mstype.int32)
dynamic_init_reset = True
dynamic_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
dynamic_position_ids = Tensor(shape=[None], dtype=mstype.int32)
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32)
dynamic_attention_mask = Tensor(shape=[None, None], dtype=mstype.bfloat16)
dynamic_attn_padding_idx = None
dynamic_attn_unpadding_idx = None
dynamic_ffn_padding_idx = None
dynamic_ffn_unpadding_idx = None
if self.dp_size > 1:
dynamic_attn_padding_idx = Tensor(shape=[None], dtype=mstype.int32)
dynamic_attn_unpadding_idx = Tensor(shape=[None], dtype=mstype.int32)
dynamic_ffn_padding_idx = Tensor(shape=[None], dtype=mstype.int32)
dynamic_ffn_unpadding_idx = Tensor(shape=[None], dtype=mstype.int32)
def get_input():
if self.npu_mem_size > 0:
return None
cache_list = []
for _ in self.model.layers:
cache_list.append(Tensor(shape=[None, None, None, None], dtype=self.config.compute_dtype))
return mutable(cache_list)
key_cache = get_input()
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_attention_mask, None,
dynamic_init_reset, dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, dynamic_q_seq_lens, dynamic_attn_padding_idx,
dynamic_attn_unpadding_idx, dynamic_ffn_padding_idx, dynamic_ffn_unpadding_idx,
key_cache, None)
logger.info("Set dynamic input for DeepseekV3.")
def pre_gather_func(self, pre_gather, output, batch_valid_length):
"""Pre gather operation in infer mode."""
if not pre_gather:
return output
if pre_gather:
if self.config.is_dynamic:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 0)
else:
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
return output
@jit
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, q_seq_lens=None, attn_padding_idx=None,
attn_unpadding_idx=None, ffn_padding_idx=None, ffn_unpadding_idx=None, key_cache=None,
value_cache=None):
""" DeepseekV3ForCausalLM forward. """
output = self.model(input_ids, batch_valid_length, batch_index, zactivate_len, block_tables,
slot_mapping, position_ids, q_seq_lens, attention_mask, attn_padding_idx,
attn_unpadding_idx, ffn_padding_idx, ffn_unpadding_idx, key_cache=key_cache)
if self.return_hidden_states:
return output
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
output = self.pre_gather_func(pre_gather, output, batch_valid_length)
logits = self.lm_head(output)
input_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), mstype.float32)
if labels is not None and labels.ndim > 1:
label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32)
input_mask = self.mul(input_mask, label_mask)
logits = self.cast(logits, mstype.float32)
if self.predict_run_mode:
logits = self.reshape(logits, (-1, logits.shape[-1]))
return logits
return logits, input_ids, input_mask
def kvcache(self, layer_idx):
"""Get the key_cache depend on layer_idx."""
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
return key_cache, None
def add_flags_enable_micro_batch(self, enable_micro_batch):
"""Add customized attributes for specific cells in the model when the use_past is enabled."""
self.model.add_flags(enable_micro_batch=enable_micro_batch)
for layer in self.model.layers:
layer.add_flags(enable_micro_batch=enable_micro_batch)
class DeepseekV3MTPLayer(nn.Cell):
r"""
DeepseekV3 MTP layer consisting of a [`DeepseekV3DecoderLayer`]
Args:
config(DeepseekV3Config): the config of network
Inputs:
emb(Tensor): Float Tensor, shape should be [batch_size, seq_length, vercab_size]
batch_valid_length(Tensor): The past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
freqs_cis(Tensor): The precomputed freqs and mask for rotary position embedding used in attention.
mask(Tensor): Tensor with shape (batch_size, src_seq_length, tgt_seq_length).
block_tables(Tensor, optional): Int64 type Tensor, store mapping tables for each sequence. Default None.
slot_mapping(Tensor, optional): Int32 type Tensor, token cache physical slot index. Default None.
q_seq_lens(Tensor): Int32 type Tensor, indicating the sequence length of query tokens.
Returns:
output: Tensor, the output of deepseek decoderlayer
"""
def __init__(self, config: DeepseekV3Config = None):
super().__init__()
self.enorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
self.hnorm = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
self.concat = P.Concat(axis=-1)
self.tp_group = get_tensor_model_parallel_group()
self.eh_proj = ColumnParallelLinear(config.hidden_size * 2,
config.hidden_size,
config=config.parallel_config,
bias=False,
gather_output=True,
param_init_type=config.param_init_type,
weight_init="normal",
compute_dtype=config.compute_dtype,
tp_group=self.tp_group)
self.decode_layer = DeepseekV3DecodeLayer(config.num_layers,
dim=config.hidden_size,
n_heads=config.num_heads,
n_kv_heads=config.n_kv_heads,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
param_init_type=config.param_init_type,
use_past=config.use_past,
use_flash_attention=config.use_flash_attention,
block_size=config.block_size,
num_blocks=config.num_blocks,
parallel_config=config.parallel_config,
moe_config=config.moe_config,
kv_lora_rank=config.kv_lora_rank,
q_lora_rank=config.q_lora_rank,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
qk_nope_head_dim=config.qk_nope_head_dim,
max_position_embeddings=config.max_position_embeddings,
scaling_factor=config.scaling_factor,
config=config)
def construct(self, emb, hidden_states, freqs_cis, mask=None, batch_valid_length=None, block_tables=None,
slot_mapping=None, q_seq_lens=None, key_cache=None):
""" Forward of DeepSeekV3MTP layer. """
norm_emb = self.enorm(emb)
norm_h = self.hnorm(hidden_states)
cat_h = self.concat([norm_emb, norm_h])
h = self.eh_proj(cat_h)
hidden_states = self.decode_layer(h, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping, q_seq_lens=q_seq_lens,
key_cache=key_cache)
return hidden_states
class DeepseekV3MTPModel(DeepseekV3PreTrainedModel):
r"""
DeepseekV3MTP model consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3MTPLayer`]
Args:
config(DeepseekV3Config): the config of network
Inputs:
input_ids(Tensor): The tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
hidden_states(Tensor): The output hidden_states of deepseekv3, Tensor of shape :math:`(batch, seq\_length)`.
batch_valid_length(Tensor): The past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables(Tensor, optional): Int64 type Tensor, store mapping tables for each sequence. Default None.
slot_mapping(Tensor, optional): Int32 type Tensor, token cache physical slot index. Default None.
position_ids(Tensor): Int32 type Tensor, indicating the position of each token.
attention_mask(Tensor): Tensor with shape (batch_size, src_seq_length, tgt_seq_length).
q_seq_lens(Tensor): Int32 type Tensor, indicating the sequence length of query tokens.
Returns:
output: Tensor, the output of deepseek decoderlayer
"""
def __init__(self, config: DeepseekV3Config = None):
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.use_past = config.use_past
self.is_first_iteration = True
self.prefill_gather_flatten = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
self.tok_embeddings = VocabEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
param_init_type=config.param_init_type,
param_init="normal")
self.freqs_mgr = FreqsMgr(head_dim=config.qk_rope_head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embeddings,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
is_dynamic=config.is_dynamic)
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_past=config.use_past)
self.layer = DeepseekV3MTPLayer(config)
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
self.tp_group = get_tensor_model_parallel_group()
self.head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
config=config.parallel_config,
bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
weight_init="normal",
gather_output=True,
tp_group=self.tp_group)
def construct(self, input_ids, hidden_states, batch_valid_length, block_tables, slot_mapping,
position_ids=None, attention_mask=None, q_seq_lens=None, key_cache=None):
"""
Forward of deepseekv3 model.
"""
emb = self.cast(self.tok_embeddings(input_ids), self.dtype)
hidden_states = self.cast(hidden_states, self.dtype)
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
mask = attention_mask
if self.is_first_iteration:
if self.is_pynative:
bs, seq_len = self.shape(input_ids)
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
mask = self.casual_mask(input_ids)
else:
freqs_cis = self.freqs_mgr.prefill()
else:
freqs_cis = self.freqs_mgr.chunk_with_decode(position_ids)
key_cache = key_cache[0] if key_cache is not None else None
hidden_states = self.layer(emb, hidden_states, freqs_cis, mask, batch_valid_length,
block_tables, slot_mapping, q_seq_lens, key_cache)
output = self.norm(hidden_states)
return output
class InferenceDeepseekV3MTPForCausalLM(DeepseekV3PreTrainedModel):
r"""
Provide DeepseekV3 logits through network.
Args:
config (DeepseekV3Config): The config of DeepseekV3 model.
Inputs:
input_ids(Tensor): The tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
hidden_states(Tensor): The output hidden_states of deepseekv3, Tensor of shape :math:`(batch, seq\_length)`.
batch_valid_length(Tensor): The past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables(Tensor, optional): Int64 type Tensor, store mapping tables for each sequence. Default None.
slot_mapping(Tensor, optional): Int32 type Tensor, token cache physical slot index. Default None.
position_ids(Tensor): Int32 type Tensor, indicating the position of each token.
attention_mask(Tensor): Tensor with shape (batch_size, src_seq_length, tgt_seq_length).
q_seq_lens(Tensor): Int32 type Tensor, indicating the sequence length of query tokens.
Returns:
Tensor. If it is in prediction mode, the output Tensor contains logits;
If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks.
"""
def __init__(self, config: DeepseekV3Config = None):
super().__init__(config, auto_prefix=True)
self.dtype = config.compute_dtype
self.config = convert_model_config(config)
self.parallel_config = self.config.parallel_config
self.npu_mem_size = config.npu_mem_size
if not is_initialized():
initialize_model_parallel(pipeline_model_parallel_size=self.parallel_config.pipeline_model_parallel_size,
expert_model_parallel_size=self.parallel_config.expert_parallel,
tensor_model_parallel_size=self.parallel_config.tensor_model_parallel_size,
data_parallel_size=self.parallel_config.data_parallel,
order='tp-ep-dp-pp',)
config = self.update_comm_config(config)
self.mtp_model = DeepseekV3MTPModel(config)
def update_comm_config(self, config):
"""update communication config"""
config.parallel_config.attn_reduce_scatter = False
config.parallel_config.attn_allgather = False
config.parallel_config.attn_allreduce = True
config.parallel_config.ffn_reduce_scatter = False
config.parallel_config.ffn_allgather = False
config.parallel_config.ffn_allreduce = True
config.parallel_config.use_alltoall = False
world_size = get_group_size()
self.tp_size = self.parallel_config.tensor_model_parallel_size
self.dp_size = world_size // self.tp_size
self.moe_ep_size = self.parallel_config.expert_parallel
self.moe_tp_size = world_size // self.moe_ep_size
if self.dp_size > 1 and self.tp_size == 1:
if self.moe_tp_size == 1:
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_allreduce = False
config.parallel_config.use_alltoall = True
else:
config.parallel_config.attn_allgather = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_reduce_scatter = True
config.parallel_config.ffn_allreduce = False
elif self.dp_size > 1:
if self.moe_tp_size == 1:
config.parallel_config.attn_reduce_scatter = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_allgather = True
config.parallel_config.ffn_allreduce = False
config.parallel_config.use_alltoall = True
else:
config.parallel_config.attn_reduce_scatter = True
config.parallel_config.attn_allgather = True
config.parallel_config.attn_allreduce = False
config.parallel_config.ffn_reduce_scatter = True
config.parallel_config.ffn_allgather = True
config.parallel_config.ffn_allreduce = False
return config
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""Get deepseekv3 model input tuple for transform ckpt."""
input_ids = Tensor(input_ids, mstype.int32)
bs, seq = input_ids.shape[0], input_ids.shape[1]
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
return input_ids, None, None, None, slot_mapping, None, None, None
def set_dynamic_inputs(self, **kwargs):
"""Mindspore's feature, Set dynamic input for DeepSeekV3-MTP."""
dynamic_input_ids = Tensor(shape=[None], dtype=mstype.int32)
dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.dtype)
dynamic_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
dynamic_position_ids = Tensor(shape=[None], dtype=mstype.int32)
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32)
dynamic_attention_mask = Tensor(shape=[None, None], dtype=mstype.bfloat16)
dynamic_key_cache = None if self.npu_mem_size > 0 else \
mutable([Tensor(shape=[None, None, None, None], dtype=self.dtype)])
self.set_inputs(dynamic_input_ids, dynamic_hidden_states, dynamic_batch_valid_length, dynamic_block_tables,
dynamic_slot_mapping, dynamic_position_ids, dynamic_attention_mask, dynamic_q_seq_lens,
dynamic_key_cache, None)
logger.info("Set dynamic input for DeepSeekV3-MTP.")
@jit
def construct(self, input_ids, hidden_states, batch_valid_length, block_tables, slot_mapping, position_ids=None,
attention_mask=None, q_seq_lens=None, key_cache=None, value_cache=None):
""" DeepseekV3ForCausalLM forward. """
hidden_states = self.mtp_model(input_ids, hidden_states, batch_valid_length, block_tables, slot_mapping,
position_ids, attention_mask, q_seq_lens, key_cache)
return hidden_states.reshape(-1, hidden_states.shape[-1])
def kvcache(self, layer_idx):
"""Get the key_cache depend on layer_idx."""
key_cache = self.mtp_model.layer.decode_layer.attention.infer_attention.paged_attention_mgr.key_cache
return key_cache, None