from __future__ import print_function
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)
import logging
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from tabulate import tabulate
from torchvision import transforms
from segmentron.data.dataloader import get_segmentation_dataset
from segmentron.models.model_zoo import get_segmentation_model
from segmentron.utils.score import SegmentationMetric
from segmentron.utils.distributed import synchronize, make_data_sampler, make_batch_data_sampler
from segmentron.config import cfg
from segmentron.utils.options import parse_args
from segmentron.utils.default_setup import default_setup
class Evaluator(object):
def __init__(self, args):
args.device = "npu:0"
self.args = args
torch.npu.set_device(args.device)
self.device = torch.device(args.device)
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
])
val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='val', mode='testval', transform=input_transform)
val_sampler = make_data_sampler(val_dataset, False, args.distributed)
val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False)
self.val_loader = data.DataLoader(dataset=val_dataset,
batch_sampler=val_batch_sampler,
num_workers=cfg.DATASET.WORKERS,
pin_memory=True)
self.classes = val_dataset.classes
self.model = get_segmentation_model().npu()
if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'named_modules') and \
cfg.MODEL.BN_EPS_FOR_ENCODER:
logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER))
self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER)
if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
self.model.npu()
self.metric = SegmentationMetric(val_dataset.num_class)
def set_batch_norm_attr(self, named_modules, attr, value):
for m in named_modules:
if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm):
setattr(m[1], attr, value)
def eval(self):
self.metric.reset()
self.model.eval()
if self.args.distributed:
model = self.model.module
else:
model = self.model
logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader)))
import time
time_start = time.time()
for i, (image, target, filename) in enumerate(self.val_loader):
image = image.npu()
target = target.npu()
with torch.no_grad():
output = model.evaluate(image)
print('eval output shape:', output.shape)
self.metric.update(output, target)
pixAcc, mIoU = self.metric.get()
logging.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
i + 1, pixAcc * 100, mIoU * 100))
synchronize()
pixAcc, mIoU, category_iou = self.metric.get(return_category_iou=True)
logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start))
logging.info('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
pixAcc * 100, mIoU * 100))
headers = ['class id', 'class name', 'iou']
table = []
for i, cls_name in enumerate(self.classes):
table.append([cls_name, category_iou[i]])
logging.info('Category iou: \n {}'.format(tabulate(table, headers, tablefmt='grid', showindex="always",
numalign='center', stralign='center')))
if __name__ == '__main__':
args = parse_args()
cfg.update_from_file(args.config_file)
cfg.update_from_list(args.opts)
cfg.PHASE = 'test'
cfg.ROOT_PATH = root_path
cfg.check_and_freeze()
default_setup(args)
evaluator = Evaluator(args)
evaluator.eval()