"""Wizardcoder modules."""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.context import ParallelMode
from mindspore import log as logger
from mindformers.modules.flash_attention import FlashAttention
from mindformers.modules.transformer.moe import default_moe_config
from mindformers.modules.transformer import TransformerEncoderLayer, MultiHeadAttention, \
VocabEmbedding, TransformerOpParallelConfig, EmbeddingOpParallelConfig
from mindformers.modules.transformer.op_parallel_config import default_dpmp_config
from mindformers.modules.layers import Linear, LayerNorm
default_transformer_config = TransformerOpParallelConfig()
default_embedding_parallel_config = EmbeddingOpParallelConfig()
class WizardCoderVocabEmbedding(VocabEmbedding):
def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config,
param_init='normal'):
super(WizardCoderVocabEmbedding, self).__init__(vocab_size, embedding_size, parallel_config, param_init)
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if parallel_config.vocab_emb_dp:
self.gather = P.Gather().shard(((mp, 1), (dp, 1)))
logger.info(f"Using {dp} data parallel for the embedding lookup.")
class MultiQueryAttention(MultiHeadAttention):
r"""
This is an implementation of multi query attention.
Supported Platforms:
``Ascend``
"""
def __init__(self, batch_size,
src_seq_length,
tgt_seq_length,
hidden_size,
num_heads,
compute_dtype,
softmax_compute_type,
param_init_type,
hidden_dropout_rate=0.1,
attention_dropout_rate=0.1,
use_past=False,
use_seq_parallel=False,
use_flash_attention=True,
parallel_config=default_dpmp_config):
super(MultiQueryAttention, self).__init__(batch_size,
src_seq_length,
tgt_seq_length,
hidden_size,
num_heads,
hidden_dropout_rate,
attention_dropout_rate,
compute_dtype,
softmax_compute_type,
param_init_type,
use_past,
parallel_config)
if not self._is_ascend:
raise ValueError("For 'MultiQueryAttention', now only support Ascend")
self.compute_dtype = compute_dtype
self.is_parallel_mode = _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
if use_seq_parallel:
self.projection.shard(strategy_bias=((dp, 1), (1,)),
strategy_matmul=((dp, mp), (mp, 1)),
out_strategy_matmul=((dp * mp, 1),))
logger.info("Enabling matmul recompuation when seq parallel enabled")
self.projection.matmul.add_prim_attr("recompute", True)
self.projection.matmul.add_prim_attr("recompute_comm_op", True)
else:
if use_seq_parallel:
self.dropout.dropout.shard(((dp * mp, 1),))
self.projection.shard(
strategy_bias=((dp * mp, 1), (1,)),
strategy_matmul=((dp, mp), (mp, 1)),
out_strategy_matmul=((dp * mp, 1),))
logger.info("Enabling matmul recompuation when seq parallel enabled")
self.projection.matmul.add_prim_attr("recompute", True)
self.projection.matmul.add_prim_attr("recompute_comm_op", True)
self.batch_matmul = P.BatchMatMul().shard(((dp, mp, 1, 1), (dp, 1, 1, 1)))
self.kv_heads = 1
self.kv_dim = self.kv_heads * self.size_per_head
self.transpose_one_head = P.Transpose().shard(((dp, 1, 1, 1),))
self.tile_for_batch_matmul = P.Tile().shard(((dp, mp, 1, 1),))
self.real_div_one_head = P.RealDiv().shard(((dp, 1, 1, 1), ()))
self.dense1 = Linear(hidden_size,
hidden_size,
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense1.shard(strategy_matmul=((dp, 1), (mp, 1)),
strategy_bias=((dp, mp), (mp,)))
old_mp = parallel_config.model_parallel
parallel_config.model_parallel = 1
self.dense2 = Linear(hidden_size,
self.kv_dim,
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel,)))
self.dense2.weight.parallel_optimizer = False
self.dense3 = Linear(hidden_size,
self.kv_dim,
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel,)))
self.dense3.weight.parallel_optimizer = False
parallel_config.model_parallel = old_mp
self.cast_rec = P.Cast()
self.reshape_rec = P.Reshape()
self.flash_attention_flag = use_flash_attention
if self.flash_attention_flag:
self.flash_attention = FlashAttention(self.size_per_head, attention_dropout_rate, prev_block_num=65536,
next_block_num=0, tiling_stgy_name="sparse",
dp=parallel_config.data_parallel, mp=parallel_config.model_parallel)
self.flash_attention.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
(parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
(parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
(parallel_config.data_parallel, 1, 1),
(parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
self.flash_attention.drop_gen_mask.recompute(False)
self.flash_attention.fill_v2.recompute(False)
self.flash_attention.flash_attention.recompute(False)
self.squeeze = P.Squeeze(1)
logger.info("dp_num = {}, mp_num = {}".format(parallel_config.data_parallel, parallel_config.model_parallel))
logger.info("Using FlashAttention in this round of operation = ", self.flash_attention_flag)
self.reshape = P.Reshape()
self.cast = P.Cast()
self.shape = P.Shape()
self.get_dtype = P.DType()
def set_select_recompute(self):
"""operator select recompute"""
self.batch_matmul.recompute()
self.real_div.recompute()
self.real_div_one_head.recompute()
self.sub.recompute()
self.add.recompute()
self.prob_dropout.dropout.recompute()
self.softmax_3d.softmax.recompute()
self.softmax.softmax.recompute()
self.cast_rec.recompute()
self.mul.recompute()
self.reshape_rec.recompute()
def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
"""Forward process of the MultiQueryAttention"""
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
ori_shape = self.shape(query_tensor)
batch_size = self._get_batch_size_from_query(query_tensor)
query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor)
ori_dtype = self.get_dtype(query_tensor)
query_tensor = self.cast(query_tensor, self.dtype)
key_tensor = self.cast(key_tensor, self.dtype)
value_tensor = self.cast(value_tensor, self.dtype)
query = self.dense1(query_tensor)
key = self.dense2(key_tensor)
value = self.dense3(value_tensor)
query = self.transpose(
self.reshape(
query,
(batch_size, self._get_seq_length_under_incremental(self.src_seq_length),
self.n_head, self.size_per_head)),
(0, 2, 1, 3))
if self.flash_attention_flag:
key = self.transpose_one_head(
self.reshape(
key,
(batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
self.kv_heads, self.size_per_head)),
(0, 2, 1, 3))
else:
key = self.transpose_one_head(
self.reshape(
key,
(batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
self.kv_heads, self.size_per_head)),
(0, 2, 3, 1))
value = self.transpose_one_head(
self.reshape(
value,
(batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
self.kv_heads, self.size_per_head)),
(0, 2, 1, 3))
if attention_mask is not None and self.flash_attention_flag is False and len(self.shape(attention_mask)) == 3:
attention_mask = self.expand_dims(attention_mask, 1)
if attention_mask is not None and self.flash_attention_flag is True and len(self.shape(attention_mask)) == 4:
attention_mask = self.squeeze(attention_mask)
key_present = key
value_present = value
if self.use_past:
if self.is_first_iteration:
valid_length_vector = self.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
else:
valid_length = batch_valid_length - 1
valid_length = self.reshape(valid_length, (-1, 1, 1))
valid_length_vector = self.cast(self.equal(valid_length, self.range), self.dtype)
current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
self.expand_dims(valid_length_vector, 2))
current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
self.expand_dims(valid_length_vector, 3))
key = self.add(key_past, current_key)
value = self.add(value_past, current_value)
key_present = key
value_present = value
attention_mask = self.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
layer_present = (key_present, value_present)
if self.flash_attention_flag:
key = self.tile_for_batch_matmul(key, (1, self.n_head, 1, 1))
value = self.tile_for_batch_matmul(value, (1, self.n_head, 1, 1))
attention = self.flash_attention(query, key, value, attention_mask)
attention = self._merge_heads(attention)
else:
attention = self._attn(query, key, value, attention_mask)
output = self.projection(attention)
output = self.dropout(output)
output = self.reshape(output, ori_shape)
output = self.cast(output, ori_dtype)
return output, layer_present
def _softmax(self, attention_scores):
"""
For the consideration of the performance, do softmax according to different situations
:param attention_scores: a 3d tensor before softmax
:return: the attention scores.
"""
if self._is_ascend and self.softmax_dtype == mstype.float16 or not self._is_ascend:
attention_probs = self.softmax(attention_scores)
else:
shape = self.shape(attention_scores)
attention_probs = self.softmax_3d(
self.reshape_rec(attention_scores,
(shape[0], -1, shape[-1])))
attention_probs = self.reshape_rec(attention_probs, shape)
return attention_probs
def _attn(self, query, key, value, attention_mask):
"""
Get the weighted score along the seq_length
Inputs:
query: the query matrix
key: the key matrix
value: the value matrix
attention_mask: the attention mask matrix with shape (batch_size,
1, seq_length, seq_length)
Outputs:
weighted_values: Tensor, the weighted sum scores
"""
factor = self.cast(self.scale_factor, self.get_dtype(query))
query = self.real_div(query, factor)
key = self.real_div_one_head(key, factor)
query = self.cast(query, self.compute_dtype)
key = self.cast(key, self.compute_dtype)
score = self.batch_matmul(query, key)
ori_dtype = self.get_dtype(score)
attention_scores = self.cast_rec(score, self.softmax_dtype)
if attention_mask is not None:
if self.use_past and not self.is_first_iteration:
bs, *_ = self.shape(query)
tmp = self.not_equal(self.slice(key, (0, 0, 0, 0), (bs, 1, 1, self.seq_length), (1, 1, 1, 1)), 0)
current_index = self.reducesum(self.cast(tmp, mstype.float32), (1, 2, 3))
index = self.sub1(self.cast(current_index, mstype.int32), 1)
index = self.reshape(index, (-1, 1, 1))
attention_mask = self.cast(self.tensor_le(self.range, index), mstype.int32)
attention_mask = self.expand_dims(attention_mask, 2)
multiplu_out = self.sub(
self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast_rec(attention_mask, self.get_dtype(attention_scores)))
adder = self.mul(multiplu_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_probs = self._softmax(attention_scores)
attention_probs = self.cast_rec(attention_probs, ori_dtype)
attention_probs = self.prob_dropout(attention_probs)
attention_probs = self.cast(attention_probs, self.compute_dtype)
value = self.cast(value, self.compute_dtype)
weighted_values = self.batch_matmul(attention_probs, value)
attention_merge = self._merge_heads(weighted_values)
return attention_merge
class WizardCoderTransformerDecoderLayer(TransformerEncoderLayer):
r"""WizardCoder Transformer Decoder Layer.
Args:
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
hidden_size(int): The hidden size of the input.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
num_heads(int): The number of the heads.
seq_length(int): The input sequence length.
attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Should be mstype.float32 or mstype.float16. Default mstype.float32.
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindformers.modules.transformer.FeedForward`. Default: gelu.
use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
words and want to generate the ten more words. We just need to compute the two words' state only once,
and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
In the first step, set the is_first_iteration to be True by
`model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`.
At this moment, pass the single step's input tensor, and loop it. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
with default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
should be [batch_size, 1, hidden_size]
- **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`).
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is
False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size)
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, size_per_head, seq_length),
(batch_size, num_heads, seq_length, size_per_head)).
Supported Platforms:
``Ascend`` ``GPU``
"""
def __init__(self,
batch_size,
hidden_size,
ffn_hidden_size,
num_heads,
seq_length,
compute_dtype,
layernorm_compute_type,
softmax_compute_type,
param_init_type,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
post_layernorm_residual=False,
hidden_act='gelu',
use_past=False,
use_seq_parallel=False,
use_flash_attention=True,
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
super(WizardCoderTransformerDecoderLayer, self).__init__(
batch_size=batch_size,
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_heads=num_heads,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
post_layernorm_residual=post_layernorm_residual,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
compute_dtype=compute_dtype,
hidden_act=hidden_act,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config
)
self.is_first_iteration = True
self.layernorm1 = LayerNorm((hidden_size,), param_init_type=layernorm_compute_type)
self.layernorm2 = LayerNorm((hidden_size,), param_init_type=layernorm_compute_type)
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
if use_seq_parallel:
self.add.shard(((dp * mp, 1), (dp * mp, 1)))
self.layernorm1.shard(((dp * mp, 1),))
self.layernorm2.shard(((dp * mp, 1),))
if not self.use_moe:
self.output.projection.shard(
strategy_bias=((dp * mp, 1), (1,)),
strategy_matmul=((dp, mp), (mp, 1)),
out_strategy_matmul=((dp * mp, 1),))
self.output.dropout.dropout.shard(((dp * mp, 1),))
self.output.projection.matmul.add_prim_attr("recompute_comm_op", True)
self.layernorm1.layer_norm.add_prim_attr("recompute_comm_op", True)
self.layernorm2.layer_norm.add_prim_attr("recompute_comm_op", True)
attention_parallel_config = parallel_config.dpmp if self.use_moe else parallel_config
self.attention = MultiQueryAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
compute_dtype=compute_dtype,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
use_seq_parallel=use_seq_parallel,
use_flash_attention=use_flash_attention,
parallel_config=attention_parallel_config)
self.dtype = compute_dtype
self.reshape = P.Reshape()
self.shape = P.Shape()
self.cast = P.Cast()
self.depend = P.Depend()
if self.use_past:
size_per_head = hidden_size // num_heads
self.key_shape = (batch_size, 1, size_per_head, seq_length)
self.value_shape = (batch_size, 1, seq_length, size_per_head)
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
def construct(self, x, input_mask=None, init_reset=True, batch_valid_length=None):
"""forward process"""
self._check_input(x, input_mask, init_reset, batch_valid_length)
x_shape = self.shape(x)
x = self.reshape(x, (-1, x_shape[-1]))
if self.post_layernorm_residual:
input_x = x
else:
input_x = self.layernorm1(x)
input_x = self.cast(input_x, self.dtype)
key_reset = None
value_reset = None
if self.use_past and self.is_first_iteration:
self.assign(self.key_past, self.mul(self.key_past, self.cast(init_reset, self.dtype)))
key_reset = self.key_past
self.assign(self.value_past, self.mul(self.value_past, self.cast(init_reset, self.dtype)))
value_reset = self.value_past
input_x = self.depend(input_x, key_reset)
input_x = self.depend(input_x, value_reset)
attention, layer_present = self.attention(input_x, input_x, input_x, input_mask,
self.key_past, self.value_past, batch_valid_length)
if self.post_layernorm_residual:
x = self.add(input_x, attention)
else:
x = self.cast(x, self.dtype)
x = self.add(x, attention)
output_x = self.layernorm2(x)
output_x = self.cast(output_x, self.dtype)
aux_loss = None
if self.use_moe:
mlp_logit, aux_loss = self.output(output_x)
else:
mlp_logit = self.output(output_x)
value_update = None
key_update = None
if self.use_past:
key_present, value_present = layer_present
self.assign(self.key_past, key_present)
key_update = self.key_past
self.assign(self.value_past, value_present)
value_update = self.value_past
key_update = self.depend(key_update, key_reset)
value_update = self.depend(value_update, value_reset)
mlp_logit = self.depend(mlp_logit, value_update)
mlp_logit = self.depend(mlp_logit, key_update)
if len(x_shape) == 3:
output_x = self.reshape(output_x, x_shape)
mlp_logit = self.reshape(mlp_logit, x_shape)
x = self.reshape(x, x_shape)
if self.post_layernorm_residual:
output = self.add_3d(output_x, mlp_logit)
output = self.reshape(output, (-1, x_shape[-1]))
output = self.layernorm1(output)
output = self.reshape(output, x_shape)
else:
output = self.add_3d(x, mlp_logit)
else:
if self.post_layernorm_residual:
output = self.add(output_x, mlp_logit)
output = self.layernorm1(output)
else:
output = self.add(x, mlp_logit)
output = self.reshape(output, x_shape)
if self.use_moe:
return output, aux_loss
return output