from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function import validate
from utils.utils import create_logger
from utils.utils import get_model_summary
import dataset
import models
def parse_args():
parser = argparse.ArgumentParser(description='Train keypoints network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--prevModelDir',
help='prev Model directory',
type=str,
default='')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
def main():
args = parse_args()
update_config(cfg, args)
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=False
)
writer_dict = {
'writer': SummaryWriter(log_dir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
dump_input = torch.rand(
(1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])
)
logger.info(get_model_summary(model, dump_input))
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model_object = torch.load(cfg.TEST.MODEL_FILE)
if 'latest_state_dict' in model_object.keys():
logger.info('=> loading from latest_state_dict at {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(model_object['latest_state_dict'], strict=False)
else:
logger.info('=> no latest_state_dict found')
model.load_state_dict(model_object, strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
model = model.npu()
criterion = JointsMSELoss(
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
).npu()
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
valid_dataset = eval('dataset.'+cfg.DATASET.TEST_DATASET)(
cfg=cfg, image_dir=cfg.DATASET.TEST_IMAGE_DIR, annotation_file=cfg.DATASET.TEST_ANNOTATION_FILE, \
dataset_type=cfg.DATASET.TEST_DATASET_TYPE, \
image_set=cfg.DATASET.TEST_SET, is_train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU,
shuffle=False,
num_workers=cfg.WORKERS,
pin_memory=True
)
validate(cfg, valid_loader, valid_dataset, model, criterion,
final_output_dir, tb_log_dir, writer_dict)
writer_dict['writer'].close()
if __name__ == '__main__':
main()