import unittest
import torch
import torch.nn as nn
try:
    import torchair
    from torchair.configs.compiler_config import CompilerConfig
    HAS_TORCHAIR = True
except ImportError:
    HAS_TORCHAIR = False

import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests

window_size = 209715200
attn_window_tensor = torch.zeros([window_size], dtype=torch.int8).npu()

attn_workers = 144
micro_batch_number = 3
batch_size = 30
top_k = 8
hidden_size = 7168
expert_num = 288
attn_to_ffn_token_size = (7168 + 4 + 511) // 512 * 512
ffn_to_attn_token_size = 7168 * 2
attn_window = attn_window_tensor.data_ptr()


def _set_all_flags():
    num_int8 = batch_size * (top_k + 1) * 4 * micro_batch_number

    int32_view = attn_window_tensor[:num_int8].view(torch.int32)

    int32_view[:] = 1


class TestModelInplace(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, schedule_context):
        torch_npu._afd.attention_worker_scheduler_(schedule_context)


class TestModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, schedule_context):
        return torch_npu._afd.attention_worker_scheduler(schedule_context)


@unittest.skipUnless(HAS_TORCHAIR, "torchair is not available")
class TestAttentionWorkerScheduler(TestCase):
    def setUp(self):
        self.context_holder = torch_npu._afd.create_schedule_context_holder(schedule_mode=1, session_num=attn_workers,
                                                                            micro_batch_num=micro_batch_number,
                                                                            micro_batch_size=batch_size,
                                                                            selected_expert_num=top_k + 1,
                                                                            expert_num=expert_num,
                                                                            attn_to_ffn_token_size=attn_to_ffn_token_size,
                                                                            ffn_to_attn_token_size=ffn_to_attn_token_size,
                                                                            attention_window=attn_window,
                                                                            attention_window_size=window_size)

        self.schedule_context = self.context_holder.get_schedule_context_tensor()
        _set_all_flags()

    @unittest.skip("skip case until cann supported")
    @SupportedDevices(['Ascend910B'])
    def test_attention_worker_scheduler_(self):
        schedule_context1 = self.schedule_context.clone()
        torch_npu._afd.attention_worker_scheduler_(self.schedule_context)
        self.assertNotEqual(schedule_context1, self.schedule_context)

    @unittest.skip("skip case until cann supported")
    @SupportedDevices(['Ascend910B'])
    def test_attention_worker_scheduler(self):
        _set_all_flags()
        schedule_context1 = self.schedule_context.clone()
        schedule_context2 = torch_npu._afd.attention_worker_scheduler(self.schedule_context)
        self.assertEqual(schedule_context1, self.schedule_context)
        self.assertNotEqual(schedule_context2, self.schedule_context)

    @unittest.skip("skip case until cann supported")
    @SupportedDevices(['Ascend910B'])
    def test_attention_worker_scheduler__graph(self):
        _set_all_flags()
        config = CompilerConfig()
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        model = TestModelInplace().npu()
        model = torch.compile(model, backend=npu_backend)
        schedule_context1 = self.schedule_context.clone()
        model(self.schedule_context)
        self.assertNotEqual(schedule_context1, self.schedule_context)
        torch._dynamo.reset()

    @unittest.skip("skip case until cann supported")
    @SupportedDevices(['Ascend910B'])
    def test_attention_worker_scheduler_graph(self):
        _set_all_flags()
        config = CompilerConfig()
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        model = TestModel().npu()
        model = torch.compile(model, backend=npu_backend)
        schedule_context1 = self.schedule_context.clone()
        schedule_context2 = model(self.schedule_context)
        self.assertEqual(schedule_context1, self.schedule_context)
        self.assertNotEqual(schedule_context2, self.schedule_context)
        torch._dynamo.reset()


if __name__ == '__main__':
    run_tests()