05360171创建于 2022年3月18日历史提交
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import pprint

import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms

import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function import validate
from utils.utils import create_logger

import dataset
import models
import torch.npu
import os


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        ckpt_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location={'cuda:0':'npu:0'})
        # print(ckpt_state_dict['pos_embedding'])  # FOR UNSeen Resolutions
        # ckpt_state_dict.pop('pos_embedding') # FOR UNSeen Resolutions
        model.load_state_dict(ckpt_state_dict, strict=True)   #  strict=False FOR UNSeen Resolutions
    else:
        model_state_file = os.path.join(
            final_output_dir, 'final_state.pth'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))
    w, h = cfg.MODEL.IMAGE_SIZE

    ######### FOR UNSeen Resolutions  #########
    # input_feature_length = int(w * h / 8 * 8)  # for TransPose-R
    # if hasattr(model, 'pos_embedding') and \
    #     model.pos_embedding is not None and input_feature_length != len(model.pos_embedding):
    #     import torch.nn.functional as F
    #     pos_embedding_org = model.pos_embedding
    #     pos_embedding_org = \
    #         pos_embedding_org.view(model.pe_h, model.pe_w, 1, -1).permute(2,3,0,1)  # [h,w,1,d]
    #     pos_embedding_new = F.interpolate(pos_embedding_org, size=(h//8,w//8), mode='bilinear', align_corners=True) #[1,d,h,w]
    #     model.pos_embedding = torch.nn.Parameter(pos_embedding_new.flatten(2).permute(2, 0, 1))
    #     print(model.pos_embedding.shape)
    ######### FOR UNSeen Resolutions  #########

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).npu()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).npu()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=False, 
    drop_last = True, sampler = None)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)


if __name__ == '__main__':
    main()