from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms, datasets, models
import os
if torch.__version__ >= "1.8":
import torch_npu
else:
import torch.npu
import time
from collections import OrderedDict
from model.residual_attention_network import ResidualAttentionModel_92_32input_update as ResidualAttentionModel
def main():
model_file = 'model_92_sgd.pkl'
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop((32, 32), padding=4),
transforms.ToTensor()
])
test_transform = transforms.Compose([
transforms.ToTensor()
])
test_dataset = datasets.CIFAR10(root='./data/', train=False, transform=test_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=False)
model = ResidualAttentionModel(10)
base_weights = torch.load(model_file, map_location="cpu")
print('Loading base network...')
new_state_dict = OrderedDict()
for k, v in base_weights.items():
if(k[0: 7] == "module."):
name = k[7:]
else:
name = k[0:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
cnt = 0
model.eval()
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
for i in range(len(labels.data)):
cnt += 1
print(f"Image{cnt} real_class: {labels.data[i]} pred_clss: {predicted[i]}")
if __name__ == "__main__":
main()