import torch
import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error


@torch.library.register_fake("npu::npu_fusion_attention_v3")
def npu_fusion_attention_forward_v3(query, key, value, head_num, input_layout, pse=None, padding_mask=None,
                                atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647,
                                inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
                                 gen_mask_parallel=True, sync=False, softmax_layout="", sink=None):
    B = query.size(0)
    S1 = query.size(2)
    T = query.size(0)
    N_L = query.size(1)
    aten_score_shape = query.shape

    if input_layout == "BSH":
        S1 = query.size(1)
        H = query.size(2)
        D = H / head_num
        D2 = 0 if D == 0 or key.size(2) == 0 else value.size(2) / (key.size(2) / D)
        aten_score_shape = [B, S1, int(head_num * D2)]
    elif input_layout == "SBH":
        B = query.size(1)
        S1 = query.size(0)
        H = query.size(2)
        D = H / head_num
        D2 = 0 if D == 0 or key.size(2) == 0 else value.size(2) / (key.size(2) / D)
        aten_score_shape = [S1, B, int(head_num * D2)]
    elif input_layout == "BNSD":
        D2 = value.size(3)
        aten_score_shape = [B, N_L, S1, D2]
    elif input_layout == "BSND":
        S1 = query.size(1)
        N_L = query.size(2)
        D2 = value.size(3)
        aten_score_shape = [B, S1, N_L, D2]
    elif input_layout == "TND":
        D2 = value.size(2)
        aten_score_shape = [T, N_L, D2]

    if input_layout == "TND":
        softmax_shape = [T, N_L, 8]
    else:
        softmax_shape = [B, head_num, S1, 8]

    seed = torch.empty([1], dtype=torch.long, device=query.device)
    offset = torch.empty([1], dtype=torch.long, device=query.device)
    attention_score = query.new_empty(aten_score_shape, dtype=query.dtype, device=query.device)
    softmax_max = torch.empty(softmax_shape, dtype=torch.float32, device=query.device)
    softmax_sum = torch.empty(softmax_shape, dtype=torch.float32, device=query.device)
    softmax_out = torch.empty([0], dtype=query.dtype, device=query.device)
    return (attention_score, softmax_max, softmax_sum, softmax_out, seed, offset)


@torch.library.register_fake("npu::npu_fusion_attention_grad_v3")
def npu_fusion_attention_backward_v3(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None,
                                  softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0,
                                  keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=None, offset=None,
                                  prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
                                  gen_mask_parallel=True, sync=False, softmax_layout="", sink=None):
    dq = query.new_empty(query.shape, dtype=query.dtype, device=query.device)
    dk = key.new_empty(key.shape, dtype=query.dtype, device=query.device)
    dv = value.new_empty(value.shape, dtype=query.dtype, device=query.device)
    dpse = torch.empty([0], dtype=query.dtype, device=query.device)
    dsink = torch.empty([], device=query.device) if sink is None else torch.empty(sink.shape, dtype=sink.dtype, device=query.device)
    return (dq, dk, dv, dpse, dsink)