import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

PER_BLOCK_SIZE = 128
FP8_MAX = 448.0

class TestNPUFlashAttentionV2(TestCase):
    def supported_op_exec(self, query, key, value, drop_mask=None, keep_prob=1.0):
        scale = 0.08838
        qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
        softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
        if drop_mask == None or len(drop_mask.shape) == 0:
            drop_res = softmax_res
        else:
            drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
        output = torch.matmul(drop_res, value)
        output = output.transpose(1, 2)
        output = output.reshape(output.shape[0], output.shape[1], -1)
        return output

    def custom_op_exec(self, query, key, value, keep_prob=1.0):
        scale = 0.08838
        return torch_npu.npu_fusion_attention_v2(
            query, key, value, head_num=32, input_layout="BSH", scale=scale, keep_prob=keep_prob)

    def custom_op_exec_with_dropout_mask(self, query, key, value, dropout_mask, seed=0, offset=0, keep_prob=1.0):
        scale = 0.08838
        return torch_npu.npu_fusion_attention_v2(
            query, key, value, head_num=32, input_layout="BSH", scale=scale, keep_prob=keep_prob,
            dropout_mask=dropout_mask, seed=seed, offset=offset)

    def trans_BNSD2BSH(self, tensor: torch.Tensor):
        tensor = torch.transpose(tensor, 1, 2)
        tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
        return tensor

    def get_drop_mask(self, q, B, N1, S1, S2, seed=2, gen_p=0.2):
        torch.npu.set_compile_mode(jit_compile=False)
        torch.npu.manual_seed(seed)
        drop_mask_uint8 = torch_npu._npu_dropout_gen_mask(q.npu(), [B, N1, S1, S2], p=gen_p, seed=seed, offset=0,
                                                          parallel=True, sync=False)
        drop_mask_bit_np = np.unpackbits(drop_mask_uint8.cpu().numpy(), count=B*N1*S1*S2, bitorder='little')
        drop_mask_bit = torch.from_numpy(drop_mask_bit_np).reshape([B, N1, S1, S2])
        drop_mask_bit = drop_mask_bit.detach().clone().to(torch.uint8)
        return drop_mask_bit.cpu()

    def get_drop_mask_for_npu(self, B, N1, S1, S2, seed=2, gen_p=0.2, device="npu"):
        torch.npu.set_compile_mode(jit_compile=False)
        torch.npu.manual_seed(seed)
        shape = [B, N1, S1, S2]
        drop_mask_uint8 = torch_npu._npu_dropout_gen_mask(
            torch.randn(1, device=device), shape, p=gen_p, seed=seed, offset=0,
            parallel=True, sync=False)
        return drop_mask_uint8


    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2(self, device="npu"):
        query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 128, 128, dtype=torch.float16)

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()
        output = self.supported_op_exec(query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)).to(torch.float16)
        attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.custom_op_exec(q_npu, k_npu, v_npu)
        self.assertRtolEqual(output, attention_score)

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2_with_dropmask(self, device="npu"):
        query = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 256, 128, dtype=torch.float16)
        keep_prob = 0.9
        drop_mask = self.get_drop_mask(query, 1, 32, 256, 256, seed=2, gen_p=1-keep_prob)

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()
        output = self.supported_op_exec(query.to(torch.float32), key.to(torch.float32), value.to(torch.float32),
                                        drop_mask, keep_prob).to(torch.float16)
        attention_score, _, _, _, _, _, _ = self.custom_op_exec(q_npu, k_npu, v_npu, keep_prob)
        self.assertRtolEqual(output, attention_score)

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2_with_external_dropout_mask(self, device="npu"):
        B, N, S, D = 1, 32, 128, 128
        query = torch.randn(B, N, S, D, dtype=torch.float16)
        key = torch.randn(B, N, S, D, dtype=torch.float16)
        value = torch.randn(B, N, S, D, dtype=torch.float16)
        keep_prob = 0.8
        seed = 123
        offset = 0

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()

        dropout_mask = self.get_drop_mask_for_npu(B, N, S, S, seed=seed, gen_p=1-keep_prob, device=device)

        attention_score, softmax_max, softmax_sum, softmax_out, out_seed, out_offset, numels = \
            self.custom_op_exec_with_dropout_mask(q_npu, k_npu, v_npu, dropout_mask, seed=seed, offset=offset, keep_prob=keep_prob)

        self.assertEqual(attention_score.shape, q_npu.shape)
        self.assertEqual(softmax_max.shape[0], B)
        self.assertEqual(softmax_max.shape[1], N)
        self.assertEqual(softmax_sum.shape[0], B)
        self.assertEqual(softmax_sum.shape[1], N)

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2_dropout_mask_reproducibility(self, device="npu"):
        B, N, S, D = 1, 32, 128, 128
        query = torch.randn(B, S, N * D, dtype=torch.float16).npu()
        key = torch.randn(B, S, N * D, dtype=torch.float16).npu()
        value = torch.randn(B, S, N * D, dtype=torch.float16).npu()
        keep_prob = 0.9
        seed = 456
        offset = 0

        dropout_mask = self.get_drop_mask_for_npu(B, N, S, S, seed=seed, gen_p=1-keep_prob, device=device)

        result1, _, _, _, _, _, _ = self.custom_op_exec_with_dropout_mask(
            query, key, value, dropout_mask, seed=seed, offset=offset, keep_prob=keep_prob)

        result2, _, _, _, _, _, _ = self.custom_op_exec_with_dropout_mask(
            query, key, value, dropout_mask, seed=seed, offset=offset, keep_prob=keep_prob)

        self.assertRtolEqual(result1, result2)

    @SupportedDevices(['Ascend910B'])
    def test_npu_flash_attention_v2_without_dropout_mask(self, device="npu"):
        B, N, S, D = 1, 32, 128, 128
        query = torch.randn(B, N, S, D, dtype=torch.float16)
        key = torch.randn(B, N, S, D, dtype=torch.float16)
        value = torch.randn(B, N, S, D, dtype=torch.float16)
        keep_prob = 1.0
        seed = 0
        offset = 0

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()

        attention_score, softmax_max, softmax_sum, softmax_out, out_seed, out_offset, numels = \
            self.custom_op_exec_with_dropout_mask(
                q_npu, k_npu, v_npu, dropout_mask=None, seed=seed, offset=offset, keep_prob=keep_prob)

        self.assertEqual(attention_score.shape, q_npu.shape)

    @SupportedDevices(['Ascend950'])
    def test_npu_flash_attention_v2_dropout_mask_with_offset_seed(self, device="npu"):
        B, N, S, D = 1, 32, 128, 128
        query = torch.randn(B, N, S, D, dtype=torch.float16)
        key = torch.randn(B, N, S, D, dtype=torch.float16)
        value = torch.randn(B, N, S, D, dtype=torch.float16)
        keep_prob = 0.8
        seed = 100
        offset = 200

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()

        attention_score, softmax_max, softmax_sum, softmax_out, out_seed, out_offset, numels = \
            self.custom_op_exec_with_dropout_mask(
                q_npu, k_npu, v_npu, dropout_mask=None, seed=seed, offset=offset, keep_prob=keep_prob)

        self.assertEqual(seed, out_seed)
        self.assertEqual(offset, out_offset)

if __name__ == "__main__":
    run_tests()