"""
DEPRECATED MODULE
This module is deprecated and will be removed in future releases.
For transformer
"""
import math
import os
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer
from research.llama3_1.infer.norm import RMSNorm
from research.llama3_1.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr
from research.llama3_1.infer.scale_mask_softmax import ScaleMaskSoftmax
from research.llama3_1.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from mindformers.parallel_core.inference.utils import divide, get_attn_mask_func
from mindformers.parallel_core.inference.transformer.activation import get_act_func
from mindformers.parallel_core.process_group_config import default_model_comm_pgs
from mindformers.modules.flash_attention import FlashAttention
from mindformers.modules.infer_attention import InferRotaryEmbedding
from mindformers.modules.layers import FreqsMgr, RotaryEmbedding
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.version_control import need_nz
class VocabEmbedding(nn.Cell):
"""
Embedding Layer.
Args:
- **num_embeddings** (int): Size of the dictionary of embeddings.
- **embedding_dim** (int): The size of each embedding vector.
- **param_init_type** (mstype): The param init type, default mstype.float32.
- **param_init** (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
Refer to class `initializer` for the values of string when a string
is specified. Default: 'normal'.
Inputs:
- **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length)
Outputs:
- **output** (Tensor) - The embedding vector for the input with shape (batch_size,
seq_length, embedding_size).
"""
def __init__(self, num_embeddings, embedding_dim, param_init_type=mstype.float32, param_init='normal',
parallel_optimizer=False):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embedding_weight = Parameter(
initializer(param_init, [self.num_embeddings, self.embedding_dim], dtype=param_init_type),
name='embedding_weight', parallel_optimizer=parallel_optimizer)
self.gather = ops.Gather()
def construct(self, input_ids):
"""Forward of vocab embedding."""
output = self.gather(self.embedding_weight, input_ids, 0)
return output
class ParallelMLP(nn.Cell):
r"""
Implementation of parallel feedforward block.
Args:
config (dict): Configuration.
is_expert (book): This block is an expert block. Default: False.
model_comm_pgs (ModelCommProcessGroups, optional): Model communication process group.
Default: default_model_comm_pgs.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Outputs:
- **output** (Tensor) - Output tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, config, is_expert=False, model_comm_pgs=default_model_comm_pgs):
super().__init__(config)
if is_expert:
raise NotImplementedError("For ParallelMLP, `is_expert` is not supported for now.")
self.config = config
self.has_bias = self.config.mlp_has_bias
self.hidden_size = self.config.hidden_size
self.ffn_hidden_size = self.config.ffn_hidden_size
self.mlp_has_gate = self.config.mlp_has_gate
self.ffn_concat = self.config.ffn_concat
self.tp = model_comm_pgs.tp
tp_group_size = self.tp.size
self.ffn_hidden_size_per_partition = divide(self.ffn_hidden_size, tp_group_size)
if self.mlp_has_gate:
if self.ffn_concat:
self.w_gate_hidden = ColumnParallelLinear(
self.hidden_size,
self.ffn_hidden_size * 2,
config=self.config.parallel_config,
bias=self.has_bias,
transpose_b=True,
gather_output=False,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
else:
self.w1 = ColumnParallelLinear(
self.hidden_size,
self.ffn_hidden_size,
config=self.config.parallel_config,
bias=self.has_bias,
transpose_b=True,
gather_output=False,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
self.w3 = ColumnParallelLinear(
self.hidden_size,
self.ffn_hidden_size,
config=self.config.parallel_config,
bias=self.has_bias,
transpose_b=True,
gather_output=False,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
else:
self.w1 = ColumnParallelLinear(
self.hidden_size,
self.ffn_hidden_size,
config=self.config.parallel_config,
bias=self.has_bias,
transpose_b=True,
gather_output=False,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
self.act_type = self.config.hidden_act
self.act_func = get_act_func(self.act_type)
self.w2 = RowParallelLinear(
self.ffn_hidden_size,
self.hidden_size,
input_is_parallel=True,
config=self.config.parallel_config,
bias=self.has_bias,
transpose_b=True,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
self.mul = ops.Mul()
self.reshape = ops.Reshape()
def construct(self, x):
""" Construct function of mlp block. """
if self.mlp_has_gate:
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x)
gate_hidden_out_shape = gate_hidden_out.shape
reshape_out = self.reshape(gate_hidden_out,
(*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2))
gate, hidden = ops.function.array_func.split_ext(reshape_out,
(1, 1), -1)
gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition))
else:
gate = self.w1(x)
hidden = self.w3(x)
gate = self.act_func(gate)
hidden = mint.mul(hidden, gate)
else:
hidden = self.w1(x)
hidden = self.act_func(hidden)
output = self.w2(hidden)
return output
class CoreAttention(nn.Cell):
r"""
Get the weighted score along the seq_length.
Args:
layer_number (int): Number which indicates the index of this transformer layer in the
whole transformer block.
config (dict): Configuration.
attn_type (str): Attention type. Support ['self_attn', 'cross_attn']. Default: 'self_attn'.
Inputs:
- **query** (Tensor) - Tensor of query matrix.
- **key** (Tensor) - Tensor of key matrix.
- **value** (Tensor) - Tensor of value matrix.
- **attention_mask** (Tensor) - Tensor of attention mask matrix.
Outputs:
- **attn_output** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, layer_number, config, attn_mask_type=None):
super().__init__()
if attn_mask_type:
raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.")
self.config = config
self.layer_index = max(1, layer_number)
self.compute_dtype = self.config.compute_dtype
self.softmax_compute_dtype = self.config.softmax_compute_dtype
self.sequence_parallel = self.config.parallel_config.use_sequence_parallel
self.apply_query_key_layer_scaling = self.config.apply_query_key_layer_scaling
self.num_heads = self.config.num_heads
self.hidden_size = self.config.hidden_size
self.head_dim = divide(self.hidden_size, self.num_heads)
coeff = None
norm_factor = math.sqrt(self.head_dim)
if self.apply_query_key_layer_scaling:
coeff = self.layer_index
norm_factor *= coeff
self.inv_norm_factor = Tensor(1.0 / norm_factor, dtype=self.compute_dtype)
self.mask_func = get_attn_mask_func(self.config.mask_func_type)
self.scale_mask_softmax = ScaleMaskSoftmax(self.mask_func,
softmax_compute_type=self.softmax_compute_dtype)
self.attention_dropout = mint.nn.Dropout(p=self.config.attention_dropout_rate)
def construct(self, query_layer, key_layer, value_layer, attention_mask):
"""
Computes the attention scores, applies the attention mask, and returns the weighted
sum of the value layer based on the attention probabilities.
Inputs:
----------
query_layer : Tensor
The query tensor of shape [B, N, S_q, D].
key_layer : Tensor
The key tensor of shape [B, N, S_k, D].
value_layer : Tensor
The value tensor of shape [B, N, S_k, D].
attention_mask : Tensor
The attention mask tensor of shape [B, N, S_q, S_k].
Returns:
-------
Tensor
The attention output tensor of shape [B, N, S_q, D].
"""
score = ops.bmm(query_layer, key_layer.transpose(0, 1, 3, 2))
score = mint.mul(score, self.inv_norm_factor)
attention_probs = self.scale_mask_softmax(score, attention_mask)
attention_probs = self.attention_dropout(attention_probs)
weighted_values = ops.bmm(attention_probs, value_layer)
return weighted_values
class ParallelAttention(nn.Cell):
r"""
Parallel attention block.
Args:
layer_index (int): Number which indicates the index of this transformer layer in the
whole transformer block.
config (dict): Configuration.
attn_type (str): Attention type. Support ['self_attn', 'cross_attn']. Default: 'self_attn'.
model_comm_pgs (ModelCommProcessGroups, optional): Model communication process group.
Default: default_model_comm_pgs.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
- **attention_mask** (Tensor) - Tensor of attention mask.
- **encoder_output** (Tensor) - Tensor of encoder output used for cross attention. Default: None.
- **rotary_pos_emb** (Tensor) - Tensor of rotary position embedding. Default: None.
Outputs:
- **output** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, config, layer_number, attention_type="self_attn", attn_mask_type=None,
model_comm_pgs=default_model_comm_pgs):
super().__init__(config)
if attn_mask_type:
raise NotImplementedError("For ParallelAttention, `attn_mask_type` is not supported for now.")
self.config = config
self.layer_index = max(1, layer_number)
self.param_init_dtype = self.config.param_init_dtype
self.compute_dtype = self.config.compute_dtype
self.is_first_iteration = True
self.use_past = self.config.use_past
self.qkv_concat = self.config.qkv_concat
self.attn_type = attention_type
self.num_heads = self.config.num_heads
self.kv_num_heads = self.num_heads if config.n_kv_heads is None else config.n_kv_heads
self.hidden_size = self.config.hidden_size
self.head_dim = divide(self.hidden_size, self.num_heads)
self.kv_hidden_size = self.head_dim * self.kv_num_heads
self.n_rep = divide(self.num_heads, self.kv_num_heads)
self.sequence_parallel = self.config.parallel_config.use_sequence_parallel
self.use_flash_attention = self.config.use_flash_attention
self.norm_factor = math.sqrt(self.head_dim)
self.tp = model_comm_pgs.tp
self.tp_group_size = self.tp.size
self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size)
self.use_gqa = self.num_heads != self.kv_num_heads
if self.use_gqa:
self._check_gqa_valid()
self.kv_num_heads_per_partition = divide(self.kv_num_heads, self.tp_group_size)
self.repeat_num = divide(self.num_heads, self.kv_num_heads)
else:
self.kv_num_heads_per_partition = self.num_heads_per_partition
if self.attn_type == "self_attn":
self._init_self_attn()
elif self.attn_type == "cross_attn":
self._init_cross_attn()
else:
raise NotImplementedError(
f"attention_type(str) should be 'self_attn' or 'cross_attn', but got {self.attn_type}")
self.reshape = ops.Reshape()
self.cast = ops.Cast()
self.wo = RowParallelLinear(
self.hidden_size,
self.hidden_size,
input_is_parallel=True,
config=self.config.parallel_config,
bias=self.config.out_proj_has_bias,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
if self.use_flash_attention:
input_layout = "TH" if self.use_past else "BNSD"
self.flash_attention = FlashAttention(head_num=self.num_heads_per_partition,
scale_value=1.0 / self.norm_factor,
next_tokens=0,
input_layout=input_layout)
else:
self.core_attention = CoreAttention(self.layer_index, self.config)
if self.use_past:
if need_nz():
kv_shape = (self.config.num_blocks, self.config.block_size,
self.kv_num_heads_per_partition * self.head_dim)
else:
kv_shape = (self.config.num_blocks, self.config.block_size,
self.kv_num_heads_per_partition, self.head_dim)
self.npu_mem_size = config.npu_mem_size if hasattr(config, "npu_mem_size") else 2
self.paged_attention_mgr = ParallelPagedAttentionMgr(self.num_heads_per_partition,
self.head_dim,
self.kv_num_heads_per_partition,
kv_shape,
config.seq_length,
compute_dtype=self.compute_dtype,
npu_mem_size=self.npu_mem_size)
self.rotary_embedding = InferRotaryEmbedding(rotary_cos_format=2)
else:
self.apply_rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_dtype)
def construct(self, x, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
attn_mask=None, alibi_mask=None, encoder_output=None, prefix_keys_values=None,
q_seq_lens=None, key_cache=None, value_cache=None):
"""Construct function of attention block."""
if self.attn_type == "self_attn":
if self.qkv_concat:
qkv = self.cast(self.w_qkv(x), self.compute_dtype)
reshape_qkv = self.reshape(qkv,
(-1,
self.kv_num_heads_per_partition,
(self.n_rep + 2) * self.head_dim))
query, key, value = ops.function.array_func.split_ext(reshape_qkv,
(self.head_dim * self.n_rep,
self.head_dim,
self.head_dim), -1)
if self.use_past:
query = self.reshape(query, (-1, self.hidden_size_per_partition))
key = self.reshape(key, (-1, self.kv_hidden_size_per_partition))
value = self.reshape(value, (-1, self.kv_hidden_size_per_partition))
else:
query = self.cast(self.wq(x), self.compute_dtype)
key = self.cast(self.wk(x), self.compute_dtype)
value = self.cast(self.wv(x), self.compute_dtype)
if not self.use_past:
bs, seq_len, _ = x.shape
query = self.reshape(query, (bs, seq_len, self.num_heads_per_partition, self.head_dim))
key = self.reshape(key, (bs, seq_len, self.kv_num_heads_per_partition, self.head_dim))
value = self.reshape(value, (bs, seq_len, self.kv_num_heads_per_partition, self.head_dim))
else:
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = ops.function.array_func.split_ext(
kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
if self.use_past:
if freqs_cis is not None:
query, key = self.rotary_embedding(query, key, freqs_cis, batch_valid_length)
if prefix_keys_values is not None:
prefix_len = prefix_keys_values.shape[2]
slot_mapping = slot_mapping + self.cast(mint.ne(slot_mapping, -1), mstype.int32) * prefix_len
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length,
key_cache=key_cache, value_cache=value_cache)
query = ops.depend(query, key_out)
if self.is_first_iteration:
if self.use_flash_attention:
context_layer = self.flash_attention(query, key, value, attn_mask, alibi_mask, None, None,
q_seq_lens, batch_valid_length)
else:
bs, seq_len, _ = x.shape
query = query.reshape(bs, seq_len, -1, self.head_dim)
key = key.reshape(bs, seq_len, -1, self.head_dim)
value = value.reshape(bs, seq_len, -1, self.head_dim)
if self.use_gqa:
key = mint.repeat_interleave(key, repeats=self.repeat_num, dim=2)
value = mint.repeat_interleave(value, repeats=self.repeat_num, dim=2)
query = query.transpose(0, 2, 1, 3)
key = key.transpose(0, 2, 1, 3)
value = value.transpose(0, 2, 1, 3)
context_layer = self.core_attention(query, key, value, attn_mask)
context_layer = context_layer.transpose(0, 2, 1, 3).reshape(
bs, seq_len, self.hidden_size_per_partition)
else:
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables,
attn_mask, q_seq_lens, key_cache, value_cache)
else:
bs, seq_len, _ = x.shape
query = query.transpose(0, 2, 1, 3)
key = key.transpose(0, 2, 1, 3)
value = value.transpose(0, 2, 1, 3)
if freqs_cis is not None:
query, key = self.apply_rotary_emb(query, key, freqs_cis)
if self.use_flash_attention:
if os.getenv('RUN_MODE') == 'predict':
raise NotImplementedError(
"Conflict detected in predict mode: "
"Flash Attention is incompatible when use_past=False")
context_layer = self.flash_attention(query, key, value, attn_mask)
else:
if self.use_gqa:
key = mint.repeat_interleave(key, repeats=self.repeat_num, axis=1)
value = mint.repeat_interleave(value, repeats=self.repeat_num, axis=1)
context_layer = self.core_attention(query, key, value, attn_mask)
context_layer = context_layer.transpose(0, 2, 1, 3).reshape(
bs, seq_len, self.hidden_size_per_partition)
output = self.wo(context_layer)
output = self.cast(output, x.dtype)
return output
def _cat_prefix(self, key, value, prefix_keys_values):
"""
concat prefix_keys_values to key and value
prefix_keys_values: shape(2, bs, pre_len, num_heads * kv_channels)
"""
if prefix_keys_values is not None:
past_key = prefix_keys_values[0]
past_value = prefix_keys_values[1]
past_key = self.cast(past_key, key.dtype)
past_value = self.cast(past_value, value.dtype)
key = ops.concat((past_key, key), 1)
value = ops.concat((past_value, value), 1)
return key, value
def _check_gqa_valid(self):
"""check whether the config is valid for grouped-query-attention"""
if self.num_heads % self.kv_num_heads != 0:
raise ValueError(
f"num_heads must be divisible by kv_num_heads, "
f"but got num_heads {self.num_heads} and kv_num_heads {self.kv_num_heads}"
)
if self.kv_num_heads % self.tp_group_size != 0:
raise ValueError(
f"kv_num_heads must be divisible by tp_group_size, "
f"but got kv_num_heads {self.kv_num_heads} and kv_num_heads {self.tp_group_size}"
)
def _init_self_attn(self):
"""init qkv linears of self-attention"""
self.hidden_size_per_partition = divide(self.hidden_size, self.tp_group_size)
self.kv_hidden_size_per_partition = divide(self.kv_hidden_size, self.tp_group_size)
if self.qkv_concat:
self.w_qkv = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
else:
self.wq = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
self.wk = ColumnParallelLinear(
self.hidden_size,
self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
self.wv = ColumnParallelLinear(
self.hidden_size,
self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
tp_group=self.tp,
)
def _init_cross_attn(self):
"""init qkv linears of cross-attention"""
if self.hidden_size != self.kv_hidden_size:
raise ValueError("hidden_size must be equal to kv_hidden_size!")
self.wq = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
if self.qkv_concat:
self.w_kv = ColumnParallelLinear(
self.hidden_size,
2 * self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
else:
self.wk = ColumnParallelLinear(
self.hidden_size,
self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
self.wv = ColumnParallelLinear(
self.hidden_size,
self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
class ParallelTransformerLayer(nn.Cell):
r"""
Single parallel transformer layer.
Args:
config (dict): Configuration.
layer_index (int): Number which indicates the index of this transformer layer in the
whole transformer block.
model_comm_pgs (ModelCommProcessGroups, optional): Model communication process group.
Default: default_model_comm_pgs.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(B, S, H)`.
- **attention_mask** (Tensor) - Tensor of attention mask.
- **rotary_pos_emb** (Tensor) - Tensor of rotary position embedding. Default: None.
Outputs:
- **output** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(
self,
config,
layer_number: int,
layer_type=None,
self_attn_mask_type=None,
drop_path_rate: float = 0.0,
model_comm_pgs=default_model_comm_pgs,
):
super().__init__(config)
if layer_type:
raise NotImplementedError("For ParallelTransformerLayer, only decoder only structure is supported for now.")
if self_attn_mask_type:
raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.")
if drop_path_rate > 0.0:
raise NotImplementedError(
f"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, "
f"but got `drop_path_rate={drop_path_rate}`"
)
self.config = config
self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm
self.attention_norm = RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
compute_type=config.layernorm_compute_dtype)
self.attention = ParallelAttention(config, layer_number, model_comm_pgs=model_comm_pgs)
self.ffn_norm = RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
compute_type=config.layernorm_compute_dtype)
self.feed_forward = ParallelMLP(config, model_comm_pgs=model_comm_pgs)
def construct(self, x, freqs_cis=None, mask=None, batch_valid_length=None, block_tables=None,
slot_mapping=None, prefix_keys_values=None, q_seq_lens=None, key_cache=None, value_cache=None):
"""Construct function of transformer layer."""
norm_output = self.attention_norm(x)
attention_output = self.attention(norm_output, batch_valid_length, block_tables, slot_mapping, freqs_cis,
mask, prefix_keys_values=prefix_keys_values,
q_seq_lens=q_seq_lens, key_cache=key_cache, value_cache=value_cache)
if self.apply_residual_connection_post_norm:
residual = norm_output
else:
residual = x
norm_input = ops.add(residual, attention_output)
norm_output = self.ffn_norm(norm_input)
mlp_output = self.feed_forward(norm_output)
if self.apply_residual_connection_post_norm:
residual = norm_output
else:
residual = norm_input
output = ops.add(residual, mlp_output)
return output
class ParallelTransformer(nn.Cell):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParallelTransformerLayer`]
Args:
config: the config of transformer
model_comm_pgs (ModelCommProcessGroups, optional): Model communication process group.
Default: default_model_comm_pgs.
Returns:
output: Tensor, the output of transformerlayer
"""
def __init__(
self,
config,
model_type=None,
layer_type=None,
self_attn_mask_type=None,
post_norm: bool = True,
pre_process=False,
post_process=False,
drop_path_rate: float = 0.0,
model_comm_pgs=default_model_comm_pgs,
):
super().__init__(config)
if model_type:
raise NotImplementedError("For ParallelTransformer, 'model_type' is not support for now.")
if layer_type:
raise NotImplementedError("For ParallelTransformer, 'layer_type' is not support for now.")
if self_attn_mask_type:
raise NotImplementedError("For ParallelTransformer, 'self_attn_mask_type' is not support for now.")
if pre_process:
raise NotImplementedError("For ParallelTransformer, 'pre_process' is not support for now.")
if post_process:
raise NotImplementedError("For ParallelTransformer, 'post_process' is not support for now.")
if drop_path_rate:
raise NotImplementedError("For ParallelTransformer, 'drop_path_rate' is not support for now.")
self.config = config
self.post_norm = post_norm
self.head_dim = config.hidden_size // config.num_heads
self.num_layers = config.num_layers
self.use_past = config.use_past
self.is_first_iteration = True
self.use_flash_attention = config.use_flash_attention
self.compute_dtype = config.compute_dtype
self.cast = ops.Cast()
self.shape = ops.Shape()
self.freqs_mgr = FreqsMgr(head_dim=self.head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embedding,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
parallel_config=config.parallel_config,
is_dynamic=config.is_dynamic)
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
use_past=config.use_past)
self.tp = model_comm_pgs.tp
self.tp_group_size = self.tp.size
if config.parallel_config.vocab_emb_dp or self.tp_group_size == 1:
self.tok_embeddings = VocabEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
param_init_type=config.param_init_dtype,
param_init="normal",
)
else:
self.tok_embeddings = VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
parallel_config=config.parallel_config,
init_method="normal",
init_type=config.param_init_dtype,
tp_group=self.tp)
self.layers = nn.CellList()
for layer_id in range(config.num_layers):
layer = ParallelTransformerLayer(
config=self.config,
layer_number=layer_id + 1,
model_comm_pgs=model_comm_pgs
)
self.layers.append(layer)
if self.post_norm:
self.norm_out = RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
compute_type=config.layernorm_compute_dtype)
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, prefix_keys_values=None, position_ids=None, attention_mask=None,
q_seq_lens=None, key_cache=None, value_cache=None):
"""
Forward of ParallelTransformer.
Args:
tokens: the tokenized inputs with datatype int32
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables (Tensor[int64]): Store mapping tables for each sequence.
slot_mapping (Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of ParallelTransformer
"""
mask = attention_mask
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill()
if prefix_keys_values is not None:
bs, seq_len = self.shape(tokens)
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.chunk_with_decode(position_ids)
else:
bs, seq_len = self.shape(tokens)
mask = self.casual_mask(tokens)
freqs_cis = self.freqs_mgr(seq_len)
if prefix_keys_values is not None:
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
hidden_states = self.cast(self.tok_embeddings(tokens), self.compute_dtype)
for i in range(self.num_layers):
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
key_cache_i = key_cache[i] if key_cache is not None else None
value_cache_i = value_cache[i] if value_cache is not None else None
hidden_states = self.layers[i](hidden_states, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping,
prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens,
key_cache=key_cache_i, value_cache=value_cache_i)
if self.post_norm:
hidden_states = self.norm_out(hidden_states)
return hidden_states