import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


def tsoftmax_grad(dp, softmax_res):
    muls = dp * softmax_res
    muls_r = muls.sum(dim=-1, keepdims=True)
    sub_r = dp - muls_r
    res = sub_r * softmax_res
    return res


def tsoftmax(x):
    x_max = torch.max(x, dim=-1, keepdims=True)[0]
    x_sub = x.sub(x_max)
    y = torch.exp(x_sub)
    x_sum = y.sum(dim=-1, keepdims=True)
    ans = y.div(x_sum)
    return ans, x_max, x_sum


class TestNPUFlashAttentionV2(TestCase):
    def supported_op_exec(self, query, key, value, dy, drop_mask=None, keep_prob=1.0):
        scale = 0.08838
        qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
        softmax_res, x_max, x_sum = tsoftmax(qk.to(torch.float32))
        dp = torch.matmul(dy, value.transpose(2, 3))
        if drop_mask == None or len(drop_mask.shape) == 0:
            drop_res = softmax_res
            dp_drop = dp
        else:
            drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob)
            dp_drop = dp * drop_mask * (1.0 / keep_prob)
        y = torch.matmul(drop_res, value)
        dv = torch.matmul(drop_res.transpose(2, 3), dy)
        softmax_grad_res = (tsoftmax_grad(dp_drop, softmax_res) * scale)
        dq = torch.matmul(softmax_grad_res, key)
        dk = torch.matmul(softmax_grad_res.transpose(2, 3), query)       
        dq = dq.transpose(1, 2)
        dq = dq.reshape(dq.shape[0], dq.shape[1], -1)
        dk = dk.transpose(1, 2)
        dk = dk.reshape(dk.shape[0], dk.shape[1], -1)
        dv = dv.transpose(1, 2)
        dv = dv.reshape(dv.shape[0], dv.shape[1], -1)
        return y, softmax_res, x_max, x_sum, dq, dk, dv

    def custom_op_exec(self, query, key, value, dy, softmax_max, softmax_sum, attention_in, keep_prob=1.0, numels=0, seed=2):
        scale = 0.08838
        return torch_npu.npu_fusion_attention_grad_v2(
            query, key, value, dy, head_num=32, input_layout="BSH", softmax_max=softmax_max, softmax_sum=softmax_sum,
            attention_in=attention_in, scale_value=scale, keep_prob=keep_prob, numels=numels, seed=seed)

    def trans_BNSD2BSH(self, tensor: torch.Tensor):
        tensor = torch.transpose(tensor, 1, 2)
        tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
        return tensor
    
    def get_drop_mask(self, q, B, N1, S1, S2, seed=2, gen_p=0.2):
        torch.npu.set_compile_mode(jit_compile=False)
        torch.npu.manual_seed(seed)
        drop_mask_uint8 = torch_npu._npu_dropout_gen_mask(q.npu(), [B, N1, S1, S2], p=gen_p, seed=seed, offset=0,
                                                          parallel=True, sync=False)
        drop_mask_bit_np = np.unpackbits(drop_mask_uint8.cpu().numpy(), count=B*N1*S1*S2, bitorder='little')
        drop_mask_bit = torch.from_numpy(drop_mask_bit_np).reshape([B, N1, S1, S2])
        drop_mask_bit = drop_mask_bit.detach().clone().to(torch.uint8)
        return drop_mask_bit.cpu()

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2(self, device="npu"):
        query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        dy = torch.randn(1, 32, 128, 128, dtype=torch.float16)

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()
        dy_npu = self.trans_BNSD2BSH(dy).npu()
        out, softmax_res, x_max, x_sum, dq_cpu, dk_cpu, dv_cpu = self.supported_op_exec(query.to(torch.float32), key.to(torch.float32), value.to(torch.float32), dy.to(torch.float32))
        x_max = x_max.expand(1, 32, 128, 8).npu()
        x_sum = x_sum.expand(1, 32, 128, 8).npu()
        out_npu = self.trans_BNSD2BSH(out).to(torch.float16).npu()
        dq, dk, dv, dpse, dq_rope, dk_rope, dsink = self.custom_op_exec(q_npu, k_npu, v_npu, dy_npu, x_max, x_sum, out_npu)
        self.assertRtolEqual(dq_cpu, dq.to(torch.float32), prec=0.005, prec16=0.005)

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2_with_dropmask(self, device="npu"):
        query = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        dy = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        keep_prob = 0.9
        numels = 1 * 32 * 256 * 256
        drop_mask = self.get_drop_mask(query, 1, 32, 256, 256, seed=2, gen_p=1-keep_prob)

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()
        dy_npu = self.trans_BNSD2BSH(dy).npu()
        out, softmax_res, x_max, x_sum, dq_cpu, dk_cpu, dv_cpu = self.supported_op_exec(
            query.to(torch.float32),key.to(torch.float32), value.to(torch.float32),
            dy.to(torch.float32), drop_mask, keep_prob)
        x_max = x_max.expand(1, 32, 256, 8).npu()
        x_sum = x_sum.expand(1, 32, 256, 8).npu()
        out_npu = self.trans_BNSD2BSH(out).to(torch.float16).npu()
        dq, dk, dv, dpse, dq_rope, dk_rope, dsink = self.custom_op_exec(q_npu, k_npu, v_npu, dy_npu, x_max, x_sum,
                                                                 out_npu, keep_prob, numels)
        self.assertRtolEqual(dq_cpu, dq.to(torch.float32), prec=0.005, prec16=0.005)

if __name__ == "__main__":
    run_tests()