import math
import torch

from megatron.core import parallel_state


def get_slopes(n):
    def get_slopes_power_of_2(n):
        start = (2 ** (-2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio ** i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
                                                           :n - closest_power_of_2]


class AlibiForFusionAttnSingleton:
    _alibi_tensor_args = None
    _alibi_tensor = None

    _alibi_slopes_headnum = None
    _alibi_slopes = None

    @classmethod
    def get_alibi_tensor_for_fusion_attn(cls, max_seq_len, num_attention_heads, dtype, neg_diagonal_opposite=False,
                                         last_k=1024):
        if cls._alibi_tensor is None or cls._alibi_tensor_args != (
        max_seq_len, num_attention_heads, neg_diagonal_opposite, last_k):
            if last_k > max_seq_len:
                last_k = max_seq_len

            tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
            current_head_num = num_attention_heads // tp_world_size
            slopes = AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn(num_attention_heads)

            position_point = torch.arange(max_seq_len) - max_seq_len + 1
            diag = torch.diag(torch.diag(position_point)).unsqueeze(0).unsqueeze(0)

            position_point = position_point.unsqueeze(0).unsqueeze(0).expand(current_head_num, last_k, -1)
            position_point = position_point - diag.transpose(-1, -2)[:, -last_k:, :].expand(current_head_num, last_k,
                                                                                            max_seq_len)

            alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point.npu()

            if not neg_diagonal_opposite:
                alibi = -torch.abs(alibi)

            alibi = alibi.unsqueeze(0)
            alibi = torch.Tensor(alibi).npu()
            if dtype == torch.float16:
                alibi = alibi.to(torch.float16)
            elif dtype == torch.bfloat16:
                alibi = alibi.to(torch.bfloat16)

            cls._alibi_tensor = alibi
            cls._alibi_tensor_args = (max_seq_len, num_attention_heads, neg_diagonal_opposite, last_k)

        return cls._alibi_tensor

    @classmethod
    def get_alibi_slopes_for_fusion_attn(cls, n):
        if cls._alibi_slopes is None or cls._alibi_slopes_headnum != n:
            slopes = get_slopes(n)

            tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
            tp_index = parallel_state.get_tensor_model_parallel_rank()

            current_head_num = n // tp_world_size
            slopes = torch.Tensor(slopes[tp_index * current_head_num: tp_index * current_head_num + current_head_num]).npu()

            cls._alibi_slopes = slopes
            cls._alibi_slopes_headnum = n
            return cls._alibi_slopes

        return cls._alibi_slopes