import logging
from typing import List, Optional
from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error
from torch_npu.contrib.module._ensemble_dropout import NpuPreGenDropout

logger = logging.getLogger(__name__)

__all__ = [
    "NpuCachedDropout",
    "NpuFairseqDropout"
]


class _DropOutTask:
    def __init__(self, shape, dtype, device, p):
        self.shape = shape
        self.dtype = dtype
        self.device = device
        self.p = p
        self.request_count = 0
        self.mask_queue = []
                
                        
class NpuCachedDropout(torch.nn.Dropout):
    r"""FairseqDropout using on npu device

    .. note::
        Dynamic shapes are not supported.

    Args:
        p (float): probability of an element to be zeroed.
        module_name (string): the name of the model
    """
    
    task_dict = {}
    dropout_stream = None

    def __init__(self, p, module_name=None):
        super().__init__(p)
        self.module_name = module_name

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            shape = x.shape
            dtype = x.dtype
            device = x.device
            do_mask_flag = True
            return_obj = x
        elif isinstance(x, list):
            shape, dtype, device = x
            do_mask_flag = False
            return_obj = None
        else:
            raise RuntimeError("input type error!" + ops_error(ErrCode.TYPE))

        if self.p == 0:
            return return_obj
        key = (shape, dtype, device, self.p)
        if key not in NpuCachedDropout.task_dict:
            dropout_task = _DropOutTask(shape, dtype, device, self.p)
            dropout_task.request_count += 1
            NpuCachedDropout.task_dict[key] = dropout_task
            return return_obj
        elif not NpuCachedDropout.task_dict[key].mask_queue:
            NpuCachedDropout.task_dict[key].request_count += 1
            return return_obj
        else:
            mask, event = NpuCachedDropout.task_dict[key].mask_queue.pop(0)
            if do_mask_flag:
                return torch_npu.npu_dropout_do_mask(x, mask, self.p)[0]
            else:
                return mask

    @classmethod
    def enable_dropout_ensemble(cls, model):
        if cls.dropout_stream is None:
            cls.dropout_stream = torch.npu.Stream()

        def wait_stream_hook_func():
            def hook_function(module, inputs):
                torch.npu.current_stream().wait_stream(cls.dropout_stream)
            return hook_function
        model.register_forward_pre_hook(wait_stream_hook_func())

        def mask_gen_hook_func():
            def hook_function(module, inputs, outputs):
                for _, task in cls.task_dict.items():
                    if len(task.mask_queue) < task.request_count:
                        for j in range(task.request_count - len(task.mask_queue)):
                            mask = torch_npu.npu_dropout_gen_mask(task.shape, p=task.p, dtype=task.dtype,
                                                                  device=task.device)
                            event = None
                            task.mask_queue.append((mask, event))
            return hook_function

        model.register_forward_hook(mask_gen_hook_func())


NpuFairseqDropout = NpuCachedDropout