"""yizhao model."""
import copy
import numpy as np
from mindpet.delta.ptuning2 import PrefixEncoder
from mindspore import Tensor, mint, ops
from mindspore.common import dtype as mstype
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from yizhao_config import YiZhaoConfig
from yizhao_loss import DPOLoss, DPOCrossEntropy
from yizhao_modules import YiZhaoFreqsMgr
from yizhao_transformer import YiZhaoTransformer
from mindformers.core.loss import CrossEntropyLoss
from mindformers.models.llama.llama_layer import LlamaEmbedding
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import lazy_inline
from mindformers.modules.layers import Linear
from mindformers.modules.transformer.transformer import LowerTriangularMaskWithDynamic
from mindformers.pet.tuners.pet_adapter import PetAdapter
from mindformers.tools.logger import logger
from mindformers.tools.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode
from mindformers.version_control import get_dropout
from mindformers.version_control import get_tril
class YiZhaoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = YiZhaoConfig
base_model_prefix = "YiZhao"
class YiZhaoModel(YiZhaoPreTrainedModel):
r"""
The backbone of YiZhao network
Args:
config (YiZhaoConfig): The config of network.
"""
def __init__(self, config: YiZhaoConfig, **kwargs):
super(YiZhaoModel, self).__init__(config, **kwargs)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.seq_length = config.seq_length
self.compute_dtype = config.compute_dtype
self.use_past = config.use_past
self.use_flash_attention = config.use_flash_attention
self.is_first_iteration = True
self.use_llama_rope = config.use_llama_rope
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_past=config.use_past)
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.embedding = LlamaEmbedding(vocab_table_size=config.vocab_size, embedding_size=config.hidden_size,
param_init_type=config.param_init_type, parallel_optimizer=True)
rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.freqs_mgr = YiZhaoFreqsMgr(
dim=rotary_dim // 2,
seq_length=config.seq_length,
rotary_dtype=config.rotary_dtype,
base=10000,
rope_ratio=config.rope_ratio
)
self.encoder = YiZhaoTransformer(config)
self.output_layer = Linear(config.hidden_size,
config.vocab_size,
has_bias=False,
param_init_type=config.param_init_type,
compute_dtype=config.compute_dtype)
if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()):
self.embedding.pipeline_stage = 0
if config.parallel_config.pipeline_stage > 1:
self.embedding.set_comm_fusion(2)
self.output_layer.pipeline_stage = config.parallel_config.pipeline_stage - 1
else:
self.embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.embedding.shard(config.parallel_config)
if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0):
self.output_layer.shard(strategy_matmul=((dp * cp, 1), (1, 1)))
else:
self.output_layer.shard(strategy_matmul=((1, 1), (dp * cp * mp, 1)))
if dp * cp >= 32 and dp * cp % 8 == 0:
self.output_layer.shard(strategy_matmul=((dp * cp * mp // 8, 1), (8, 1)))
self.tril = get_tril()
self.ones = P.Ones()
self.less = P.Less()
self.gather = P.Gather()
self.expand_dims = P.ExpandDims()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.tile = ops.Tile()
low_triangle = np.tril(np.ones((1, self.seq_length, self.seq_length)))
self.low_triangle = Tensor(low_triangle, mstype.int32)
self.cast = P.Cast()
self.dropout = get_dropout(config.hidden_dropout)
self.dropout.dropout.recompute(False)
def get_masks(self, batch_size, padding_mask=None, input_position=None):
"""Get attention mask."""
low_triangle = self.tile(self.low_triangle, (batch_size, 1, 1))
if padding_mask is not None:
low_triangle = self.mul(low_triangle, self.expand_dims(padding_mask, 1))
if self.use_past and padding_mask is not None:
low_triangle -= self.expand_dims(padding_mask, -1) - 1
attention_mask = self.less(low_triangle, 0.5)
if self.use_past and not self.is_first_iteration:
attention_mask = self.gather(attention_mask.view(-1, self.seq_length), input_position, 0)
attention_mask = self.reshape(attention_mask, (batch_size, 1, -1, self.seq_length))
return attention_mask
def construct(self, input_ids, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, batch_valid_length=None, full_attention_mask=None, prefix_key_values=None,
block_tables=None, slot_mapping=None):
"""YiZhao model."""
_ = position_ids
_, seq_len = input_ids.shape
mask = None
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill()
mask = self.casual_mask.prefill()
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
else:
freqs_cis = self.freqs_mgr(seq_len)
if full_attention_mask is None:
full_attention_mask = attention_mask
mask = full_attention_mask
if input_embeds is None:
input_embeds = self.embedding(input_ids)
input_embeds = self.dropout(input_embeds)
hidden_states = self.encoder(
input_embeds, mask, freqs_cis,
batch_valid_length=batch_valid_length, prefix_key_values=prefix_key_values, block_tables=block_tables,
slot_mapping=slot_mapping)
return hidden_states
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class YiZhaoForCausalLM(YiZhaoPreTrainedModel):
r"""
Provide gpt training loss or logits through network.
Args:
config (YiZhaoConfig): The config of YiZhaoModel.
Returns:
Tensor, the loss or logits of the network.
"""
@lazy_inline
def __init__(self, config: YiZhaoConfig, **kwargs):
super(YiZhaoForCausalLM, self).__init__(config, **kwargs)
self.transformer = YiZhaoModel(config=config)
self.cast = P.Cast()
self.gather = P.Gather()
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0):
self.loss = CrossEntropyLoss(parallel_config=config.parallel_config)
else:
loss_parallel_config = copy.deepcopy(config.parallel_config)
loss_parallel_config.model_parallel = dp * mp * cp
loss_parallel_config.data_parallel = 1
loss_parallel_config.context_parallel = 1
if dp * cp >= 32 and dp * cp % 8 == 0:
loss_parallel_config.model_parallel = 8
loss_parallel_config.data_parallel = dp * mp * cp // 8
self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config)
self.gmask = config.gmask_token_id
self.bos_token_id = config.bos_token_id
self.use_past = config.use_past
self.rl_config = config.rl_config
self.is_first_iteration = True
self.not_equal = P.NotEqual()
self.add = P.Add()
self.reshape = P.Reshape()
self.load_checkpoint(config)
self.vocab_size = config.padded_vocab_size
self.predict_run_mode = get_predict_run_mode()
self.sub_batch_valid_len = P.Sub()
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""Get YiZhao model input tuple for transform ckpt."""
input_ids = Tensor(input_ids, mstype.int32)
labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None
bs = input_ids.shape[0]
slot_mapping = Tensor(np.ones(shape=tuple([bs])), mstype.int32)
batch_valid_length = Tensor(np.ones(shape=tuple([bs])), mstype.int32)
return (input_ids, labels, None, None, None, None, None, None, batch_valid_length,
None, None, slot_mapping, None, None)
def set_dynamic_inputs(self, **kwargs):
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, None,
dynamic_batch_valid_length, None, dynamic_block_tables, dynamic_slot_mapping, None, None)
logger.info("Set dynamic input for YiZhao.")
def add_flags_custom(self, is_first_iteration):
"""Add customized attributes for specific cells in the model."""
self.add_flags(is_first_iteration=is_first_iteration)
self.transformer.add_flags(is_first_iteration=is_first_iteration)
for layer in self.transformer.encoder.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.self_attention.infer_attention.add_flags(is_first_iteration=is_first_iteration)
layer.self_attention.infer_attention.rotary_embedding.add_flags(is_first_iteration=is_first_iteration)
def construct(self, input_ids=None, labels=None, attention_mask=None, input_mask=None, input_position=None,
position_ids=None, input_embeds=None, init_reset=None, batch_valid_length=None,
prefix_key_values=None,
block_tables=None, slot_mapping=None, batch_index=None, zactivate_len=None):
"""YiZhao for conditional generation model."""
bs, seq_len = input_ids.shape
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
hidden_states = self.transformer(
input_ids=input_ids,
input_position=input_position,
position_ids=position_ids,
attention_mask=attention_mask,
input_embeds=input_embeds,
batch_valid_length=batch_valid_length,
prefix_key_values=prefix_key_values,
block_tables=block_tables,
slot_mapping=slot_mapping
)
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
if pre_gather:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
hidden_states = self.gather(hidden_states, self.sub_batch_valid_len(batch_valid_length, 1), 1)
lm_logits = self.transformer.output_layer(hidden_states)
outputs = (lm_logits,)
if self.rl_config is not None:
return lm_logits
if labels is not None:
logits = lm_logits.to(mstype.float32)
labels = labels.reshape((-1,))
logits = logits.reshape((-1, logits.shape[-1]))
if input_mask is None:
input_mask = self.not_equal(labels, -100).to(mstype.float32)
input_mask = input_mask.reshape((-1,))
if self.training:
outputs = self.loss(logits, labels, input_mask)
else:
zeros = ops.zeros((bs, 1, self.vocab_size), dtype=logits.dtype)
logits = logits.reshape((bs, seq_len, self.vocab_size))
logits = ops.cat((logits, zeros), axis=1)
zeros = ops.zeros((bs, 1), dtype=labels.dtype)
labels = labels.reshape((bs, seq_len))
labels = ops.cat((zeros, labels), axis=1)
zeros = zeros.to(input_mask.dtype)
input_mask = input_mask.reshape((bs, seq_len))
input_mask = ops.cat((zeros, input_mask), axis=1)
outputs = logits, labels, input_mask
if not self.training:
lm_logits = self.cast(lm_logits, mstype.float32)
if self.predict_run_mode:
lm_logits = self.reshape(lm_logits, (-1, lm_logits.shape[-1]))
outputs = (lm_logits,)
return outputs
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class YiZhaoDPO(YiZhaoForCausalLM):
"""
YiZhao Model for pretraining with DPO
Args:
config (YiZhaoConfig): The config of network.
"""
@lazy_inline
def __init__(self, config: YiZhaoConfig, **kwargs):
super().__init__(config, **kwargs)
dp = config.parallel_config.data_parallel
mp = config.parallel_config.model_parallel
if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0):
self.dpo_loss = DPOLoss(config)
else:
loss_parallel_config = copy.deepcopy(config)
loss_parallel_config.parallel_config.model_parallel = dp * mp
loss_parallel_config.parallel_config.data_parallel = 1
if dp >= 32 and dp % 8 == 0:
loss_parallel_config.parallel_config.model_parallel = 8
loss_parallel_config.parallel_config.data_parallel = dp * mp // 8
self.dpo_loss = DPOLoss(loss_parallel_config)
self.alpha = config.alpha
self.beta = config.beta
if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0):
self.sft_loss = DPOCrossEntropy(parallel_config=config.parallel_config)
else:
loss_parallel_config = copy.deepcopy(config.parallel_config)
loss_parallel_config.model_parallel = dp * mp
loss_parallel_config.data_parallel = 1
if dp >= 32 and dp % 8 == 0:
loss_parallel_config.model_parallel = 8
loss_parallel_config.data_parallel = dp * mp // 8
self.sft_loss = DPOCrossEntropy(parallel_config=loss_parallel_config)
def construct(self, input_ids, labels=None,
attention_mask=None, loss_mask=None,
ref_chosen_logps=None, ref_rejected_logps=None,
input_position=None, position_ids=None, input_embeds=None,
batch_valid_length=None, block_tables=None, slot_mapping=None, prefix_key_values=None):
"""YiZhao for conditional generation model."""
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
hidden_states = self.transformer(
input_ids=input_ids,
input_position=input_position,
position_ids=position_ids,
attention_mask=attention_mask,
input_embeds=input_embeds,
batch_valid_length=batch_valid_length,
prefix_key_values=prefix_key_values,
block_tables=block_tables,
slot_mapping=slot_mapping
)
policy_logits = self.transformer.output_layer(hidden_states)
policy_logits = self.cast(policy_logits, mstype.float32)
dpo_loss, sft_loss = self.dpo_loss(policy_logits, labels, loss_mask, ref_chosen_logps.reshape((-1,)),
ref_rejected_logps.reshape((-1,)))
return self.alpha * dpo_loss + self.beta * sft_loss
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class YiZhaoWithPtuning2(YiZhaoForCausalLM):
"""
YiZhao Model for pretraining with p-tuning-v2
Args:
config (YiZhaoConfig): The config of network.
"""
def __init__(self, config: YiZhaoConfig = None, **kwargs):
ckpt_cfg = config.checkpoint_name_or_path
config.checkpoint_name_or_path = None
config.pre_seq_len = config.pet_config.pre_seq_len
super().__init__(config, **kwargs)
self.use_past = config.use_past
config.pet_config.num_layers = config.num_layers
config.pet_config.kv_channels = config.kv_channels
if config.multi_query_attention:
config.pet_config.num_heads = config.multi_query_group_num
else:
config.pet_config.num_heads = config.num_attention_heads
self.prefix_encoder = PrefixEncoder(
config.pet_config.pre_seq_len,
config.pet_config.num_layers,
config.pet_config.num_heads,
config.pet_config.kv_channels,
config.pet_config.prefix_projection,
config.pet_config.projection_dim,
config.pet_config.dropout_prob
)
if ckpt_cfg:
config.checkpoint_name_or_path = ckpt_cfg
self.load_checkpoint(config)
PetAdapter.freeze_pretrained_model(self, config.pet_config.pet_type)
def construct(self, input_ids=None, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=True, batch_valid_length=None, prefix_key_values=None,
block_tables=None, slot_mapping=None, batch_index=None, zactivate_len=None):
if not self.use_past or self.is_first_iteration:
batch_size = input_ids.shape[0]
prefix_key_values = self.prefix_encoder(batch_size)
return super().construct(
input_ids=input_ids,
labels=labels,
input_position=input_position,
position_ids=position_ids,
attention_mask=attention_mask,
input_embeds=input_embeds,
batch_valid_length=batch_valid_length,
prefix_key_values=prefix_key_values,
block_tables=block_tables,
slot_mapping=slot_mapping
)
def kvcache(self, layer_idx):
key_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.key_cache
value_cache = \
self.transformer.encoder.layers[layer_idx].self_attention.infer_attention.paged_attention_mgr.value_cache
return key_cache, value_cache