"""
算子预估
"""
import math
import os

from mindspeed.auto_settings.config.model_config import get_model_config
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.utils.utils import get_module_info, get_black_prof_file
from mindspeed.auto_settings.config.search_config import SearchConfig


class Activation:
    """
    Modeling the activation memory generated during the forward of a single transformer block.
    """

    def __init__(self, config: SearchConfig):
        self.unit_gb = 1024 ** 3
        self.config = config
        self.tp = config.tensor_model_parallel_size
        self.cp = config.ring_attention_size
        self.up = config.ulysses_size
        self.ep = config.expert_model_parallel_size
        self.mbs = config.micro_batch_size
        self.seq_len = config.seq_length
        self.hidden_size = config.hidden_size
        self.ffn_hidden_size = config.ffn_hidden_size
        self.num_query_groups = config.num_query_groups
        self.num_attention_heads = config.num_attention_heads
        self.num_experts = config.num_experts
        self.swiglu = config.swiglu
        self.top_k = config.moe_router_topk
        self.recompute_activation_function = config.recompute_activation_function
        self.swap_attention = config.swap_attention
        # if hasattr(config, 'swap_attention'):
        #     self.swap_attention = config.swap_attention

    @property
    def activation_mem(self):
        _config = self.config.crop()
        file_path = get_black_prof_file(_config)
        act_mem = get_module_info(file_path, '0', 'memory')
        if math.isinf(act_mem):
            _config.micro_batch_size = 1
            act_mem = get_module_info(file_path, "0", "memory")
            act_mem *= self.config.micro_batch_size
        return act_mem * self.unit_gb

    def layer_norm(self):
        shape = [self.seq_len // self.cp // self.up // self.tp, self.mbs, self.hidden_size]
        return 2 * math.prod(shape)

    def linear_qkv(self):
        ng = self.num_query_groups // self.tp
        np = self.num_attention_heads // self.tp
        head_dim = self.hidden_size // self.num_attention_heads
        shape = [self.seq_len // self.cp // self.up, self.mbs, ng * (np // ng + 2) * head_dim]
        return 2 * math.prod(shape)

    def linear_proj(self):
        shape = [self.seq_len // self.cp // self.up // self.tp, self.mbs, self.hidden_size]
        return 2 * math.prod(shape)

    def core_attention(self):
        ng = self.num_query_groups // self.tp
        np = self.num_attention_heads // self.tp
        head_dim = self.hidden_size // self.num_attention_heads
        q_shape = [self.seq_len // self.cp // self.up, self.mbs, np, head_dim]
        q_mem = 2 * math.prod(q_shape)
        ret = q_mem
        if self.up > 1:
            ret += 4 * q_mem
        if self.cp > 1:
            ret += (2048 * 2048)
        return ret

    def mlp(self):
        ffn_hidden_size = self.ffn_hidden_size
        if self.swiglu:
            ffn_hidden_size *= 2

        if self.ep == 0:
            linear1_shape = [self.seq_len // self.cp // self.up, self.mbs, ffn_hidden_size // self.tp]
            linear1_mem = 2 * math.prod(linear1_shape)

            activation_func_mem = linear1_mem
            if self.swiglu:
                activation_func_mem /= 2

            linear2_shape = [self.seq_len // self.cp // self.up // self.tp, self.mbs, self.hidden_size]
            linear2_mem = 2 * math.prod(linear2_shape)
        else:
            num_total_tokens = self.seq_len // self.cp // self.up * self.ep * self.top_k
            linear1_shape = [num_total_tokens // self.num_experts, self.mbs, ffn_hidden_size // self.tp]
            linear1_mem = 2 * math.prod(linear1_shape)

            activation_func_mem = linear1_mem
            if self.swiglu:
                activation_func_mem = activation_func_mem // 2

            linear2_shape = [num_total_tokens // self.num_experts, self.mbs, self.hidden_size]
            linear2_mem = 2 * math.prod(linear2_shape)

        if self.recompute_activation_function:
            activation_func_mem = 0

        return linear1_mem + activation_func_mem + linear2_mem

    def moe_layer(self):
        num_local_experts = self.num_experts // self.ep
        num_total_tokens = self.seq_len // self.cp // self.up * self.ep * self.top_k

        shape = [num_total_tokens // self.num_experts * num_local_experts, self.mbs, self.hidden_size]
        dispatcher = 2 * math.prod(shape)

        sequential_mlp = self.mlp() * num_local_experts

        shape = [num_total_tokens // self.num_experts * num_local_experts, self.mbs, self.hidden_size]
        undispatcher = 2 * math.prod(shape)

        return dispatcher + sequential_mlp + undispatcher


class MemoryCostBlack(object):
    unit_gb = 1024 ** 3
    cann_memory = 4.5 * 1024 ** 3

    def __init__(self):
        self.logger = get_logger("MemoryCostBlack")

    def compute_params(self, config: SearchConfig):
        """Calculate model parameters on stage0."""
        pp = config.pipeline_model_parallel_size
        tp = config.tensor_model_parallel_size
        ep = config.expert_model_parallel_size
        num_experts = config.num_experts if config.num_experts else 1

        gated_linear_multiplier = 3 / 2 if config.swiglu else 1
        embedding_size = config.hidden_size * config.padded_vocab_size
        num_parameters_in_transformer_layers = (
                2
                * config.num_layers
                * config.hidden_size
                * config.hidden_size
                * (
                        1
                        + ((config.ffn_hidden_size / config.hidden_size) * num_experts * gated_linear_multiplier)
                        + (config.num_query_groups / config.num_attention_heads)
                        + (2 / config.hidden_size)
                        + (1 / (config.num_layers * config.hidden_size))
                )
        )
        mlp_params_shard = (
                2
                * config.hidden_size * config.ffn_hidden_size
                * num_experts * gated_linear_multiplier
                * config.num_layers / pp
        )
        total_params_count = (
                (
                        num_parameters_in_transformer_layers / pp
                        + embedding_size
                        - mlp_params_shard
                ) / tp
                + (mlp_params_shard / tp if ep is None else mlp_params_shard / tp / ep)
        )
        if config.untie_embeddings_and_output_weights and pp == 1:
            total_params_count += embedding_size / tp
        self.logger.debug(f'num_parameters_in_transformer_layers: {num_parameters_in_transformer_layers}')
        self.logger.debug(f'mlp_params_shard: {mlp_params_shard}')
        self.logger.debug(f'total_params_count: {total_params_count}')

        return int(total_params_count)

    def compute_static_memory(self, params: int, config: SearchConfig):
        dp = config.data_parallel_size
        if config.fp16:
            mem_para = 2 * params
            mem_grad = 2 * params
            if config.reuse_fp32_param and config.use_distributed_optimizer:
                mem_optimizer = 4 * params + 8 * params / dp
            elif config.use_distributed_optimizer:
                mem_optimizer = 4 * params + 4 * params + 8 * params / dp
            elif config.reuse_fp32_param:
                mem_optimizer = 12 * params
            else:
                mem_optimizer = 16 * params
        elif config.bf16:
            if config.reuse_fp32_param and config.use_distributed_optimizer:
                mem_para = 0
                mem_grad = 4 * params
                mem_optimizer = 4 * params + 8 * params / dp
            elif config.use_distributed_optimizer:
                mem_para = 2 * params
                mem_grad = 4 * params
                mem_optimizer = 4 * params + 8 * params / dp
            elif config.reuse_fp32_param:
                mem_para = 0
                mem_grad = 4 * params
                mem_optimizer = 4 * params + 8 * params
            else:
                mem_para = 2 * params
                mem_grad = 4 * params
                mem_optimizer = 4 * params + 4 * params + 4 * params
        else:
            raise AssertionError('not support fp32 training')
        return mem_para, mem_grad, mem_optimizer

    def get_peak_memory(self, config: SearchConfig):
        model_config = get_model_config()
        pp = config.pp
        vpp = config.vpp
        activation = Activation(config)

        params = self.compute_params(config)
        mem_para, mem_grad, mem_optimizer = self.compute_static_memory(params, config)
        mem_activation_per_layer = activation.activation_mem
        if vpp == 1:
            # non-interleaved pipeline
            mem_activation_per_batch = mem_activation_per_layer * (model_config.num_layers // pp)
            mem_activation = mem_activation_per_batch * pp
        else:
            num_layers_per_vpp_stage = model_config.num_layers // pp // vpp
            mem_activation_per_batch = mem_activation_per_layer * num_layers_per_vpp_stage
            mem_activation = mem_activation_per_batch * (pp * vpp + (pp - 1))

        if model_config.recompute_granularity == 'full':
            mem_activation = 0
            mem_activation_per_layer = 0
            mem_activation_per_batch = 0

        m1 = mem_para + mem_optimizer + mem_activation
        m2 = mem_para + mem_optimizer + mem_activation + mem_grad - mem_activation_per_batch
        m3 = mem_para + mem_optimizer + mem_activation + mem_grad
        peak_memory = (max(m1, m2, m3) + self.cann_memory) / self.unit_gb

        self.logger.debug(
            f"### config: {config} \n"
            f"mem_para: {mem_para / self.unit_gb}\n"
            f"mem_grad: {mem_grad / self.unit_gb}\n"
            f"mem_optimizer: {mem_optimizer / self.unit_gb}\n"
            f"mem_activate_per_layer: {mem_activation_per_layer / self.unit_gb}\n"
            f"mem_activation: {mem_activation / self.unit_gb}\n"
            f"peak_memory: {peak_memory}"
        )
        return peak_memory