# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""WizardCoder model"""
import copy

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
from mindspore.ops import functional as F

from mindformers.models.utils import cell_reuse
from mindformers.modules.transformer.moe import default_moe_config
from mindformers.modules.layers import LayerNorm
from mindformers.version_control import get_dropout
from mindformers.core.loss import CrossEntropyLoss
from mindformers.modules.transformer import AttentionMask
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.models.base_model import BaseModel
from mindformers.tools.logger import logger
from wizardcoder_config import WizardCoderConfig
from wizardcoder_modules import WizardCoderTransformerDecoderLayer, WizardCoderVocabEmbedding

__all__ = ['WizardCoderLMHeadModel']


@MindFormerRegister.register(MindFormerModuleType.MODELS)
class WizardCoderLMHeadModel(BaseModel):
    r"""
        Provide wizardcoder training loss or logits through network.
        Args:
            config (WizardCoderConfig): The config of WizardCoderModel.

        Returns:
            Tensor, the loss or logits of the network.
        """
    @cell_reuse
    def __init__(self, config: WizardCoderConfig = None):
        config = config if config is not None else WizardCoderConfig()
        super(WizardCoderLMHeadModel, self).__init__(config, auto_prefix=True)
        self.use_past = config.use_past
        self.eos_token = self.config.eos_token
        self.pad_token = self.config.pad_token
        self.eos_token_tensor = Tensor((np.ones((1, 1)) * self.eos_token).astype(np.int32))
        self.seq_length = config.seq_length

        parallel_config = self.config.parallel_config
        self.stridedslice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),))
        self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))

        self.get_attention_mask = AttentionMask(
            seq_length=config.seq_length, parallel_config=parallel_config.dp_mp_config).to_float(config.compute_dtype)

        self.backbone = WizardCoderModel(config)
        self.head = WizardCoderHead(vocab_size=config.vocab_size, parallel_config=self.config.parallel_config)

        if parallel_config.pipeline_stage > 1:
            self.head.pipeline_stage = parallel_config.pipeline_stage - 1
            self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)

        mp = config.parallel_config.model_parallel
        vocab_size = config.vocab_size
        loss_parallel_config = copy.deepcopy(parallel_config)

        if vocab_size % mp != 0:
            logger.warning("The vocab size of WizardCoder Loss is: %s, it is not divide by model_parallel: %s",
                           vocab_size, mp)
            logger.warning("Now, the model_parallel num of WizardCoder Loss will be changed: mp = 1")
            loss_parallel_config.model_parallel = 1

        self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, eps_const=1e-24)
        self.reshape = P.Reshape()
        self.cast = P.Cast()
        self.load_checkpoint(config)
        self.add = P.Add().shard(((parallel_config.data_parallel, 1), ()))
        self.mul = P.Mul().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
        self.tile = P.Tile()
        self.gather = P.Gather()
        self.concat = P.Concat(axis=-1)
        self.ones = P.Ones()

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        input_position = kwargs.get("current_index", None)
        if input_position is not None:
            input_position = Tensor(input_position, mstype.int32)
        return {
            "input_ids": Tensor(input_ids, mstype.int32),
            "input_position": input_position
        }

    def construct(self, input_ids, labels=None, input_mask=None, input_position=None,
                  init_reset=True, batch_valid_length=None):
        r"""
            construct function for Language Modeling

            Args:
                input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
                labels (Tensor): the indices of labels in the vocabulary.

            Returns:
                logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
                                                         otherwise, return the computed loss.
        """
        batch_size, seq_length = input_ids.shape
        if self.use_past:
            if not isinstance(init_reset, Tensor):
                init_reset = Tensor([init_reset], mstype.bool_)
            if not isinstance(batch_valid_length, Tensor):
                batch_valid_length = self.ones((batch_size, 1), mstype.int32)

        if self.phase == "train":
            tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1))
        else:
            tokens = input_ids

        input_mask = self.cast(self.not_equal(tokens, self.pad_token), mstype.float16)
        attention_mask = self.get_attention_mask(input_mask)

        # [batch_size, seq_length, vocab_size]
        output_states, table = self.backbone(tokens, attention_mask, input_position, init_reset=init_reset,
                                             batch_valid_length=batch_valid_length)

        logits = self.head(output_states, table)

        if self.phase != 'train':
            logits = self.reshape(logits, (-1, logits.shape[-1]))
            if (not self.use_past or self.is_first_iteration) and input_position is not None:
                logits = self.gather(logits, input_position, 0)
            # makes cast effective to avoid allgather issue in Mindspore1.10
            input_mask = self.add(input_mask, 1)
            return logits, tokens, input_mask

        if labels is None:
            labels = self.stridedslice(input_ids, (0, 1), (batch_size, seq_length), (1, 1))
        else:
            if self.phase == "train":
                labels = self.stridedslice(labels, (0, 1), (batch_size, seq_length), (1, 1))
            label_mask = self.cast(self.not_equal(labels, -100), mstype.float16)
            input_mask = self.mul(input_mask, label_mask)

        labels = self.reshape(labels, (-1,))
        input_mask = self.reshape(input_mask, (-1,))
        loss = self.loss(logits, labels, input_mask)

        return loss


class WizardCoderEmbeddingLayer(nn.Cell):
    r"""The Embedding Layer of WizardCoder network."""

    def __init__(self, config: WizardCoderConfig = None):
        super(WizardCoderEmbeddingLayer, self).__init__()
        parallel_config = copy.deepcopy(config.parallel_config)
        embedding_mp = config.parallel_config.embedding_dp_mp_config.model_parallel
        vocab_size = config.vocab_size
        if vocab_size % embedding_mp != 0:
            logger.warning("The vocab size of embedding layer is: %s, it is not divide by model_parallel: %s",
                           vocab_size, embedding_mp)
            logger.warning("Now, model_parallel will be changed: mp = 1")
            parallel_config.embedding_dp_mp_config.model_parallel = 1

        self.word_embedding = WizardCoderVocabEmbedding(vocab_size=vocab_size,
                                                        embedding_size=config.embedding_size,
                                                        param_init=initializer('normal',
                                                                               [vocab_size, config.embedding_size],
                                                                               dtype=mstype.float32),
                                                        parallel_config=parallel_config.embedding_dp_mp_config)
        self.word_embedding.embedding_table.parallel_optimizer = True
        new_parallel_config = copy.deepcopy(parallel_config)
        new_parallel_config.vocab_emb_dp = True

        self.position_embedding = WizardCoderVocabEmbedding(vocab_size=config.n_position,
                                                            embedding_size=config.embedding_size,
                                                            param_init=initializer('normal',
                                                                                   [config.n_position,
                                                                                    config.embedding_size],
                                                                                   dtype=mstype.float32),
                                                            parallel_config=new_parallel_config.embedding_dp_mp_config)
        self.add = P.Add().shard(
            ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))

        self.dropout = get_dropout(config.dropout_prob)
        self.dropout.dropout.shard(((parallel_config.data_parallel, 1, 1),))

    def construct(self, input_ids, input_position):
        """The forward compute of Embedding Layer."""
        word_embedding, word_table = self.word_embedding(input_ids)
        position_embedding, _ = self.position_embedding(input_position)
        embedding = self.add(word_embedding, position_embedding)
        embedding = self.dropout(embedding)
        return embedding, word_table


def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers, use_select_recompute):
    r"""
        Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.

        Args:
            network(Cell) - Represents the transformer block
            parallel_config(dict) - Parallel Config
            layer_id(int) - Means the layer index for the current module, counts from zero.
            offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
            layers(int) - The total layers used for the model.
    """
    pp_dis = max(int(np.ceil((layers - 1) / parallel_config.pipeline_stage)), 1)
    pp_remainder = layers % parallel_config.pipeline_stage
    if pp_remainder > 0 and pp_dis != 1:
        if layer_id < (parallel_config.pipeline_stage - pp_remainder) * (pp_dis - 1):
            pp_dis = pp_dis - 1
        else:
            layer_id = layer_id + parallel_config.pipeline_stage - pp_remainder

    pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
    network.pipeline_stage = pp_id

    # Used for optimizer's fusion tag
    dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1)
    if parallel_config.pipeline_stage > 1:
        network.set_comm_fusion(2)
    else:
        network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
    if not use_select_recompute:
        if isinstance(parallel_config.recompute, bool):
            if parallel_config.recompute:
                network.recompute()
        else:
            if parallel_config.recompute.recompute:
                network.recompute(recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)

    else:
        network.attention.set_select_recompute()


class WizardCoderModel(nn.Cell):
    """
    The backbone of WizardCoder network

    Args:
        config(WizardCoderConfig): the config of network

    Inputs:
        input_ids: the tokenized inputs with datatype int32
        input_mask: the mask indicating whether each position is a valid input

    Returns:
        output_state: Tensor, the output logit of backbone
        present_layer: Tensor, the current feature map
        embedding_table: Tensor, the embedding table for the vocabulary
    """

    def __init__(self, config):
        super(WizardCoderModel, self).__init__()

        self.embedding = WizardCoderEmbeddingLayer(config)
        self.embedding.pipeline_stage = 0
        self.cast_rec = P.Cast()
        self.reshape_rec = P.Reshape()
        self.config = config
        self.is_first_iteration = True
        self.use_past = config.use_past

        self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.layernorm_dtype)
        if config.parallel_config.pipeline_stage > 1:
            self.layernorm.set_comm_fusion(2)
        else:
            self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
        self.layernorm.shard(((config.parallel_config.data_parallel, 1),))
        self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1

        if config.use_select_recompute:
            self.layernorm.layer_norm.add_prim_attr("recompute_comm_op", True)

        if not hasattr(config.parallel_config, "moe_config"):
            config.parallel_config.moe_config = default_moe_config
        moe_config = config.parallel_config.moe_config

        self.blocks = nn.CellList()
        for i in range(config.num_layers):
            block = WizardCoderTransformerDecoderLayer(
                hidden_size=config.embedding_size,
                batch_size=config.batch_size,
                ffn_hidden_size=config.embedding_size * config.expand_ratio,
                seq_length=config.seq_length,
                num_heads=config.num_heads,
                attention_dropout_rate=config.attention_probs_dropout_prob,
                hidden_dropout_rate=config.hidden_dropout_prob,
                hidden_act=config.hidden_act,
                use_past=config.use_past,
                param_init_type=config.param_init_type,
                layernorm_compute_type=config.layernorm_dtype,
                softmax_compute_type=config.softmax_dtype,
                parallel_config=config.parallel_config.dp_mp_config,
                use_seq_parallel=config.use_seq_parallel,
                use_flash_attention=config.use_flash_attention,
                moe_config=moe_config)
            set_parallel_configure_for_layer(
                block, layer_id=i, layers=config.num_layers,
                offset=0, parallel_config=config.parallel_config,
                use_select_recompute=config.use_select_recompute)
            self.blocks.append(block)

        self.tile = P.Tile().shard(((config.parallel_config.data_parallel,),))
        self.dtype = mstype.float16
        self.num_layers = config.num_layers
        self.input_position = Tensor(np.arange(config.seq_length), mstype.int32)
        self.bias = Tensor(np.arange(config.batch_size) * self.config.seq_length, mstype.int32)

    def construct(self, input_ids, attention_mask, input_position=None, init_reset=False, batch_valid_length=None):
        """wizardcoder model"""
        batch_size, seq_length = F.shape(input_ids)
        if input_position is None or self.is_first_iteration:
            if batch_size == 1:
                input_position = self.reshape_rec(self.input_position, (1, seq_length))
            else:
                input_position = self.tile(self.input_position, (batch_size, 1))
        else:
            bias = Tensor(np.arange(batch_size) * self.config.seq_length, mstype.int32)
            input_position = F.sub(input_position, bias)
            input_position = F.reshape(input_position, (batch_size, 1))
        input_embedding, embedding_table = self.embedding(input_ids, input_position)

        hidden_states = self.cast_rec(input_embedding, self.dtype)
        hidden_shape = F.shape(hidden_states)
        hidden_states = self.reshape_rec(hidden_states, (-1, hidden_shape[-1]))

        for i in range(self.num_layers):
            hidden_states = self.blocks[i](hidden_states, attention_mask, init_reset=init_reset,
                                           batch_valid_length=batch_valid_length)
        output_state = self.layernorm(hidden_states)
        return output_state, embedding_table


class WizardCoderHead(nn.Cell):
    r"""Head for wizardcoder to get the logits of each token in the vocab."""

    def __init__(self,
                 vocab_size,
                 compute_type=mstype.float16,
                 parallel_config=None):
        super().__init__()
        copied_parallel_config = copy.deepcopy(parallel_config)
        mp = copied_parallel_config.model_parallel
        if vocab_size % mp != 0:
            logger.warning("The vocab size of WizardCoderHead MatMul is: %s, it is not divide by model_parallel: %s",
                           vocab_size, mp)
            logger.warning("Now, the model_parallel num of WizardCoderHead MatMul will be changed: mp = 1")
            copied_parallel_config.model_parallel = 1

        if copied_parallel_config.pipeline_stage > 1:
            copied_parallel_config.vocab_emb_dp = False
        if copied_parallel_config.vocab_emb_dp:
            self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (1, 1)))
        else:
            self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (
                copied_parallel_config.model_parallel, 1)))
        self.dtype = compute_type
        self.cast = P.Cast()

    def construct(self, state, table):
        logits = self.matmul(self.cast(state, self.dtype), self.cast(table, self.dtype))
        return logits