"""DeepseekV3 models' APIs."""
import math
from enum import Enum
from typing import Tuple, Optional, Dict
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, mint, ops, Parameter
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.common.initializer import Zero
from mindspore.communication._comm_helper import _is_initialized
try:
from mindspore._checkparam import Validator
except ImportError:
import mindspore._checkparam as Validator
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import lazy_inline, LayerSetting, check_fine_grain_interleave_valid, predict_lazy_inline
from mindformers.modules.layers import Linear, FreqsMgr, _check_input_dtype, _yarn_get_mscale
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.modules.transformer.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer import TransformerOpParallelConfig
from mindformers.modules.infer_attention import InferRotaryEmbedding, FlashAttention
from mindformers.tools.logger import logger
from mindformers.tools.utils import get_predict_run_mode, is_pynative
from mindformers.experimental.infer.core.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.experimental.parallel_core.pynative.parallel_state import get_group_info, initialize_model_parallel
from mindformers.experimental.infer.core.utils import get_tp_world_size
from mindformers.experimental.infer.core.norm import RMSNorm
from mindformers.experimental.infer.core.moe import RoutedParallelMLP, SharedParallelMLP, ParallelMoEV2
from mindformers.experimental.infer.core.transformer import ParallelMLP, VocabEmbedding
from research.deepseek3.deepseek3_config import DeepseekV3Config
from research.deepseek3.utils import convert_model_config
__all__ = ['InferenceDeepseekV3ForCausalLM', 'DeepseekV3Model']
class CacheConfig(Enum):
KEY_VALUE_CACHE = 0
KEY_CACHE = 1
KEY_VALUE_CACHE_KVSCALE_CACHE = 2
class MLAPagedAttentionMgr(nn.Cell):
r""" Paged Attention Manager for MLA, which only stores the cache of key_cache.
Args:
- **n_head** (int): The head num of query.
- **head_dim** (int): The dim of head.
- **n_kv_head** (int): The head num of key and value.
- **kv_shape** (tuple): Shape of key and value: math:`(num_blocks, block_size, self.n_kv_head, head_dim)`.
- **compute_dtype** (mstype): Compute dtype for infer attention. Default mstype.float16.
- **parallel_decoding** (mstype): If open parallel decoding. Default False.
- **scale_value** (mstype): The scale factor of score. Default None.
- **mla_v_dim** (int): The dim of value in Multi-Latent Attention. Default 512.
Inputs:
- **key** (Tensor[float16, bfloat16]) - The key tensor.
Input tensor of shape: math:`(B, S2, H2)` or math:`(B, N2, S2, D)`.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
- **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type
are the same as the key.
"""
def __init__(self,
n_heads,
head_dim,
n_kv_heads,
kv_shape,
compute_dtype=mstype.float16,
parallel_decoding=False,
scale_value=None,
mla_v_dim=512):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.n_kv_heads = n_kv_heads
self.scale_value = 1 / math.sqrt(self.head_dim) if scale_value is None else scale_value
self.key_cache = Parameter(Tensor(shape=kv_shape, dtype=compute_dtype, init=Zero()), name="key_cache",
requires_grad=False)
self.reshape_and_cache = ops.auto_generate.ReshapeAndCache()
self.paged_attention = ops.auto_generate.PagedAttention(self.n_heads,
self.scale_value,
self.n_kv_heads,
mla_v_dim=mla_v_dim)
self.parallel_decoding = parallel_decoding
def construct(self, key, slot_mapping):
"""The forward compute of single cache for Paged Attention."""
return self.reshape_and_cache(key, None, self.key_cache, None, slot_mapping)
def paged_attn(self, query, batch_valid_length, block_tables):
"""The forward compute of Paged Attention."""
return self.paged_attention(query, self.key_cache, self.key_cache, block_tables, batch_valid_length)
class MLAInferAttention(nn.Cell):
r"""Multi-Latent-Attention Layer for infer.
This function contains the InferAttention primitives used with FlashAttention and PagedAttention for MLA infer.
B -- Batch size
S1 -- Sequence length of query. The value ranges from 1 to 32768 and is a multiple of 16.
S2 -- Sequence length of key and value. The value ranges from 1 to 32768 and is a multiple of 16.
N1 -- Num heads of query
N2 -- Num heads of key and value, and N2 must be a factor of N1
D -- Head size. Support value: 64, 80, 96, 120, 128 and 256.
H1 -- Hidden size of query, which equals to N1 * D
H2 -- Hidden size of key and value, which equals to N2 * D
Args:
n_head (int): The head num of query.
head_dim (int): The dim of head.
n_kv_head (int): The head num of key and value.
pa_n_head_split (int): The query head num of paged attention op after split.
pa_n_kv_head_split (int): The key and value head num of paged attention op after split.
keep_prob (float): The keep probability of dropout. Default: 1.0.
scale_value (float): The scale factor of score. Default: 1.0.
pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
Default: "BSH".
sparse_mode (int): Indicates sparse mode. Default 0.
- 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full attn_mask
matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and nextTokens needs to
be calculated.
- 1: Represents allMask, that is, passing in the complete attn_mask matrix.
- 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
vertex, and the optimized attn_mask matrix (2048*2048) is required.
- 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
right vertex, and the optimized attn_mask matrix (2048*2048) is required.
- 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
optimized attn_mask matrix (2048*2048) is required..
- 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
width N is added to the left side. The value of N is obtained by the new input prefix, and the N value of
each Batch axis is different. Not implemented yet.
- 6: Represents the global scenario, not implemented yet.
- 7: Represents the dilated scenario, not implemented yet.
- 8: Represents the block_local scenario, not implemented yet.
block_size (int): Block size for paged attention.
num_blocks (int): Block num for paged attention.
use_alibi_mask (bool): The value is True if alibi_mask is passed. Default: False.
use_rope_rotary_emb (bool): If use rotary embedding. Default True.
rotary_cos_format (int): Choose the rotary embedding cos format. Default 0.
rotary_dtype (mstype): Compute dtype for rope op. Default mstype.float16.
compute_dtype (mstype): Compute dtype for infer attention. Default mstype.float16.
parallel_decoding (mstype): If open parallel decoding. Default False.
prefill_head_dim (int): The dim of head for prefill attention. Default None.
Inputs:
- **query** (Tensor[float16, bfloat16]) - The query tensor.
Input tensor of shape :math:`(B, S1, H1)` or :math:`(B, N1, S1, D)`.
- **key** (Tensor[float16, bfloat16]) - The key tensor.
Input tensor of shape :math:`(B, S2, H2)` or :math:`(B, N2, S2, D)`.
- **value** (Tensor[float16, bfloat16]) - The value tensor.
Input tensor of shape :math:`(B, S2, H2)` or :math:`(B, N2, S2, D)`.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
- **freqs_cos** (Tensor[float16, bfloat16]) - The precompute freqs cos for rotary position embedding used in
attention, shape is (seq_len, head_dim).
- **freqs_sin** (Tensor[float16, bfloat16]) - The precompute freqs sin for rotary position embedding used in
attention, shape is (seq_len, head_dim).
- **attn_mask** (Union[Tensor[uint8], None]) - The attention mask tensor. For each element, 0 indicates
retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`,
:math:`(S1, S2)` or (2048, 2048).
- **alibi_mask** (Union[Tensor[float16, bfloat16], None]) - The position embedding code. If S is greater than
1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of the lower triangle for
memory optimization.
Input tensor of shape :math:`(B, N1, S1, S2)`, :math:`(1, N1, S1, S2)`, :math:`(B, N1, 1024, S2)`,
:math:`(1, N1, 1024, S2)` or (1024, 1024).
- **prefix_keys_values** (Union[Tensor[float16, bfloat16], None]) - The prefix keys values.
- **q_seq_lens** (Union[Tensor[int32], None]) - The query actual seq len.
Outputs:
- **attention_out** (Tensor[float16, bfloat16]) - The output of attention, its shape, and data type
are the same as the query.
"""
def __init__(self,
n_head,
head_dim,
n_kv_head,
pa_n_head_split=None,
pa_n_kv_head_split=None,
keep_prob=1.0,
scale_value=1.0,
pre_tokens=2147483647,
next_tokens=2147483647,
sparse_mode=0,
block_size=16,
num_blocks=1024,
use_alibi_mask=False,
compute_dtype=mstype.float16,
parallel_decoding=False,
prefill_head_dim=None,
):
super(MLAInferAttention, self).__init__()
self.n_head = n_head
self.head_dim = head_dim
self.n_kv_head = n_kv_head
self.pa_n_head_split = pa_n_head_split if pa_n_head_split is not None else n_head
self.pa_n_kv_head_split = pa_n_kv_head_split if pa_n_kv_head_split is not None else n_kv_head
self.keep_prob = keep_prob
self.scale_value = scale_value
self.pre_tokens = pre_tokens
self.next_tokens = next_tokens
self.sparse_mode = sparse_mode
self.block_size = block_size
self.num_blocks = num_blocks
self.use_alibi_mask = use_alibi_mask
self.compute_dtype = compute_dtype
self.is_first_iteration = True
self.reshape = P.Reshape()
self.is_pynative = is_pynative()
if self.is_pynative:
self.input_layout = "BSH"
else:
self.input_layout = "TH"
self.use_attention_mask = not self.use_alibi_mask
self.flash_attention = FlashAttention(head_num=self.n_head,
pre_tokens=self.pre_tokens,
next_tokens=self.next_tokens,
keep_prob=self.keep_prob,
scale_value=self.scale_value,
sparse_mode=self.sparse_mode,
use_attention_mask=self.use_attention_mask,
use_alibi_mask=self.use_alibi_mask,
input_layout=self.input_layout)
kv_shape = (self.num_blocks, self.block_size, self.n_kv_head, self.head_dim)
self.paged_attention_mgr = MLAPagedAttentionMgr(self.pa_n_head_split,
self.head_dim,
self.pa_n_kv_head_split,
kv_shape,
compute_dtype=self.compute_dtype,
parallel_decoding=parallel_decoding,
scale_value=self.scale_value)
self.prefill_head_dim = prefill_head_dim
def _prefill_attention(self, query, key, value, attn_mask, alibi_mask, actual_seq_qlen=None,
actual_seq_kvlen=None):
"""
prefill attention
"""
bs, seq_len, _ = query.shape
prefill_head_dim = self.prefill_head_dim if self.prefill_head_dim else self.head_dim
if not self.is_pynative:
query = self.reshape(query, (-1, self.n_head * prefill_head_dim))
key = self.reshape(key, (-1, self.n_head * prefill_head_dim))
value = self.reshape(value, (-1, self.n_head * prefill_head_dim))
output = self.flash_attention(query, key, value, attn_mask, alibi_mask, None, None,
actual_seq_qlen, actual_seq_kvlen)
output = self.reshape(output, (bs, seq_len, self.n_head * prefill_head_dim))
return output
def _incre_attention(self, query, batch_valid_length, block_tables):
return self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables)
def construct(self, query, key, value, batch_valid_length, block_tables,
attn_mask=None, alibi_mask=None):
""" Forward process of the MLA Infer Attention Cell """
if self.is_first_iteration:
return self._prefill_attention(query, key, value, attn_mask, alibi_mask, batch_valid_length,
batch_valid_length)
return self._incre_attention(query, batch_valid_length, block_tables)
class DeepseekV3Attention(nn.Cell):
r"""
This is an implementation of self-attention mechanism in DeepSeek-V3.
Args:
- **dim** (int): The hidden size of the input.
- **n_heads** (int): The number of the heads.
- **n_kv_heads** (int): The number of key_value heads that should be used to implement
Grouped Query Attention.
- **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16.
Should be mstype.float32 or mstype.float16.
- **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype.
float32. Should be mstype.float32 or mstype.float16.
- **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not.
- **use_past** (bool): Use the past state to compute, used for incremental prediction.
For example, if we have two words and want to generate the ten more words.
We just need to compute the two words' state only once, and generate the next word one by one.
When use_past is True, there are two steps to run the prediction.
In the first step, set the is_first_iteration to be True by
`model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
pass the single step's input tensor, and loop it. Default True.
- **use_flash_attention** (bool): Whether to enable flash attention ops. Default True.
- **block_size** (int): The maximum number of tokens in one block can have when using paged attention.
Default 16.
- **num_blocks** (int): The maximum number of blocks when using paged attention. Default 512.
- **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
- **kv_lora_rank** (int): kv_lora_rank for Multi-Latent-Attention. Default 512.
- **q_lora_rank** (int): q_lora_rank for Multi-Latent-Attention. Default 1536.
- **qk_rope_head_dim** (int): qk_rope_head_dim for Multi-Latent-Attention. Default 64.
- **v_head_dim** (int): v_head_dim for Multi-Latent-Attention. Default 128.
- **qk_nope_head_dim** (int): qk_nope_head_dim for Multi-Latent-Attention. Default 128.
- **max_position_embeddings** (int): The maximum sequence length that this model might ever be used with.
Default 2048.
- **scaling_factor** (float): Scaling factor of Multi-Latent Attention. Default None.
- **norm_eps** (float): The epsilon value of the denominator. Default 1e-5.
- **layernorm_compute_dtype** (dtype.Number): The computation type of layernorm. Default mstype.float32.
Inputs:
- **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or
(batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
Otherwise, must be (batch_size, 1, hidden_size)
- **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention.
- **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size),
if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, head_dim, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, head_dim)).
"""
def __init__(self,
dim=512,
n_heads=8,
n_kv_heads=None,
compute_dtype=mstype.float16,
param_init_type=mstype.float32,
qkv_has_bias=False,
use_past=True,
use_flash_attention=True,
block_size=None,
num_blocks=None,
parallel_config=TransformerOpParallelConfig(),
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
max_position_embeddings=2048,
scaling_factor=None,
norm_eps=1e-5,
layernorm_compute_dtype=mstype.float32,
config: DeepseekV3Config = None
):
super().__init__()
self.hidden_size = dim
self.tensor_parallel_group_size = get_tp_world_size()
self.n_head = n_heads
self.n_local_heads = n_heads // self.tensor_parallel_group_size
self.head_dim = dim // n_heads
self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
self.kv_dim = self.n_kv_head * self.head_dim
self.block_size = block_size
self.num_blocks = num_blocks
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.max_position_embeddings = max_position_embeddings
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.scaling_factor = scaling_factor
self.dtype = compute_dtype
self.is_first_iteration = True
self.use_past = use_past
self.use_flash_attention = use_flash_attention
self.qkv_concat = config.qkv_concat
if not self.use_past:
raise ValueError("For 'DeepseekV3Attention', the use_past must be enabled.")
if not self.use_flash_attention:
raise ValueError("For 'DeepseekV3Attention', the use_flash_attention must be enabled.")
if self.hidden_size % self.n_head != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
"of 'n_head', but got the hidden_size is {} and the n_head is {}."
.format(self.hidden_size, self.n_head))
if self.n_kv_head % parallel_config.model_parallel != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of "
"'parallel_config.model_parallel', but got the n_kv_head is {} "
"and the parallel_config.model_parallel is {}."
.format(self.n_kv_head, parallel_config.model_parallel))
self.shape = P.Shape()
self.cast = P.Cast()
if self.q_lora_rank == 0:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.n_head * self.q_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype
)
self.kv2l = Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
else:
if self.qkv_concat:
self.qkv2l = Linear(
self.hidden_size,
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
else:
self.q2l_proj = Linear(
self.hidden_size,
self.q_lora_rank,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
self.kv2l = Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
has_bias=qkv_has_bias,
compute_dtype=compute_dtype,
param_init_type=param_init_type
)
self.lq_norm = RMSNorm(self.q_lora_rank, norm_eps, compute_type=layernorm_compute_dtype)
self.l2q_proj = ColumnParallelLinear(
self.q_lora_rank,
self.n_head * self.q_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype
)
self.lkv_norm = RMSNorm(self.kv_lora_rank, norm_eps, compute_type=layernorm_compute_dtype)
self.lkv2kv_k_nope = ColumnParallelLinear(
self.kv_lora_rank,
self.n_head * self.qk_nope_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype
)
self.lkv2kv_v = ColumnParallelLinear(
self.kv_lora_rank,
self.n_head * self.v_head_dim,
config=parallel_config,
bias=qkv_has_bias,
param_init_type=param_init_type,
compute_dtype=compute_dtype
)
self.wo = RowParallelLinear(
self.n_head * self.v_head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
config=parallel_config,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
)
self.scale_fa = 1. / math.sqrt(self.q_head_dim)
if self.scaling_factor is not None:
mscale_all_dim = self.scaling_factor.get("mscale_all_dim", 0)
factor = self.scaling_factor["factor"]
if mscale_all_dim:
mscale = _yarn_get_mscale(factor, mscale_all_dim)
self.scale_fa = mscale * mscale / (math.sqrt(self.q_head_dim))
self.reshape = P.Reshape()
self.tile_kv = P.Tile()
self.dim_slice_4d = P.Slice()
self.kpe_concat = P.Concat(2)
self.pe_concat = P.Concat(3)
self.qabsorb_matmul = P.BatchMatMul()
self.outabsorb_matmul = P.BatchMatMul(transpose_b=True)
self.infer_attention = MLAInferAttention(self.n_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
1,
scale_value=self.scale_fa,
next_tokens=0,
block_size=self.block_size,
num_blocks=self.num_blocks,
compute_dtype=compute_dtype,
prefill_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim)
self.apply_rotary_emb = InferRotaryEmbedding(rotary_cos_format=2)
def construct(self, x: Tensor, freqs_cis: Tuple[Tensor, Tensor], mask=None, batch_valid_length=None,
block_tables=None, slot_mapping=None):
""" Forward process of the DeepseekV3Attention. """
ori_dtype = x.dtype
if self.q_lora_rank == 0:
bs, seq_len, _ = self.shape(x)
q = self.q_proj(x)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.qkv_concat:
qkv2l = self.qkv2l(x)
q, latent_kv, k_pe = mint.split(qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
bs, seq_len, _ = self.shape(q)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
else:
q = self.q2l_proj(x)
bs, seq_len, _ = self.shape(q)
norm_q = self.lq_norm(q)
q = self.l2q_proj(norm_q)
latent_kv_all = self.kv2l(x)
latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
q = self.reshape(q, (bs, seq_len, self.n_local_heads, self.q_head_dim))
q_nope, q_pe = mint.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
i_kv = self.lkv_norm(latent_kv)
k_pe = self.reshape(k_pe, (bs, seq_len, 1, self.qk_rope_head_dim))
q_pe, k_pe = self.apply_rotary_emb(q_pe, k_pe, freqs_cis, batch_valid_length)
q_pe = self.reshape(q_pe, (bs, seq_len, self.n_local_heads, self.qk_rope_head_dim))
k_pe = self.reshape(k_pe, (bs, seq_len, 1, self.qk_rope_head_dim))
key_states_cache = self.kpe_concat((i_kv, k_pe.view(bs, seq_len, self.qk_rope_head_dim)))
key_out = self.infer_attention.paged_attention_mgr(key_states_cache, slot_mapping)
q_nope = ops.depend(q_nope, key_out)
if self.is_first_iteration:
o_k_nope = self.lkv2kv_k_nope(i_kv)
o_v = self.lkv2kv_v(i_kv)
k_nope = self.reshape(o_k_nope, (bs, seq_len, self.n_local_heads, self.qk_nope_head_dim))
value_states = self.reshape(o_v, (bs, seq_len, self.n_local_heads, self.v_head_dim))
query_states = self.pe_concat((q_nope, q_pe))
k_pe = self.tile_kv(k_pe, (1, 1, self.n_local_heads, 1))
key_states = self.pe_concat((k_nope, k_pe))
value_states = self.pe_concat((value_states, k_pe))
key_states = key_states.view(bs, seq_len, -1)
value_states = value_states.view(bs, seq_len, -1)
query_states = query_states.view(bs, seq_len, -1)
context_layer = self.infer_attention(query_states, key_states, value_states, batch_valid_length,
block_tables, mask)
context_layer = context_layer.view(bs, seq_len, self.n_local_heads, self.q_head_dim)
context_layer = self.dim_slice_4d(context_layer, (0, 0, 0, 0), (bs, seq_len, self.n_local_heads,
self.v_head_dim))
attn_out = context_layer.view(bs, seq_len, self.n_local_heads * self.v_head_dim)
output = self.wo(attn_out)
output = self.cast(output, ori_dtype)
return output
q_absorb = self.lkv2kv_k_nope.weight.view(self.n_local_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.lkv2kv_v.weight.view(self.n_local_heads, self.v_head_dim, self.kv_lora_rank)
q_nope = self.qabsorb_matmul(q_nope.transpose(0, 2, 1, 3), q_absorb).transpose(0, 2, 1, 3)
query_states = self.pe_concat((q_nope, q_pe))
query_states = query_states.view(bs, seq_len, -1)
key_states = key_states_cache
context_layer = self.infer_attention(query_states, key_states, key_states, batch_valid_length,
block_tables, attn_mask=mask)
context_layer = context_layer.view(bs, seq_len, self.n_local_heads, -1).transpose(0, 2, 1, 3)
attn_out = self.outabsorb_matmul(context_layer, out_absorb).transpose(0, 2, 1, 3)
attn_out = attn_out.view(bs, seq_len, self.n_local_heads * self.v_head_dim)
output = self.wo(attn_out)
output = self.cast(output, ori_dtype)
return output
class DeepseekV3ParallelMLP(ParallelMLP):
r"""
Implementation of parallel feedforward block.
Args:
config (dict): Configuration.
is_expert (book): This block is an expert block. Default: False.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Outputs:
- **output** (Tensor) - Output tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, config, is_expert=False):
super().__init__(config)
if is_expert:
raise NotImplementedError("For ParallelMLP, `is_expert` is not supported for now.")
def construct(self, x):
""" Construct function of mlp block. """
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x)
gate, hidden = mint.split(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x)
hidden = self.w3(x)
gate = self.act_func(gate)
hidden = mint.mul(hidden, gate)
output = self.w2(hidden)
return output
class DeepseekV3MoE(Cell):
r"""
This is an implementation of self-attention mechanism in DeepSeek-V3.
Args:
- **config** (Config): Model config of DeepSeek-V3.
Inputs:
- **x** (Tensor): Should be `[batch, seq_length, hidden_size]`. Float tensor.
Outputs:
- **output** (Tensor): The output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
"""
def __init__(self, config):
super(DeepseekV3MoE, self).__init__()
self.config = config
self.parallel_config = config.parallel_config
self.moe_config = config.moe_config
self.moe_config.router_dense_type = config.router_dense_type
intermediate_size = self.moe_config.moe_intermediate_size
if self.parallel_config.expert_parallel == 1:
ffn = RoutedParallelMLP(config)
self.routed_experts = ParallelMoEV2(ffn, self.config.hidden_size, self.moe_config)
else:
raise NotImplementedError("For ParallelMoEV2, `expert_parallel` is not supported for now.")
if self.moe_config.shared_expert_num is not None:
intermediate_size = intermediate_size * self.moe_config.shared_expert_num
self.shared_experts = SharedParallelMLP(config, intermediate_size)
self.add = P.Add()
def construct(self, x):
output = self.routed_experts(x)
if self.moe_config.shared_expert_num is not None:
output = self.add(output, self.shared_experts(x))
return output
class DeepseekV3DecodeLayer(nn.Cell):
r"""
Transformer Layer. This is an implementation of the single layer of the transformer
encoder layer, including multihead attention and feedward layer.
Args:
- **layer_id** (int): The layer id of current transformer block layer.
- **dim** (int): The hidden size of the input.
- **num_heads** (int): The number of the heads.
- **n_kv_heads** (int): The number of key_value heads that should be used to implement
Grouped Query Attention.
- **norm_eps** (float): The epsilon value of the denominator. Default 1e-5.
- **compute_dtype** (dtype.Number): The computation type of the layer.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **layernorm_compute_type** (dtype.Number): The computation type of the norm.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **param_init_type** (dtype.Number): The parameter initialization type of the module.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
- **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not.
- **use_past** (bool): Use the past state to compute, used for incremental prediction.
For example, if we have two words and want to generate the ten more words.
We just need to compute the two words' state only once, and generate the next word one by one.
When use_past is True, there are two steps to run the prediction. In the first step,
set the is_first_iteration to be True by `model.add_flags_recursive(is_first_iteration=True)`,
and pass the full inputs. Then, set the is_first_iteration to be False by
`model.add_flags_recursive(is_first_iteration=False)`.
At this moment, pass the single step's input tensor, and loop it. Default True.
- **moe_config** (MoEConfig): The MoE configuration. Default: ``default_moe_config`` ,
an instance of `MoEConfig` with default args.
- **use_flash_attention** (bool): Whether to enable flash attention ops. Default True.
- **block_size** (int): The maximum number of tokens in one block can have when using paged attention.
Default 16.
- **num_blocks** (int): The maximum number of blocks when using paged attention. Default 512.
- **parallel_config** (OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
- **kv_lora_rank** (int): kv_lora_rank for Multi-Latent-Attention. Default 512.
- **q_lora_rank** (int): q_lora_rank for Multi-Latent-Attention. Default 1536.
- **qk_rope_head_dim** (int): qk_rope_head_dim for Multi-Latent-Attention. Default 64.
- **v_head_dim** (int): v_head_dim for Multi-Latent-Attention. Default 128.
- **qk_nope_head_dim** (int): qk_nope_head_dim for Multi-Latent-Attention. Default 128.
- **max_position_embeddings** (int): The maximum sequence length that this model might ever be used with.
Default 2048.
- **scaling_factor** (float): Scaling factor of Multi-Latent Attention. Default None.
- **config** (Config): Model config of DeepSeek-V3. Default None.
Inputs:
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
should be [batch_size, 1, hidden_size]
- **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention.
- **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
- **block_tables** (Tensor[int64]) - Store mapping tables for each sequence.
- **slot_mapping** (Tensor[int32]) - Store token cache physical slot index.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`).
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past
is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size)
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, head_dim, seq_length),
(batch_size, num_heads, seq_length, head_dim)).
"""
@predict_lazy_inline
def __init__(self,
layer_id,
dim: int = 512,
n_heads: int = 8,
n_kv_heads: Optional[int] = None,
norm_eps: float = 1e-5,
compute_dtype=mstype.float32,
layernorm_compute_dtype=mstype.float32,
param_init_type=mstype.float32,
qkv_has_bias=False,
use_past=True,
moe_config=None,
use_flash_attention=True,
block_size: Optional[int] = None,
num_blocks: Optional[int] = None,
parallel_config=TransformerOpParallelConfig(),
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
max_position_embeddings=2048,
scaling_factor: Optional[Dict] = None,
config: DeepseekV3Config = None
):
super().__init__()
self.layer_id = layer_id
self.hidden_size = dim
self.n_head = n_heads
self.head_dim = self.hidden_size // self.n_head
self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
self.dtype = compute_dtype
self.is_first_iteration = True
self.use_past = use_past
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.add = P.Add()
self.ffn_norm = RMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype)
self.attention_norm = RMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype)
self.attention = DeepseekV3Attention(dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
qkv_has_bias=qkv_has_bias,
use_past=use_past,
use_flash_attention=use_flash_attention,
block_size=block_size,
num_blocks=num_blocks,
parallel_config=parallel_config,
kv_lora_rank=kv_lora_rank,
q_lora_rank=q_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
norm_eps=norm_eps,
layernorm_compute_dtype=layernorm_compute_dtype,
config=config)
self.expert_num = 1 if moe_config is None else moe_config.expert_num
self.shared_expert_num = 0 if moe_config is None else moe_config.shared_expert_num
self.first_k_dense = (moe_config.first_k_dense_replace and layer_id < moe_config.first_k_dense_replace)
if self.first_k_dense:
logger.warning("first_k_dense_replace is provided in MoEConfig, "
"a normal dense FFN will be used in this block.")
self.feed_forward = DeepseekV3ParallelMLP(config)
else:
self.feed_forward = DeepseekV3MoE(config=config)
self.predict_run_mode = get_predict_run_mode()
if self.predict_run_mode:
self.no_inline = False
def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None,
slot_mapping=None):
""" Forward of transformer block. """
if not self.use_past:
self._check_input(x, freqs_cis, mask)
input_x = self.attention_norm(x)
h = self.attention(input_x, freqs_cis, mask, batch_valid_length, block_tables, slot_mapping)
h = self.add(x, h)
ffn_norm = self.ffn_norm(h)
ffn_out = self.feed_forward(ffn_norm)
out = self.add(h, ffn_out)
return out
def _check_input(self, x, freqs_cis, mask):
r"""Check inputs"""
_check_input_dtype(
x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
freqs_cos, freqs_sin, swap_mask = freqs_cis
_check_input_dtype(freqs_cos.dtype, "freqs_cos",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
_check_input_dtype(freqs_sin.dtype, "freqs_sin",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
if swap_mask is not None:
_check_input_dtype(swap_mask.dtype, "swap_mask",
[mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name)
if mask is not None:
_check_input_dtype(mask.dtype, "input_mask",
[mstype.float32, mstype.float16, mstype.bfloat16, mstype.uint8, mstype.bool_],
self.cls_name)
return True
class DeepseekV3PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DeepseekV3Config
base_model_prefix = "deepseekv3"
class DeepseekV3Model(DeepseekV3PreTrainedModel):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
Args:
config(DeepseekV3Config): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
Returns:
output: Tensor, the output of deepseek decoderlayer
"""
def __init__(self,
config: DeepseekV3Config = None):
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
if config.batch_size or config.use_past:
Validator.check_positive_int(config.batch_size)
self.dtype = config.compute_dtype
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.n_head = config.num_heads
self.head_dim = self.hidden_size // self.n_head
self.pad_token_id = config.pad_token_id
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.max_position_embeddings = config.max_position_embeddings
self.is_first_iteration = True
self.is_pynative = is_pynative()
self.use_past = config.use_past
self.is_dynamic = config.is_dynamic
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.tile = P.Tile()
self.expand_dims = P.ExpandDims()
self.gather = P.Gather()
self.slice = P.StridedSlice()
self.freqs_mgr = FreqsMgr(head_dim=self.qk_rope_head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embeddings,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
is_dynamic=config.is_dynamic)
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_past=config.use_past)
if config.parallel_config.vocab_emb_dp:
self.tok_embeddings = VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
parallel_config=config.parallel_config,
init_method="normal",
init_type=config.param_init_type)
else:
self.tok_embeddings = VocabEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
param_init_type=config.param_init_type,
param_init="normal")
self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave,
config.parallel_config)
self.layers = nn.CellList()
self.layer_setting = LayerSetting(config.num_layers,
config.offset,
config.parallel_config,
config.pp_interleave_num)
for layer_id in range(config.num_layers):
layer = DeepseekV3DecodeLayer(layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
n_kv_heads=config.n_kv_heads,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
param_init_type=config.param_init_type,
use_past=config.use_past,
use_flash_attention=config.use_flash_attention,
block_size=config.block_size,
num_blocks=config.num_blocks,
parallel_config=config.parallel_config,
moe_config=config.moe_config,
kv_lora_rank=config.kv_lora_rank,
q_lora_rank=config.q_lora_rank,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
qk_nope_head_dim=config.qk_nope_head_dim,
max_position_embeddings=config.max_position_embeddings,
scaling_factor=config.scaling_factor,
config=config)
self.layer_setting(layer, layer_id)
self.layers.append(layer)
self.norm_out = RMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type)
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None):
"""
Forward of deepseekv3 model.
Args:
tokens: the tokenized inputs with datatype int32
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
batch_index(Tensor): The generated batch index when use continuous batching in LLM serving.
Tensor of shape :math:`(batch_size,)`. Default None.
zactivate_len(Tensor): The slice length of KVCache when use dynamic shape infer.
Tensor of shape :math:`(seq_length,)`. Default None.
block_tables(Tensor[int64]): Store mapping tables for each sequence.
slot_mapping(Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of deepseekv3 decoderlayer
"""
bs, seq_len = self.shape(tokens)
mask = None
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
if not self.is_pynative:
mask = self.casual_mask.prefill()
else:
mask = self.casual_mask(tokens)
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
else:
mask = self.casual_mask(tokens)
freqs_cis = self.freqs_mgr(seq_len)
h = self.cast(self.tok_embeddings(tokens), self.dtype)
h = self.reshape(h, (bs, seq_len, self.hidden_size))
for i in range(self.num_layers):
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping)
output = self.norm_out(h)
return output
class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
r"""
Provide DeepseekV3 logits through network.
Args:
config (DeepseekV3Config): The config of DeepseekV3 model.
Inputs:
input_ids(Tensor): The tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
labels(Tensor): The tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`.
input_position(Tensor): Current position, used by model.predict.
position_ids(Tensor): Reserved param, not used.
attention_mask(Tensor): Reserved param, not used.
input_embeds(Tensor): Reserved param, not used.
init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Default True.
batch_valid_length(Tensor): The past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
batch_index(Tensor): The generated batch index when use continuous batching in LLM serving.
Tensor of shape :math:`(batch_size,)`. Default None.
zactivate_len(Tensor): The slice length of KVCache when use dynamic shape infer.
Tensor of shape :math:`(seq_length,)`. Default None.
block_tables(Tensor, optional): Int64 type Tensor, store mapping tables for each sequence. Default None.
slot_mapping(Tensor, optional): Int32 type Tensor, token cache physical slot index. Default None.
Returns:
Tensor. If it is in prediction mode, the output Tensor contains logits;
If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks.
"""
@lazy_inline
def __init__(self, config: DeepseekV3Config = None):
super(InferenceDeepseekV3ForCausalLM, self).__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = convert_model_config(config)
self.parallel_config = self.config.parallel_config
tp_group = get_group_info('tp').group is None
ep_group = get_group_info('ep').group is None
pp_group = get_group_info('pp').group is None
all_groups_initialized = tp_group and ep_group and pp_group
if all_groups_initialized and _is_initialized():
initialize_model_parallel(pipeline_model_parallel_size=self.parallel_config.pipeline_model_parallel_size,
expert_model_parallel_size=self.parallel_config.expert_parallel,
tensor_model_parallel_size=self.parallel_config.tensor_model_parallel_size,
order='tp-ep-dp-pp')
self.seq_length = config.seq_length
self.ignore_token_id = config.ignore_token_id
self.pad_token_id = config.pad_token_id
self.use_past = config.use_past
self.vocab_size = config.vocab_size
self.is_first_iteration = True
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.slice = P.StridedSlice()
self.not_equal = P.NotEqual()
self.mul = P.Mul()
self.add = P.Add()
self.ones = P.Ones()
self.gather = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.model = DeepseekV3Model(config=config)
if config.parallel_config.vocab_emb_dp:
self.lm_head = Linear(
in_channels=config.hidden_size,
out_channels=config.vocab_size,
weight_init="normal",
has_bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype
)
else:
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
config=config.parallel_config,
bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
weight_init="normal",
gather_output=True
)
self.prefill_gather_flatten = P.Gather()
self.load_checkpoint(config)
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode:{}".format(self.predict_run_mode))
self.return_hidden_states = config.return_hidden_states
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
""" Get deepseekv3 model input tuple for transform ckpt. """
input_ids = Tensor(input_ids, mstype.int32)
labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None
bs, seq = input_ids.shape[0], input_ids.shape[1]
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
return input_ids, labels, None, None, None, None, None, None, None, None, None, \
slot_mapping
def set_dynamic_inputs(self, **kwargs):
""" Mindspore's feature, Set dynamic input for DeepseekV3. """
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_init_reset = True
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
self.set_inputs(dynamic_input_ids, None, None, None, None, None, dynamic_init_reset,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping)
logger.info("Set dynamic input for DeepseekV3.")
def pre_gather_func(self, pre_gather, output, batch_valid_length):
"""Pre gather operation in infer mode."""
if not pre_gather:
return output
if pre_gather:
if self.config.is_dynamic:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
else:
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
return output
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None):
""" DeepseekV3ForCausalLM forward. """
bsz, _ = self.shape(input_ids)
if self.use_past:
if not isinstance(batch_valid_length, Tensor):
batch_valid_length = self.ones((bsz,), mstype.int32)
tokens = input_ids
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables,
slot_mapping)
if self.return_hidden_states:
output = self.reshape(output, (-1, output.shape[-1]))
return output
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
output = self.pre_gather_func(pre_gather, output, batch_valid_length)
logits = self.lm_head(output)
input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32)
if labels is not None and labels.ndim > 1:
label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32)
input_mask = self.mul(input_mask, label_mask)
logits = self.cast(logits, mstype.float32)
if self.predict_run_mode:
logits = self.reshape(logits, (-1, logits.shape[-1]))
return logits
return logits, tokens, input_mask
def kvcache(self, layer_idx):
"""Get the key_cache depend on layer_idx."""
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
return key_cache, None