05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os, argparse
from utils.dist_utils import is_main_process, dist_print, DistSummaryWriter
from utils.config import Config
import torch
import time


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', help='path to config file')
    parser.add_argument('--local_rank', type=int, default=0)

    parser.add_argument('--dataset', default=None, type=str)
    parser.add_argument('--data_root', default=None, type=str)
    parser.add_argument('--epoch', default=None, type=int)
    parser.add_argument('--batch_size', default=None, type=int)
    parser.add_argument('--optimizer', default=None, type=str)
    parser.add_argument('--learning_rate', default=None, type=float)
    parser.add_argument('--weight_decay', default=None, type=float)
    parser.add_argument('--momentum', default=None, type=float)
    parser.add_argument('--scheduler', default=None, type=str)
    parser.add_argument('--steps', default=None, type=int, nargs='+')
    parser.add_argument('--gamma', default=None, type=float)
    parser.add_argument('--warmup', default=None, type=str)
    parser.add_argument('--warmup_iters', default=None, type=int)
    parser.add_argument('--backbone', default=None, type=str)
    parser.add_argument('--griding_num', default=None, type=int)
    parser.add_argument('--use_aux', default=None, type=str2bool)
    parser.add_argument('--sim_loss_w', default=None, type=float)
    parser.add_argument('--shp_loss_w', default=None, type=float)
    parser.add_argument('--note', default=None, type=str)
    parser.add_argument('--log_path', default=None, type=str)
    parser.add_argument('--finetune', default=None, type=str)
    parser.add_argument('--resume', default=None, type=str)
    parser.add_argument('--test_model', default=None, type=str)
    parser.add_argument('--test_work_dir', default=None, type=str)
    parser.add_argument('--num_lanes', default=None, type=int)
    parser.add_argument('--auto_backup', action='store_true', help='automatically backup current code in the log path')

    return parser


def merge_config():
    args = get_args().parse_args()
    cfg = Config.fromfile(args.config)
    items = ['dataset', 'data_root', 'epoch', 'batch_size', 'optimizer', 'learning_rate',
             'weight_decay', 'momentum', 'scheduler', 'steps', 'gamma', 'warmup', 'warmup_iters',
             'use_aux', 'griding_num', 'backbone', 'sim_loss_w', 'shp_loss_w', 'note', 'log_path',
             'finetune', 'resume', 'test_model', 'test_work_dir', 'num_lanes']
    for item in items:
        if getattr(args, item) is not None:
            dist_print('merge ', item, ' config')
            setattr(cfg, item, getattr(args, item))
    return args, cfg


def save_model(net, optimizer, epoch, save_path, distributed):
    if is_main_process():
        model_state_dict = net.state_dict()
        state = {'model': model_state_dict, 'optimizer': optimizer.state_dict()}
        # state = {'model': model_state_dict}
        assert os.path.exists(save_path)
        model_path = os.path.join(save_path, 'ep%03d.pth' % epoch)
        torch.save(state, model_path)


import pathspec


def cp_projects(auto_backup, to_path):
    if is_main_process() and auto_backup:
        with open('./.gitignore', 'r') as fp:
            ign = fp.read()
        ign += '\n.git'
        spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, ign.splitlines())
        all_files = {os.path.join(root, name) for root, dirs, files in os.walk('./') for name in files}
        matches = spec.match_files(all_files)
        matches = set(matches)
        to_cp_files = all_files - matches
        dist_print('Copying projects to ' + to_path + ' for backup')
        t0 = time.time()
        warning_flag = True
        for f in to_cp_files:
            dirs = os.path.join(to_path, 'code', os.path.split(f[2:])[0])
            if not os.path.exists(dirs):
                os.makedirs(dirs)
            os.system('cp %s %s' % (f, os.path.join(to_path, 'code', f[2:])))
            elapsed_time = time.time() - t0
            if elapsed_time > 5 and warning_flag:
                dist_print(
                    'If the program is stuck, it might be copying large files in this directory. please don\'t set --auto_backup. Or please make you working directory clean, i.e, don\'t place large files like dataset, log results under this directory.')
                warning_flag = False


import datetime, os


def get_work_dir(cfg):
    # now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    # hyper_param_str = '_lr_%1.0e_b_%d' % (cfg.learning_rate, cfg.batch_size)
    # work_dir = os.path.join(cfg.log_path, now + hyper_param_str + cfg.note)
    work_dir = os.path.join(cfg.log_path, "model_par")
    return work_dir


def get_logger(work_dir, cfg):
    logger = DistSummaryWriter(work_dir)
    config_txt = os.path.join(work_dir, 'cfg.txt')
    if is_main_process():
        with open(config_txt, 'w') as fp:
            fp.write(str(cfg))

    return logger