import timeit
import time
import numpy as np
import torch
import torchvision
import torchvision.datasets
from torchvision import transforms
from PIL import Image
import torchvision_npu
IMAGENET_PATH = "./test/Data/"
IMAGE_SIZE = 224
EPOCH = 10
BATCH_SIZE = 256
IMAGE_RESIZE = 256
class ImagenetHandle(object):
def __init__(self, imagenet_path):
self.imagenet_path = imagenet_path
self.img_size = IMAGE_SIZE
self.epoch = EPOCH
self.batch_size = BATCH_SIZE
self.img_resize = IMAGE_RESIZE
def imagenet_valid_dataset(self):
transform = [
transforms.Resize(self.img_resize),
transforms.CenterCrop(self.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
]
dataset = torchvision.datasets.ImageFolder(self.imagenet_path, transforms.Compose(transform))
return dataset
def imagenet_train_dataset(self):
transform = [
transforms.RandomResizedCrop(self.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
]
dataset = torchvision.datasets.ImageFolder(self.imagenet_path, transforms.Compose(transform))
return dataset
def trans_dataset(self):
train_loader = self.imagenet_train_dataset()
valid_loader = self.imagenet_valid_dataset()
train_fps = 0
valid_fps = 0
for index in range(self.epoch):
train_begin_time = time.time()
for batch_idx, (imgs, target) in enumerate(train_loader):
continue
train_fps += len(train_loader) / (time.time() - train_begin_time)
valid_begin_time = time.time()
for batch_idx, (imgs, target) in enumerate(valid_loader):
continue
valid_fps += len(valid_loader) / (time.time() - valid_begin_time)
print('train data {:.4f} FPS , valid data {:.4f} FPS'.format(train_fps / self.epoch, valid_fps / self.epoch))
return train_fps / self.epoch, valid_fps / self.epoch
def test_cv2_accelerate():
torchvision.set_image_backend("PIL")
torch.manual_seed(10)
pil_train_fps, pil_valid_fps = ImagenetHandle(IMAGENET_PATH).trans_dataset()
torchvision.set_image_backend("cv2")
torch.manual_seed(10)
cv2_train_fps, cv2_valid_fps = ImagenetHandle(IMAGENET_PATH).trans_dataset()
assert pil_train_fps < cv2_train_fps