import threading
import torch
from torch.utils.data import Dataset

from torch_npu.testing.testcase import TestCase, run_tests


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.data[index].clone()


def thread_worker(data, device, results, index):
    for _ in range(50):
        data = data.to(device, non_blocking=True)
        results.append(data)


class TestModel:

    def run(self):
        torch.npu.set_device(0)
        device = torch.device(f'npu:{0}')

        dataset = RandomDataset(1000, 1000)

        results = [None] * len(dataset)

        threads = []
        batch_size = 64
        for i in range(0, len(dataset), batch_size):
            batch = dataset[i:i + batch_size]
            thread = threading.Thread(target=thread_worker, args=(batch, device, results, i))
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()


class AllocatorMultiThreadProf(TestCase):
    test_model = TestModel()

    def test_model_run_succ(self):
        res = True
        try:
            self.test_model.run()
        except Exception as e:
            res = False
        self.assertEqual(res, True)


if __name__ == "__main__":
    run_tests()