b686393d创建于 2024年11月21日历史提交
import os
import unittest
from pathlib import Path
import torch_npu
import torchvision
from torchvision import datasets, transforms
import torchvision_npu


torch_npu.npu.current_stream().set_data_preprocess_stream(True)
TEST_DIR = Path(__file__).resolve().parents[1]


class TestAccelerate(unittest.TestCase):
    def test_accelerate_npu(self):
        path = os.path.join(TEST_DIR, "Data")
        torchvision.set_image_backend('npu')
        train_datasets = datasets.ImageFolder(path, transforms.RandomHorizontalFlip())
        self.assertTrue(train_datasets.accelerate_enable)
        self.assertNotEqual(train_datasets.device, "cpu")

    def test_accelerate_cpu(self):
        path = os.path.join(TEST_DIR, "Data")
        torchvision.set_image_backend('PIL')
        train_datasets = datasets.ImageFolder(path)
        self.assertFalse(train_datasets.accelerate_enable)
        self.assertEqual(train_datasets.device, "cpu")


if __name__ == '__main__':
    unittest.main()