from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
from tqdm import tqdm
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms
import torch.multiprocessing
from tqdm import tqdm
import _init_paths
import models
from config import cfg
from config import check_config
from config import update_config
from core.inference import get_multi_stage_outputs
from core.inference import aggregate_results
from core.group import HeatmapParser
from dataset import make_test_dataloader
from utils.utils import create_logger
from utils.utils import get_model_summary
from utils.vis import save_debug_images
from utils.vis import save_valid_image
from utils.transforms import resize_align_multi_scale
from utils.transforms import get_final_preds
from utils.transforms import get_multi_scale_size
torch.multiprocessing.set_sharing_strategy('file_system')
def parse_args():
parser = argparse.ArgumentParser(description='Test keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
def _print_name_value(logger, name_value, full_arch_name):
names = name_value.keys()
values = name_value.values()
num_values = len(name_value)
logger.info(
'| Arch ' +
' '.join(['| {}'.format(name) for name in names]) +
' |'
)
logger.info('|---' * (num_values+1) + '|')
if len(full_arch_name) > 15:
full_arch_name = full_arch_name[:8] + '...'
logger.info(
'| ' + full_arch_name + ' ' +
' '.join(['| {:.3f}'.format(value) for value in values]) +
' |'
)
def main():
args = parse_args()
update_config(cfg, args)
check_config(cfg)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid'
)
logger.info(pprint.pformat(args))
logger.info(cfg)
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=False
)
dump_input = torch.rand(
(1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE)
)
logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu')), strict=True)
else:
model_state_file = os.path.join(
final_output_dir, 'model_best.pth.tar'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).npu()
model.eval()
data_loader, test_dataset = make_test_dataloader(cfg)
if cfg.MODEL.NAME == 'pose_hourglass':
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
else:
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
)
parser = HeatmapParser(cfg)
all_preds = []
all_scores = []
pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None
for i, (images, annos) in tqdm(enumerate(data_loader)):
assert 1 == images.size(0), 'Test batch size should be 1'
image = images[0].cpu().numpy()
base_size, center, scale = get_multi_scale_size(
image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
)
with torch.no_grad():
final_heatmaps = None
tags_list = []
for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
input_size = cfg.DATASET.INPUT_SIZE
image_resized, center, scale = resize_align_multi_scale(
image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
)
image_resized = transforms(image_resized)
image_resized = image_resized.unsqueeze(0).npu()
outputs, heatmaps, tags = get_multi_stage_outputs(
cfg, model, image_resized, cfg.TEST.FLIP_TEST,
cfg.TEST.PROJECT2IMAGE, base_size
)
final_heatmaps, tags_list = aggregate_results(
cfg, s, final_heatmaps, tags_list, heatmaps, tags
)
final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
tags = torch.cat(tags_list, dim=4)
grouped, scores = parser.parse(
final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
)
final_results = get_final_preds(
grouped, center, scale,
[final_heatmaps.size(3), final_heatmaps.size(2)]
)
if cfg.TEST.LOG_PROGRESS:
pbar.update()
if i % cfg.PRINT_FREQ == 0:
prefix = '{}_{}'.format(os.path.join(final_output_dir, 'result_valid'), i)
save_valid_image(image, final_results, '{}.jpg'.format(prefix), dataset=test_dataset.name)
all_preds.append(final_results)
all_scores.append(scores)
if cfg.TEST.LOG_PROGRESS:
pbar.close()
name_values, _ = test_dataset.evaluate(
cfg, all_preds, all_scores, final_output_dir
)
if isinstance(name_values, list):
for name_value in name_values:
_print_name_value(logger, name_value, cfg.MODEL.NAME)
else:
_print_name_value(logger, name_values, cfg.MODEL.NAME)
if __name__ == '__main__':
main()