# --------------------------------------------------------

# Focal Transformer

# Copyright (c) 2021 Microsoft

# Licensed under The MIT License [see LICENSE for details]

# Modified by Jianwei Yang (jianwyan@microsoft.com)

# Based on Swin Transformer written by Zhe Liu

# --------------------------------------------------------



import os

import torch

import torch.distributed as dist

from timm.models.layers import trunc_normal_



try:

    # noinspection PyUnresolvedReferences

    from apex import amp

except ImportError:

    amp = None





def load_checkpoint(config, model, optimizer, lr_scheduler, logger):

    logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")

    if config.MODEL.RESUME.startswith('https'):

        checkpoint = torch.hub.load_state_dict_from_url(

            config.MODEL.RESUME, map_location='cpu', check_hash=True)

    else:

        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')

    if "focal" in config.MODEL.RESUME and 'model' not in checkpoint:

        checkpoint = {'model': checkpoint}

    if model.state_dict()['head.weight'].shape != checkpoint['model']['head.weight'].shape:

        # TODO: select the corresponding weights for 1K

        # checkpoint['model']['head.weight'] = checkpoint['model']['head.weight'].new(model.state_dict()['head.weight'].shape)

        # trunc_normal_(checkpoint['model']['head.weight'], std=.02)

        # checkpoint['model']['head.bias'] = checkpoint['model']['head.bias'].new(model.state_dict()['head.bias'].shape)

        # trunc_normal_(checkpoint['model']['head.bias'], std=.02)        

        checkpoint['model']['head.weight'] = model.state_dict()['head.weight'][:1000]

        checkpoint['model']['head.bias'] = model.state_dict()['head.bias'][:1000]



    msg = model.load_state_dict(checkpoint['model'], strict=False)

    logger.info(msg)

    max_accuracy = 0.0

    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:

        optimizer.load_state_dict(checkpoint['optimizer'])

        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        config.defrost()

        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1

        config.freeze()

        if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":

            amp.load_state_dict(checkpoint['amp'])

        logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")

        if 'max_accuracy' in checkpoint:

            max_accuracy = checkpoint['max_accuracy']



    del checkpoint

    torch.npu.empty_cache()

    return max_accuracy





def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):

    save_state = {'model': model.state_dict(),

                  'optimizer': optimizer.state_dict(),

                  'lr_scheduler': lr_scheduler.state_dict(),

                  'max_accuracy': max_accuracy,

                  'epoch': epoch,

                  'config': config}

    if config.AMP_OPT_LEVEL != "O0":

        save_state['amp'] = amp.state_dict()



    save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')

    logger.info(f"{save_path} saving......")

    torch.save(save_state, save_path)

    logger.info(f"{save_path} saved !!!")





def get_grad_norm(parameters, norm_type=2):

    if isinstance(parameters, torch.Tensor):

        parameters = [parameters]

    parameters = list(filter(lambda p: p.grad is not None, parameters))

    norm_type = float(norm_type)

    total_norm = 0

    for p in parameters:

        param_norm = p.grad.data.norm(norm_type)

        total_norm += param_norm.item() ** norm_type

    total_norm = total_norm ** (1. / norm_type)

    return total_norm





def auto_resume_helper(output_dir):

    checkpoints = os.listdir(output_dir)

    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]

    print(f"All checkpoints founded in {output_dir}: {checkpoints}")

    if len(checkpoints) > 0:

        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)

        print(f"The latest checkpoint founded: {latest_checkpoint}")

        resume_file = latest_checkpoint

    else:

        resume_file = None

    return resume_file





def reduce_tensor(tensor):

    rt = tensor.clone()

    dist.all_reduce(rt, op=dist.ReduceOp.SUM)

    rt /= dist.get_world_size()

    return rt