import torch


class Prefetcher(object):
    """Prefetcher using on npu device.


    Args:
        loder (torch.utils.data.DataLoader or DataLoader like iterator):
            Using to generate inputs after preprocessing.
        stream (torch.npu.Stream): Default None.
            Because of the limitation of NPU's memory mechanism,
            if prefetcher is initialized repeatedly during training,
            a defined stream should be introduced to prevent memory leakage;
            if prefetcher is initialized only once during training,
            a defined stream is not necessary.

    Returns:
        float: tensors of shape (k, 5) and (k, 1). Labels are 0-based.
    """

    def __init__(self, loader, stream=None):
        self.loader = iter(loader)
        self.stream = stream if stream is not None else torch.npu.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return

        with torch.npu.stream(self.stream):
            self.next_input = self.next_input.npu(non_blocking=True)
            self.next_target = self.next_target.npu(non_blocking=True)

    def next(self):
        torch.npu.current_stream().wait_stream(self.stream)
        next_input = self.next_input
        next_target = self.next_target
        if next_target is not None:
            self.preload()
        return next_input, next_target