""" For transformer """
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn, ops, mint, Tensor
from mindformers.experimental.parallel_core.pynative.utils import divide
from mindformers.experimental.infer.core.layers import ColumnParallelLinear, RowParallelLinear
from mindformers.experimental.infer.core.transformer import \
ParallelAttention, ParallelTransformerLayer, ParallelTransformer, ParallelMLP
from mindformers.modules.layers import FreqsMgrDynamicNTK
from mindformers.tools.logger import logger
class TelechatParallelMLP(ParallelMLP):
r"""
Implementation of parallel feedforward block.
Args:
config (dict): Configuration.
is_expert (book): This block is an expert block. Default: False.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Outputs:
- **output** (Tensor) - Output tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(self, config, is_expert=False):
super().__init__(config, is_expert)
self.w2 = RowParallelLinear(
self.ffn_hidden_size,
self.hidden_size,
input_is_parallel=True,
config=self.config.parallel_config,
bias=True,
transpose_b=True,
is_expert=is_expert,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
def construct(self, x):
""" Construct function of mlp block. """
if self.mlp_has_gate:
if self.ffn_concat:
gate_hidden_out = self.w_gate_hidden(x)
gate, hidden = mint.split(gate_hidden_out,
(self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1)
else:
gate = self.w1(x)
hidden = self.w3(x)
gate = self.act_func(gate)
hidden = mint.mul(hidden, gate)
else:
hidden = self.w1(x)
hidden = self.act_func(hidden)
output = self.w2(hidden)
return output
class TelechatParallelAttention(ParallelAttention):
r"""
Parallel attention block.
Args:
layer_index (int): Number which indicates the index of this transformer layer in the
whole transformer block.
config (dict): Configuration.
attn_type (str): Attention type. Support ['self_attn', 'cross_attn']. Default: 'self_attn'.
Inputs:
- **hidden_states** (Tensor) - Tensor of shape :math:`(B, S, H)`.
- **attention_mask** (Tensor) - Tensor of attention mask.
- **encoder_output** (Tensor) - Tensor of encoder output used for cross attention. Default: None.
- **rotary_pos_emb** (Tensor) - Tensor of rotary position embedding. Default: None.
Outputs:
- **output** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def construct(self, x, batch_valid_length, block_tables, slot_mapping, freqs_cis=None,
attn_mask=None, alibi_mask=None, prefix_keys_values=None, encoder_output=None,
key_cache=None, value_cache=None):
"""Construct function of attention block."""
ori_dtype = x.dtype
bs, seq_len, _ = x.shape
if self.attn_type == "self_attn":
if self.sequence_parallel:
seq_len = seq_len * self.tp_group_size
if self.qkv_concat:
qkv = self.cast(self.w_qkv(x), self.compute_dtype)
query, key, value = mint.split(qkv,
(self.hidden_size_per_partition,
self.kv_hidden_size_per_partition,
self.kv_hidden_size_per_partition), -1)
else:
query = self.cast(self.wq(x), self.compute_dtype)
key_value = self.cast(self.wk_v(x), self.compute_dtype)
key_value = key_value.reshape(-1, self.kv_num_heads_per_partition, self.head_dim * 2)
key, value = mint.split(key_value, (self.head_dim, self.head_dim), -1)
key = key.reshape(bs, seq_len, -1)
value = value.reshape(bs, seq_len, -1)
else:
query = self.cast(self.wq(x), self.compute_dtype)
if self.qkv_concat:
kv = self.cast(self.w_kv(encoder_output), self.compute_dtype)
key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1)
else:
key = self.cast(self.wk(encoder_output), self.compute_dtype)
value = self.cast(self.wv(encoder_output), self.compute_dtype)
if self.use_past:
if freqs_cis is not None:
query, key = self.rotary_embedding(query, key, freqs_cis, batch_valid_length)
if prefix_keys_values is not None:
prefix_len = prefix_keys_values.shape[2]
slot_mapping = slot_mapping + self.cast(mint.ne(slot_mapping, -1), mstype.int32) * prefix_len
if self.is_first_iteration:
key, value = self._cat_prefix(key, value, prefix_keys_values)
key_out = self.paged_attention_mgr(key, value, slot_mapping, batch_valid_length,
key_cache=key_cache, value_cache=value_cache)
query = ops.depend(query, key_out)
if self.is_first_iteration:
if self.use_flash_attention:
if self.is_pynative:
context_layer = self.flash_attention(query, key, value, attn_mask, alibi_mask)
else:
bs, seq_len, _ = query.shape
query = self.reshape(query, (-1, self.num_heads_per_partition * self.head_dim))
key = self.reshape(key, (-1, self.kv_num_heads_per_partition * self.head_dim))
value = self.reshape(value, (-1, self.kv_num_heads_per_partition * self.head_dim))
context_layer = self.flash_attention(query, key, value, attn_mask, alibi_mask, None, None,
batch_valid_length, batch_valid_length)
context_layer = self.reshape(context_layer, (bs, seq_len,
self.num_heads_per_partition * self.head_dim))
else:
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))
key = key.reshape(bs, seq_len, -1, self.head_dim)
value = value.reshape(bs, seq_len, -1, self.head_dim)
if self.use_gqa:
repeat_num = self.num_heads_per_partition - self.kv_num_heads_per_partition
key = self._repeat_kv(key, repeat_num)
value = self._repeat_kv(value, repeat_num)
else:
key = key.transpose((0, 2, 1, 3))
value = value.transpose((0, 2, 1, 3))
context_layer = self.core_attention(query, key, value, attn_mask)
else:
context_layer = self.paged_attention_mgr.paged_attn(query, batch_valid_length, block_tables,
key_cache=key_cache, value_cache=value_cache)
else:
query = query.reshape(bs, seq_len, -1, self.head_dim).transpose((0, 2, 1, 3))
key = key.reshape(bs, seq_len, -1, self.head_dim)
value = value.reshape(bs, seq_len, -1, self.head_dim)
if self.use_gqa:
repeat_num = self.num_heads_per_partition - self.kv_num_heads_per_partition
key = self._repeat_kv(key, repeat_num)
value = self._repeat_kv(value, repeat_num)
else:
key = key.transpose((0, 2, 1, 3))
value = value.transpose((0, 2, 1, 3))
if freqs_cis is not None:
query, key = self.apply_rotary_emb(query, key, freqs_cis)
context_layer = self.core_attention(query, key, value, attn_mask)
output = self.wo(context_layer)
output = self.cast(output, ori_dtype)
return output
def _init_self_attn(self):
"""init qkv linears of self-attention"""
if self.qkv_concat:
self.w_qkv = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
self.hidden_size_per_partition = divide(self.hidden_size, self.tp_group_size)
self.kv_hidden_size_per_partition = divide(self.kv_hidden_size, self.tp_group_size)
else:
self.wq = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
self.wk_v = ColumnParallelLinear(
self.hidden_size,
self.kv_hidden_size * 2,
config=self.config.parallel_config,
bias=self.config.qkv_has_bias,
gather_output=False,
transpose_b=True,
param_init_type=self.config.param_init_dtype,
compute_dtype=self.config.compute_dtype,
)
self.kv_hidden_size_per_partition = divide(self.kv_hidden_size, self.tp_group_size)
class TelechatParallelTransformerLayer(ParallelTransformerLayer):
r"""
Single parallel transformer layer.
Args:
config (dict): Configuration.
layer_index (int): Number which indicates the index of this transformer layer in the
whole transformer block.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(B, S, H)`.
- **attention_mask** (Tensor) - Tensor of attention mask.
- **rotary_pos_emb** (Tensor) - Tensor of rotary position embedding. Default: None.
Outputs:
- **output** (Tensor) - Tensor of shape :math:`(B, S, H)`.
Supported Platforms:
``Ascend``
"""
def __init__(
self,
config,
layer_number: int,
layer_type=None,
self_attn_mask_type=None,
drop_path_rate: float = 0.0,
):
super().__init__(
config,
layer_number,
layer_type,
self_attn_mask_type,
drop_path_rate
)
self.attention = TelechatParallelAttention(config, layer_number)
self.feed_forward = TelechatParallelMLP(config)
class TelechatParallelTransformer(ParallelTransformer):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParallelTransformerLayer`]
Args:
config: the config of transformer
Returns:
output: Tensor, the output of transformerlayer
"""
def __init__(
self,
config,
model_type=None,
layer_type=None,
self_attn_mask_type=None,
post_norm: bool = True,
pre_process=False,
post_process=False,
drop_path_rate: float = 0.0
):
super().__init__(
config,
model_type,
layer_type,
self_attn_mask_type,
post_norm,
pre_process,
post_process,
drop_path_rate
)
self.enable_dynamic_ntk = False
if config.extend_method == 'DYNAMIC_NTK':
self.enable_dynamic_ntk = True
self.freqs_mgr = FreqsMgrDynamicNTK(head_dim=self.head_dim,
max_position_embedding=config.max_position_embedding,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
parallel_config=config.parallel_config,
is_dynamic=config.is_dynamic)
logger.info("Running with dynamic NTK.")
self.layers = nn.CellList()
for layer_id in range(config.num_layers):
layer = TelechatParallelTransformerLayer(config=self.config, layer_number=layer_id + 1)
self.layers.append(layer)
def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, prefix_keys_values=None, key_cache=None, value_cache=None):
"""
Forward of ParallelTransformer.
Args:
tokens: the tokenized inputs with datatype int32
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables (Tensor[int64]): Store mapping tables for each sequence.
slot_mapping (Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of ParallelTransformer
"""
bs, seq_len = self.shape(tokens)
mask = None
if self.use_past:
if self.is_first_iteration:
if self.enable_dynamic_ntk:
freqs_cis = self.freqs_mgr.prefill(bs, batch_valid_length.max())
else:
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
if self.is_pynative:
mask = self.casual_mask(tokens)
else:
mask = self.casual_mask.prefill()
if prefix_keys_values is not None:
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
else:
mask = self.casual_mask(tokens)
freqs_cis = self.freqs_mgr(seq_len)
if prefix_keys_values is not None:
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
hidden_states = self.cast(self.tok_embeddings(tokens), self.compute_dtype)
for i in range(self.num_layers):
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
key_cache_i = key_cache[i] if key_cache is not None else None
value_cache_i = value_cache[i] if value_cache is not None else None
hidden_states = self.layers[i](hidden_states, freqs_cis, mask, batch_valid_length=batch_valid_length,
block_tables=block_tables, slot_mapping=slot_mapping,
prefix_keys_values=prefix_kv,
key_cache=key_cache_i, value_cache=value_cache_i)
if self.post_norm:
hidden_states = self.norm_out(hidden_states)
return hidden_states