05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import time
import torch
import torch.distributed as dist


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def make_average_meters(n:int):
    return [AverageMeter() for i in range(n)]


class BlockTimer(object):
    """Measures time used of code block"""

    def __init__(self, device_id='', description=''):
        if device_id in ['', None]:
            device_id = '(unspecified device)'
        if description == '':
            self.start_str = "{} starts.".format(device_id)
            self.finish_formatter = "{} finished. Time used = {{:.3f}}s".format(device_id)
        else:
            self.start_str = "{} starts {}.".format(device_id, description)
            self.finish_formatter = "{} finished {}. Time used = {{:.3f}}s".format(device_id, description)
    
    def __enter__(self):
        print(self.start_str)
        self.start_time = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        print(self.finish_formatter.format(time.time() - self.start_time))


def reduce_tensor(tensor, n):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= n
    return rt


def remove_ddp_module_prefix(state_dict):
    '''remove 'module.' prefix generated by distributed training'''
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    
    return new_state_dict
    

def load_state_dict(model_path, map_location='cpu'):
    state_dict = torch.load(model_path, map_location=map_location)
    return remove_ddp_module_prefix(state_dict)


def save_state_dict(state_dict, model_path):
    new_state_dict = remove_ddp_module_prefix(state_dict)
    torch.save(new_state_dict, model_path)