import os
import math

import torch
import torch_npu
import torch.nn.functional as F
from loguru import logger

import numpy as np


def get_dense_attention_score(q, k):
    scores = q @ k.transpose(-2, -1)
    d_k = q.size(-1)
    del q, k
    scores = scores / (math.sqrt(d_k) + 1e-8)

    attn_score = F.softmax(scores, dim=-1, dtype=torch.float32)
    return attn_score


def get_per_row_min_cumulative_coverage(dense_qk, sparsity, device="npu:0"):
    batch, head, l_q, l_k = dense_qk.shape
    per_row_k = int(l_k * sparsity)
    cumultative_coverage_results = torch.empty(batch, head, device=device, dtype=dense_qk.dtype)

    for b in range(batch):
        for h in range(head):
            topk_values, _ = torch.topk(dense_qk[b, h], k=l_k - per_row_k, dim=-1, largest=False, sorted=False)
            topk_sum = torch.sum(topk_values, dim=-1)
            coverage_per_row = 1 - (topk_sum / torch.sum(dense_qk[b, h], dim=-1))
            min_coverage_per_head = torch.min(coverage_per_row)
            cumultative_coverage_results[b, h] = min_coverage_per_head
            del topk_values, topk_sum, coverage_per_row, min_coverage_per_head
    
    return cumultative_coverage_results


def remove_q_k_spec_token(q, txt_len=11):
    return q[:, :, :-txt_len, :]


'''
遍历不同的稀疏度,计算对应的累计注意力分数覆盖率
'''


def get_cumulative_coverage_of_different_sparsity(dir_path, global_layer_num, sparsity_list, image_len=10, 
                                                    txt_len=11, cu_seqlens_q=10206, img_token_lens=10200, 
                                                    frame_count=17, filter_first_frame=True, device="npu:0"):
    all_layer_sparsity_cumulative_coverage = {}
    for global_idx in range(0, global_layer_num):
        per_layer_sparsity_cumulative_coverage = {}
        qk_path = os.path.join(dir_path, f"layer-{global_idx}-qk.pt")
        qk = torch.load(qk_path, map_location=device)
        q = qk['q'].permute(0, 2, 1, 3)[:, :, :cu_seqlens_q, :]
        k = qk['k'].permute(0, 2, 1, 3)[:, :, :cu_seqlens_q, :]
        tokens_per_frame = int(q.shape[2] / frame_count)
        if filter_first_frame:
            q = q[:, :, tokens_per_frame:, :]
            k = k[:, :, tokens_per_frame:, :]
        txt_len = cu_seqlens_q - img_token_lens
        q_img = remove_q_k_spec_token(q, txt_len)
        k_img = remove_q_k_spec_token(k, txt_len)
        del q, k
        q_len = q_img.shape[2]
        dense_qk = get_dense_attention_score(q_img, k_img)
        del q_img, k_img
        for sparsity in sparsity_list:
            logger.info(f"Global layer idx:{global_idx}, Sparsity: {sparsity}.")
            coverage = get_per_row_min_cumulative_coverage(dense_qk, sparsity)
            per_layer_sparsity_cumulative_coverage[sparsity] = coverage
        del dense_qk
        all_layer_sparsity_cumulative_coverage[global_idx] = per_layer_sparsity_cumulative_coverage

    return all_layer_sparsity_cumulative_coverage


def get_sparsity_of_target_cumulative_coverage(all_layer_sparsity_cumulative_coverage, global_layer_num, head_num,
                                                sparsity_list, target_coverage=0.95, device="npu:0"):
    target_sparsity_of_target_coverage = torch.ones((global_layer_num, head_num)) * -1
    target_sparsity_of_target_coverage = target_sparsity_of_target_coverage.to(device)

    for sparsity in sparsity_list:
        for global_idx in range(0, global_layer_num):
            per_layer_sparsity_cumulative_coverage = all_layer_sparsity_cumulative_coverage[global_idx][sparsity][0]
            
            for idx, per_head_cov in enumerate(per_layer_sparsity_cumulative_coverage):
                if per_head_cov > target_coverage:
                    target_sparsity_of_target_coverage[global_idx][idx] = 1 - sparsity
    target_sparsity_of_target_coverage[torch.where(target_sparsity_of_target_coverage == -1)] = 0
    return target_sparsity_of_target_coverage


def save_expected_sparsity(dir_path, target_sparsity_expected_coverage, target_coverage=0.95):
    os.makedirs(dir_path, exist_ok=True)
    sparsity_file_path_of_expected_coverage = os.path.join(dir_path, 
                                                f"sparsity_of_RE_{target_coverage}_only_img.pt")
    torch.save(target_sparsity_expected_coverage, sparsity_file_path_of_expected_coverage)