# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
import torch
import argparse
import os
import glob
import hashlib
from timm.models.helpers import load_state_dict

parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
                    help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
                    help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
                    help='output filename')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
                    help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
                    help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N',
                    help='Number of checkpoints to average')


def checkpoint_metric(checkpoint_path):
    if not checkpoint_path or not os.path.isfile(checkpoint_path):
        return {}
    print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    metric = None
    if 'metric' in checkpoint:
        metric = checkpoint['metric']
    elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
        metrics = checkpoint['metrics']
        print(metrics)
        metric = metrics[checkpoint['metric_name']]
    return metric


def main():
    args = parser.parse_args()
    # by default use the EMA weights (if present)
    args.use_ema = not args.no_use_ema
    # by default sort by checkpoint metric (if present) and avg top n checkpoints
    args.sort = not args.no_sort

    if os.path.exists(args.output):
        print("Error: Output filename ({}) already exists.".format(args.output))
        exit(1)

    pattern = args.input
    if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
        pattern += os.path.sep
    pattern += args.filter
    checkpoints = glob.glob(pattern, recursive=True)

    if args.sort:
        checkpoint_metrics = []
        for c in checkpoints:
            metric = checkpoint_metric(c)
            if metric is not None:
                checkpoint_metrics.append((metric, c))
        checkpoint_metrics = list(sorted(checkpoint_metrics))
        checkpoint_metrics = checkpoint_metrics[-args.n:]
        print("Selected checkpoints:")
        [print(m, c) for m, c in checkpoint_metrics]
        avg_checkpoints = [c for m, c in checkpoint_metrics]
    else:
        avg_checkpoints = checkpoints
        print("Selected checkpoints:")
        [print(c) for c in checkpoints]

    avg_state_dict = {}
    avg_counts = {}
    for c in avg_checkpoints:
        new_state_dict = load_state_dict(c, args.use_ema)
        if not new_state_dict:
            print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
            continue

        for k, v in new_state_dict.items():
            if k not in avg_state_dict:
                avg_state_dict[k] = v.clone().to(dtype=torch.float64)
                avg_counts[k] = 1
            else:
                avg_state_dict[k] += v.to(dtype=torch.float64)
                avg_counts[k] += 1

    for k, v in avg_state_dict.items():
        v.div_(avg_counts[k])

    # float32 overflow seems unlikely based on weights seen to date, but who knows
    float32_info = torch.finfo(torch.float32)
    final_state_dict = {}
    for k, v in avg_state_dict.items():
        v = v.clamp(float32_info.min, float32_info.max)
        final_state_dict[k] = v.to(dtype=torch.float32)

    try:
        torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
    except:
        torch.save(final_state_dict, args.output)

    with open(args.output, 'rb') as f:
        sha_hash = hashlib.sha256(f.read()).hexdigest()
    print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))


if __name__ == '__main__':
    main()