"""YiZhao Modules."""
from typing import Tuple
import numpy as np
from mindspore import nn, Parameter, Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from yizhao_config import YiZhaoConfig
from mindformers.modules.layers import Linear
from mindformers.version_control import check_rmsnorm_big_kernel_valid, check_valid_big_kernel
class YiZhaoFreqsMgr(nn.Cell):
r"""freqs_cis manager."""
def __init__(self,
dim,
seq_length=None,
rotary_dtype=mstype.float16,
base=10000,
rope_ratio=1.0
):
super().__init__()
self.reshape = P.Reshape()
base = base * rope_ratio
theta = 1.0 / (base ** (np.arange(0, dim, 2, dtype=np.float32) / dim))
seq_idx = np.arange(seq_length, dtype=np.float32)
idx_theta = np.outer(seq_idx, theta).astype(np.float32)
freqs = np.expand_dims(idx_theta, 2)
emb = np.concatenate((freqs, freqs), axis=-1)
emb = emb.reshape(seq_length, dim)
freqs_cos = np.concatenate((np.cos(emb), np.ones_like(emb)), axis=-1)
freqs_sin = np.concatenate((np.sin(emb), np.zeros_like(emb)), axis=-1)
swap_mask = YiZhaoFreqsMgr.get_swap_mask_llama(dim * 2)
self.seq_length = seq_length
self.dim = dim
def rearange(w):
"""
Args:
w: weight
"""
w = np.concatenate(
[
w[..., 0::2],
w[..., 1::2],
],
axis=-1
)
return w
freqs_cos = rearange(freqs_cos)
freqs_sin = rearange(freqs_sin)
self.head_dim = dim
self.freqs_cos = Tensor(freqs_cos, dtype=rotary_dtype)
self.freqs_sin = Tensor(freqs_sin, dtype=rotary_dtype)
self.swap_mask = Tensor(swap_mask, dtype=rotary_dtype)
self.slice = P.StridedSlice().shard(((1, 1),))
self.gather = P.Gather().shard(((1, 1), (1,)))
self.tile = P.Tile().shard(((1, 1),))
def construct(self, seq_length):
"""Construct for YiZhaoFreqsMgr"""
freqs_cos = self.slice(self.freqs_cos, (0, 0), (seq_length, self.head_dim * 2), (1, 1)).reshape(
(self.seq_length, 1, 2 * self.dim))
freqs_sin = self.slice(self.freqs_sin, (0, 0), (seq_length, self.head_dim * 2), (1, 1)).reshape(
(self.seq_length, 1, 2 * self.dim))
return freqs_cos, freqs_sin, self.swap_mask
def prefill(self):
""" Prefill for YiZhaoFreqsMgr """
return self.freqs_cos, self.freqs_sin, self.swap_mask
def increment(self, batch_valid_length):
""" Increment for YiZhaoFreqsMgr"""
indices = batch_valid_length - 1
freqs_cos = self.gather(self.freqs_cos, indices, 0)
freqs_sin = self.gather(self.freqs_sin, indices, 0)
return freqs_cos, freqs_sin, self.swap_mask
@staticmethod
def get_swap_mask_llama(head_dim):
"""Swap matrix"""
zero_block = np.zeros((head_dim // 2, head_dim // 2), dtype=np.float32)
id_block = np.identity(head_dim // 2, dtype=np.float32)
return np.block([[zero_block, id_block], [-id_block, zero_block]])
class YiZhaoRMSNorm(nn.Cell):
r"""
A self-defined RMSNorm operation using reduce mean.
Args:
dim (tuple): The shape of the input tensor
eps (float): The epsilon value of the denominator. Default 1e-5.
param_init_type: The param init type.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
Outputs:
Tensor of shape :math:`(batch, seq_length, hidden_size)`.
"""
def __init__(self, dim, eps=1e-6, param_init_type=mstype.float32, compute_type=mstype.float32):
super(YiZhaoRMSNorm, self).__init__()
self.eps = Tensor(float(eps), dtype=param_init_type)
if param_init_type == mstype.bfloat16:
self.weight = Parameter(initializer('ones', (dim,), dtype=mstype.float32).astype(mstype.bfloat16))
else:
self.weight = Parameter(initializer('ones', (dim,), dtype=param_init_type))
self.compute_type = compute_type
if not check_rmsnorm_big_kernel_valid():
self.square = P.Square()
self.mean = P.ReduceMean(keep_dims=True)
self.add = P.Add()
self.rsqrt = P.Rsqrt()
self.mul = P.Mul()
self.mul2 = P.Mul()
self.rms_norm = self._self_norm
self.self_define = True
else:
self.norm = P.RmsNorm(float(eps))
self.rms_norm = self._rms_norm
self.self_define = False
self.cast = P.Cast()
self.rcast = P.Cast()
def _self_norm(self, x):
""" Self norm for RMSNorm"""
original_type = x.dtype
norm_factor = self.square(self.cast(x, self.compute_type))
norm_factor = self.mean(norm_factor, -1)
norm_factor = self.add(norm_factor, self.eps)
norm_factor = self.rsqrt(norm_factor)
output = self.mul(x, self.cast(norm_factor, original_type))
output = self.mul2(output, self.cast(self.weight, original_type))
return output
def _rms_norm(self, x):
"""Rms norm"""
original_type = x.dtype
x = self.cast(x, self.compute_type)
output = self.norm(x, self.cast(self.weight, self.compute_type))[0]
return self.rcast(output, original_type)
def construct(self, x):
"""Forward of RMSNorm."""
return self.rms_norm(x)
def shard(self, strategy):
"""Parallel strategy configuratiuon interface."""
if self.self_define:
self.square.shard(strategy)
self.mean.shard(strategy)
self.rsqrt.shard(strategy)
self.add.shard((strategy[0], ()))
self.mul.shard((strategy[0], strategy[0]))
self.mul2.shard((strategy[0], (1,)))
else:
self.norm.shard((strategy[0], (1,)))
class YiZhaoSiLU(nn.Cell):
r"""
A self-defined SwiGlu.
Inputs:
- **x** (Tensor) - Tensor.
Outputs:
Tensor. x = x * sigmod(x).
"""
def __init__(self):
super(YiZhaoSiLU, self).__init__()
if check_valid_big_kernel():
self.silu = P._inner_ops.SiLU()
self.self_define = False
else:
self.sigmoid = P.Sigmoid()
self.mul = P.Mul()
self.silu = self._self_silu
self.self_define = True
def shard(self, strategy):
"""sharding for SiLU"""
if self.self_define:
self.sigmoid.shard(strategy)
self.mul.shard((strategy[0], strategy[0]))
else:
self.silu.shard(strategy)
def _self_silu(self, x):
""" mul sigmoid for SiLU"""
return self.mul(x, self.sigmoid(x))
def construct(self, x):
"""Construct for SiLU"""
return self.silu(x)
class YiZhaoSwiGLU4ConcatMLP(nn.Cell):
"""SwiGLU activation function for MLP concat."""
def __init__(self):
super(YiZhaoSwiGLU4ConcatMLP, self).__init__()
self.split = P.Split(axis=-1, output_num=2)
self.silu = YiZhaoSiLU()
self.mul = P.Mul()
def construct(self, x):
"""construct for SwiGLU activation function"""
x0, x1 = self.split(x)
return self.mul(self.silu(x0), x1)
def shard(self, strategy):
"""Shard for SwiGLU"""
self.split.shard(strategy)
self.silu.shard(strategy)
self.mul.shard((strategy[0], strategy[0]))
class YiZhaoSwiGLU(nn.Cell):
"""SwiGLU activation function."""
def __init__(self):
super(YiZhaoSwiGLU, self).__init__()
self.silu = YiZhaoSiLU()
self.mul = P.Mul()
def construct(self, left, right):
"""construct for SwiGLU activation function"""
return self.mul(self.silu(left), right)
def shard(self, strategy):
"""Shard for SwiGLU"""
self.silu.shard(strategy)
self.mul.shard(strategy * 2)
class YiZhaoConcatMLP(nn.Cell):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config: YiZhaoConfig):
super(YiZhaoConcatMLP, self).__init__()
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.add_bias = config.add_bias_linear
self.dense_h_to_4h = Linear(
config.hidden_size,
config.ffn_hidden_size * 2,
has_bias=self.add_bias,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
)
self.dense_h_to_4h.shard(
strategy_matmul=((dp * cp, 1), (mp, 1)),
strategy_bias=((dp * cp, mp),
(mp,)))
self.activation_func = YiZhaoSwiGLU4ConcatMLP()
self.activation_func.shard(((dp, cp, 1),))
self.dense_4h_to_h = Linear(
config.ffn_hidden_size,
config.hidden_size,
has_bias=self.add_bias,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
)
self.dense_4h_to_h.shard(
strategy_matmul=((dp * cp, mp), (1, mp)),
strategy_bias=((dp * cp, 1), (1,)))
def construct(self, hidden_states):
"""Construct of concat MLP"""
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
output = self.dense_4h_to_h(intermediate_parallel)
return output
class YiZhaoMLP(nn.Cell):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config: YiZhaoConfig):
super(YiZhaoMLP, self).__init__()
self.add_bias = config.add_bias_linear
self.dense_left = Linear(
config.hidden_size,
config.ffn_hidden_size,
has_bias=self.add_bias,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype
)
self.dense_right = Linear(
config.hidden_size,
config.ffn_hidden_size,
has_bias=self.add_bias,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype
)
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.dense_left.shard(strategy_matmul=((dp * cp, 1), (mp, 1)), strategy_bias=((dp * cp, mp), (mp,)))
self.dense_right.shard(strategy_matmul=((dp * cp, 1), (mp, 1)), strategy_bias=((dp * cp, mp), (mp,)))
self.activation_func = YiZhaoSwiGLU()
self.activation_func.shard(((dp, cp, mp),))
self.dense_4h_to_h = Linear(
config.ffn_hidden_size,
config.hidden_size,
has_bias=self.add_bias,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype,
)
self.dense_4h_to_h.shard(strategy_matmul=((dp * cp, mp), (1, mp)), strategy_bias=((dp * cp, 1), (1,)))
def construct(self, hidden_states):
"""Construct for MLP"""
intermediate_left = self.dense_left(hidden_states)
intermediate_right = self.dense_right(hidden_states)
intermediate_parallel = self.activation_func(intermediate_left, intermediate_right)
output = self.dense_4h_to_h(intermediate_parallel)
return output
class YiZhaoRotaryEmbedding(nn.Cell):
"""Embedding Layer"""
def __init__(self, compute_dtype=mstype.float32):
super(YiZhaoRotaryEmbedding, self).__init__()
self.mul = P.Mul()
self.bmm_swap = P.BatchMatMul()
self.add = P.Add()
self.cast = P.Cast()
self.shape = P.Shape()
self.dtype = compute_dtype
def construct(self,
query: Tensor,
key: Tensor,
rotary_pos_emb: Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
"""Construct for Embedding Layer"""
freqs_cos, freqs_sin, swap_mask = rotary_pos_emb
original_dtype = query.dtype
query = self.cast(query, self.dtype)
key = self.cast(key, self.dtype)
query = self.add(self.mul(query, freqs_cos),
self.mul(self.bmm_swap(query, swap_mask), freqs_sin))
key = self.add(self.mul(key, freqs_cos),
self.mul(self.bmm_swap(key, swap_mask), freqs_sin))
query = self.cast(query, original_dtype)
key = self.cast(key, original_dtype)
return query, key
def shard(self, strategy):
"""sharding for Embedding"""
self.add.shard((strategy, strategy))
self.bmm_swap.shard((strategy, (1, 1)))
self.mul.shard((strategy, (1, 1)))
class YiZhaoRotaryEmbeddingOpt(nn.Cell):
"""YiZhaoRotaryEmbeddingOpt"""
def __init__(self, compute_dtype=mstype.float32):
super(YiZhaoRotaryEmbeddingOpt, self).__init__()
self.mul = P.Mul()
self.bmm_swap = P.BatchMatMul()
self.add = P.Add()
self.cast = P.Cast()
self.shape = P.Shape()
self.dtype = compute_dtype
def construct(self,
query: Tensor,
key: Tensor,
rotary_pos_emb: Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
"""construct"""
freqs_cos, freqs_sin, swap_mask = rotary_pos_emb
_, s, _, d = query.shape
original_dtype = query.dtype
query = self.cast(query, self.dtype)
key = self.cast(key, self.dtype)
freqs_cos = F.reshape(freqs_cos, (s, 1, d))
freqs_sin = F.reshape(freqs_sin, (s, 1, d))
query = self.add(self.mul(query, freqs_cos),
self.mul(self.bmm_swap(query, swap_mask), freqs_sin))
key = self.add(self.mul(key, freqs_cos),
self.mul(self.bmm_swap(key, swap_mask), freqs_sin))
query = self.cast(query, original_dtype)
key = self.cast(key, original_dtype)
return query, key
def shard(self, strategy):
_, cp, _, _ = strategy
self.add.shard((strategy, strategy))
self.bmm_swap.shard((strategy, (1, 1)))
self.mul.shard((strategy, (cp, 1, 1)))
class GetCompressMask(nn.Cell):
"""GetCompressMask"""
def __init__(self, mask_length, parallel_config):
super(GetCompressMask, self).__init__()
self.mask_length = mask_length
tril_dev = np.tril(np.ones((self.mask_length, self.mask_length), dtype=np.int8))
attention_mask = np.ones((self.mask_length, self.mask_length), dtype=np.int8)
attention_mask = attention_mask - tril_dev
self.attention_mask = Tensor(attention_mask, dtype=mstype.uint8)
self.cast = P.Cast()
def construct(self, sequence_start_ids):
mask = self.cast(self.attention_mask, mstype.uint8)
return mask
class GetEodResetMask(nn.Cell):
"""GetEodResetMask"""
def __init__(self, seq_length, parallel_config):
super(GetEodResetMask, self).__init__()
dp = parallel_config.data_parallel
self.seq_length = seq_length
self.expand_dims = P.ExpandDims().shard(((dp, 1),))
self.tile = P.Tile().shard(((dp, 1, 1),))
self.equal = P.Equal().shard(((dp, 1, 1,), (dp, 1, 1)))
self.tril_op = P.Tril().shard(((dp, 1, 1,),))
self.sub = P.Sub().shard(((), (dp, 1, 1),))
def construct(self, eod_vec):
eod_vec_row = self.expand_dims(eod_vec, 1)
eod_vec_column = self.expand_dims(eod_vec, 2)
eod_matrix_1 = self.tile(eod_vec_row, (1, self.seq_length, 1))
eod_matrix_2 = self.tile(eod_vec_column, (1, 1, self.seq_length))
eod_matrix = self.equal(eod_matrix_1, eod_matrix_2)
eod_matrix = F.cast(eod_matrix, mstype.uint8)
mask = self.tril_op(eod_matrix)
mask = self.sub(1, mask)
return mask