import time
import argparse
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.multiprocessing import spawn
import torchvision.transforms as transforms
import torchbiomed.datasets as dset
import os
import shutil
import vnet
from apex import amp
import apex
nodule_masks = "normalized_nodule_mask"
lung_masks = "normalized_lung_mask"
ct_images = "normalized_lung_ct"
ct_targets = lung_masks
target_split = [2, 2, 2]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=4)
parser.add_argument('--device', type=str, default='gpu')
parser.add_argument('--device_num', type=int, default=1)
parser.add_argument('--amp', action='store_true')
parser.add_argument('--opt_level', type=str, default='O2')
parser.add_argument('--data', type=str, default='/opt/npu/dataset/luna16')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
args = parser.parse_args()
args.device_id = args.device+':0'
batch_size = args.batchSz
model = vnet.VNet(elu=False, nll=True)
model = model.to(args.device_id)
if args.amp:
model = amp.initialize(model, opt_level=args.opt_level)
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location=args.device_id)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
return
normMu = [-642.794]
normSigma = [459.512]
normTransform = transforms.Normalize(normMu, normSigma)
testTransform = transforms.Compose([
transforms.ToTensor(),
normTransform
])
testSet = dset.LUNA16(root=args.data, images=ct_images, targets=ct_targets,
mode="test", transform=testTransform, masks=None, split=target_split)
testLoader = DataLoader(testSet, batch_size=batch_size, shuffle=False, num_workers=4, sampler=None)
model.eval()
incorrect = 0
numel = 0
with torch.no_grad():
for data, target in testLoader:
data, target = data.to(args.device_id), target.to(args.device_id)
target = target.view(target.numel())
numel += target.numel()
output = model(data)
output = output.view(-1,2)
pred = output.data.max(1)[1]
incorrect += pred.ne(target.data).cpu().sum().item()
err = 100.*incorrect/numel
print('Error rate: {}/{} ({:.4f}%)\n'.format(incorrect, numel, err))
if __name__ == '__main__':
main()