# coding: UTF-8
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import shutil
import torch
def logging(s, log_path, print_=True, log_=True):
if print_:
print(s)
if log_:
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
print('Experiment dir : {}'.format(dir_path))
return get_logger(log_path='log.txt')
def save_checkpoint(model, optimizer, path, epoch):
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))