# -*- coding: utf-8 -*-

# BSD 3-Clause License

#

# Copyright (c) 2017

# All rights reserved.

# Copyright 2022 Huawei Technologies Co., Ltd

#

# Redistribution and use in source and binary forms, with or without

# modification, are permitted provided that the following conditions are met:

#

# * Redistributions of source code must retain the above copyright notice, this

#   list of conditions and the following disclaimer.

#

# * Redistributions in binary form must reproduce the above copyright notice,

#   this list of conditions and the following disclaimer in the documentation

#   and/or other materials provided with the distribution.

#

# * Neither the name of the copyright holder nor the names of its

#   contributors may be used to endorse or promote products derived from

#   this software without specific prior written permission.

#

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"

# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE

# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# ==========================================================================





# Copyright (c) 2015-present, Facebook, Inc.

# All rights reserved.

import argparse

import datetime

import numpy as np

import time

import torch

import torch.backends.cudnn as cudnn

import json

import os

import apex

from pathlib import Path



from mixup_nova import Mixup_nova as Mixup

# from timm.data import Mixup

from timm.models import create_model

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from timm.scheduler import create_scheduler

from timm.optim import create_optimizer

from timm.utils import NativeScaler, get_state_dict, ModelEma



from datasets import build_dataset

from engine import train_one_epoch, evaluate, throughput

from losses import DistillationLoss

from samplers import RASampler

import models

import utils

from logger import create_logger



def print_tensor(name, tensors):

    if isinstance(tensors, torch.Tensor):

        print(name, tensors.shape)

    elif isinstance(tensors, tuple) or isinstance(tensors, list):

        for tensor in tensors:

            print_tensor(name, tensor)

    else:

        print(name, type(tensors))

        

def hook_func(name, module):

    def hook_function(module, inputs, outputs):

        print_tensor(name + ' inputs', inputs)

        print_tensor(name + ' outputs', outputs)

    return hook_function





def hook_for_model(model):

    for name, module in model.named_modules():

        module.register_forward_hook(hook_func('[forward]: ' + name, module))

        module.register_backward_hook(hook_func('[backward]: ' + name, module))





def get_args_parser():

    parser = argparse.ArgumentParser('Training and evaluation script', add_help=False)

    parser.add_argument('--batch-size', default=128, type=int)

    parser.add_argument('--epochs', default=300, type=int)



    # Model parameters

    parser.add_argument('--model', default='smlpnet_tiny', type=str, metavar='MODEL',

                        help='Name of model to train')

    parser.add_argument('--input-size', default=224, type=int, help='images input size')



    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',

                        help='Dropout rate (default: 0.)')

    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',

                        help='Drop path rate (default: 0.1)')



    parser.add_argument('--model-ema', action='store_true')

    parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')

    parser.set_defaults(model_ema=True)

    parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')

    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')



    # Optimizer parameters

    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',

                        help='Optimizer (default: "adamw"')

    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',

                        help='Optimizer Epsilon (default: 1e-8)')

    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',

                        help='Optimizer Betas (default: None, use opt default)')

    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',

                        help='Clip gradient norm (default: None, no clipping)')

    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',

                        help='SGD momentum (default: 0.9)')

    parser.add_argument('--weight-decay', type=float, default=0.05,

                        help='weight decay (default: 0.05)')

    # Learning rate schedule parameters

    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',

                        help='LR scheduler (default: "cosine"')

    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',

                        help='learning rate (default: 5e-4)')

    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',

                        help='learning rate noise on/off epoch percentages')

    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',

                        help='learning rate noise limit percent (default: 0.67)')

    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',

                        help='learning rate noise std-dev (default: 1.0)')

    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',

                        help='warmup learning rate (default: 1e-6)')

    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',

                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')



    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',

                        help='epoch interval to decay LR')

    parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N',

                        help='epochs to warmup LR, if scheduler supports')

    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',

                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')

    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',

                        help='patience epochs for Plateau LR scheduler (default: 10')

    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',

                        help='LR decay rate (default: 0.1)')



    # Augmentation parameters

    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',

                        help='Color jitter factor (default: 0.4)')

    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',

                        help='Use AutoAugment policy. "v0" or "original". " + \

                             "(default: rand-m9-mstd0.5-inc1)'),

    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')

    parser.add_argument('--train-interpolation', type=str, default='bicubic',

                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')



    parser.add_argument('--repeated-aug', action='store_true')

    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')

    parser.set_defaults(repeated_aug=True)



    # * Random Erase params

    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',

                        help='Random erase prob (default: 0.25)')

    parser.add_argument('--remode', type=str, default='pixel',

                        help='Random erase mode (default: "pixel")')

    parser.add_argument('--recount', type=int, default=1,

                        help='Random erase count (default: 1)')

    parser.add_argument('--resplit', action='store_true', default=False,

                        help='Do not random erase first (clean) augmentation split')



    # * Mixup params

    parser.add_argument('--mixup', type=float, default=0.8,

                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')

    parser.add_argument('--cutmix', type=float, default=1.0,

                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')

    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,

                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')

    parser.add_argument('--mixup-prob', type=float, default=1.0,

                        help='Probability of performing mixup or cutmix when either/both is enabled')

    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,

                        help='Probability of switching to cutmix when both mixup and cutmix enabled')

    parser.add_argument('--mixup-mode', type=str, default='batch',

                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')



    # Distillation parameters

    parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',

                        help='Name of teacher model to train (default: "regnety_160"')

    parser.add_argument('--teacher-path', type=str, default='')

    parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")

    parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")

    parser.add_argument('--distillation-tau', default=1.0, type=float, help="")



    # * Finetuning params

    parser.add_argument('--finetune', default='', help='finetune from checkpoint')



    # Dataset parameters

    parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,

                        help='dataset path')

    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],

                        type=str, help='Image Net dataset path')

    parser.add_argument('--inat-category', default='name',

                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],

                        type=str, help='semantic granularity')



    parser.add_argument('--output_dir', default='',

                        help='path where to save, empty for no saving')

    parser.add_argument('--device', default='npu',

                        help='device to use for training / testing')

    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--resume', default='', help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',

                        help='start epoch')

    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')

    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')

    parser.add_argument('--num_workers', default=10, type=int)

    parser.add_argument('--pin-mem', action='store_true',

                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')

    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',

                        help='')

    parser.set_defaults(pin_mem=True)



    # distributed training parameters

    parser.add_argument('--world_size', default=1, type=int,

                        help='number of distributed processes')

    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    parser.add_argument("--local_rank", type=int, default=0)

    # parameters for training on preemptible clusters

    parser.add_argument('--auto-resume', action='store_true')

    parser.add_argument('--no-auto-resume', action='store_false', dest='auto_resume')

    parser.set_defaults(auto_resume=True)



    # spach parameters

    parser.add_argument('--stem-type', default='conv1', type=str, choices=['conv1', 'conv4'])

    parser.add_argument('--shared-spatial-func', action='store_true')

    # npu parameters

    parser.add_argument('--npu', action='store_true', default=False, help='Enabling npu training')

    # parameters for benchmark

    parser.add_argument('--throughput', action='store_true')



    return parser





def parse_model_args(args):

    model = args.model

    model_args = []

    if model.startswith('spach'):

        model_args = ['stem_type', 'shared_spatial_func']

    args = vars(args)

    model_args = {_: args[_] for _ in model_args}

    return model_args





def main(args):



    utils.init_distributed_mode(args)

# 日志设置

    logger = create_logger(args.output_dir, utils.get_rank(), args.model)

    logger.info(args)



    if args.distillation_type != 'none' and args.finetune and not args.eval:

        raise NotImplementedError("Finetuning with distillation not yet supported")



    if args.npu:

        device = f'npu:{str(utils.get_rank())}'

        #device = 'npu:3'

    else:

        device = torch.device(args.device)



    # fix the seed for reproducibility

    seed = args.seed + utils.get_rank()

    torch.manual_seed(seed)

    np.random.seed(seed)

    # random.seed(seed)

    cudnn.benchmark = True



# 数据集

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)

    dataset_val, _ = build_dataset(is_train=False, args=args)



    if args.distributed:  # args.distributed:

        num_tasks = utils.get_world_size()

        global_rank = utils.get_rank()

        if args.repeated_aug:

            sampler_train = RASampler(

                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True

            )

        else:

            sampler_train = torch.utils.data.DistributedSampler(

                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True

            )

        if args.dist_eval:

            if len(dataset_val) % num_tasks != 0:

                logger.info('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '

                            'This will slightly alter validation results as extra duplicate entries are added to achieve '

                            'equal num of samples per-process.')

            sampler_val = torch.utils.data.DistributedSampler(

                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)

        else:

            sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    else:

        sampler_train = torch.utils.data.RandomSampler(dataset_train)

        sampler_val = torch.utils.data.SequentialSampler(dataset_val)



    data_loader_train = torch.utils.data.DataLoader(

        dataset_train, sampler=sampler_train,

        batch_size=args.batch_size,

        num_workers=args.num_workers,

        pin_memory=args.pin_mem,

        drop_last=True,

    )



    data_loader_val = torch.utils.data.DataLoader(

        dataset_val, sampler=sampler_val,

        batch_size=int(1.5 * args.batch_size),

        num_workers=args.num_workers,

        pin_memory=args.pin_mem,

        drop_last=False

    )



    mixup_fn = None

    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None

    if mixup_active:

        mixup_fn = Mixup(

            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,

            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,

            label_smoothing=args.smoothing, num_classes=args.nb_classes)

# 创建模型

    logger.info(f"Creating model: {args.model}")

    model = create_model(

        args.model,

        pretrained=False,

        num_classes=args.nb_classes,

        drop_rate=args.drop,

        drop_path_rate=args.drop_path,

        drop_block_rate=None,

        **parse_model_args(args)

    )

# 微调情况

    if args.finetune:

        if args.finetune.startswith('https'):

            checkpoint = torch.hub.load_state_dict_from_url(

                args.finetune, map_location='cpu', check_hash=True)

        else:

            checkpoint = torch.load(args.finetune, map_location='cpu')



        checkpoint_model = checkpoint['model']

        state_dict = model.state_dict()

        for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:

            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:

                logger.info(f"Removing key {k} from pretrained checkpoint")

                del checkpoint_model[k]



        # interpolate position embedding

        pos_embed_checkpoint = checkpoint_model['pos_embed']

        embedding_size = pos_embed_checkpoint.shape[-1]

        num_patches = model.patch_embed.num_patches

        num_extra_tokens = model.pos_embed.shape[-2] - num_patches

        # height (== width) for the checkpoint position embedding

        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)

        # height (== width) for the new position embedding

        new_size = int(num_patches ** 0.5)

        # class_token and dist_token are kept unchanged

        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]

        # only the position tokens are interpolated

        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]

        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)

        pos_tokens = torch.nn.functional.interpolate(

            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)

        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)

        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)

        checkpoint_model['pos_embed'] = new_pos_embed



        model.load_state_dict(checkpoint_model, strict=False)

# 发送到设备

    print(device, flush=True)

    model.to(device)

    

    #hook_for_model(model)

# ema model

    model_ema = None

    if args.model_ema:

        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper

        model_ema = ModelEma(

            model,

            decay=args.model_ema_decay,

            device='cpu' if args.model_ema_force_cpu else '',

            resume='')

# ddp model

    model_without_ddp = model

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0

    args.lr = linear_scaled_lr

    optimizer = apex.optimizers.NpuFusedAdamW(model.parameters(), args.lr,

                          weight_decay=args.weight_decay)

    lr_scheduler, _ = create_scheduler(args, optimizer)

    model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", loss_scale="dynamic", combine_grad=True)



    if args.distributed:

        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False)

        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],

        #                                                  find_unused_parameters=False, broadcast_buffers=False)

        model_without_ddp = model.module



    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    logger.info(f'number of params: {n_parameters}')

    if hasattr(model_without_ddp, 'flops'):

        try:

            flops = model_without_ddp.flops()

            logger.info(f"number of GFLOPs: {flops / 1e9}")

        except Exception as e:

            logger.exception(e)

# 学习率设置

    # linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0

    # args.lr = linear_scaled_lr

    # optimizer = create_optimizer(args, model_without_ddp)

    # loss_scaler = NativeScaler()



    # lr_scheduler, _ = create_scheduler(args, optimizer)



    criterion = LabelSmoothingCrossEntropy()



    if args.mixup > 0.:

        # smoothing is handled with mixup label transform

        criterion = SoftTargetCrossEntropy()

    elif args.smoothing:

        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)

    else:

        criterion = torch.nn.CrossEntropyLoss()



    teacher_model = None

    if args.distillation_type != 'none':

        assert args.teacher_path, 'need to specify teacher-path when using distillation'

        logger.info(f"Creating teacher model: {args.teacher_model}")

        teacher_model = create_model(

            args.teacher_model,

            pretrained=False,

            num_classes=args.nb_classes,

            global_pool='avg',

        )

        if args.teacher_path.startswith('https'):

            checkpoint = torch.hub.load_state_dict_from_url(

                args.teacher_path, map_location='cpu', check_hash=True)

        else:

            checkpoint = torch.load(args.teacher_path, map_location='cpu')

        teacher_model.load_state_dict(checkpoint['model'])

        teacher_model.to(device)

        teacher_model.eval()



    # wrap the criterion in our custom DistillationLoss, which

    # just dispatches to the original criterion if args.distillation_type is 'none'

    criterion = DistillationLoss(

        criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau

    )



    output_dir = Path(args.output_dir)

    if args.auto_resume:

        _resume = str((output_dir / 'checkpoint.pth').absolute())

        if os.path.exists(_resume):

            logger.info(f'auto resume from {output_dir}/checkpoint.pth')

            args.resume = _resume



    if args.resume:

        if args.resume.startswith('https'):

            checkpoint = torch.hub.load_state_dict_from_url(

                args.resume, map_location='cpu', check_hash=True)

        else:

            checkpoint = torch.load(args.resume, map_location='cpu')

        model_without_ddp.load_state_dict(checkpoint['model'])

        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:

            optimizer.load_state_dict(checkpoint['optimizer'])

            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

            args.start_epoch = checkpoint['epoch'] + 1

            if args.model_ema:

                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])

            # if 'scaler' in checkpoint:

            #    loss_scaler.load_state_dict(checkpoint['scaler'])



    if args.eval:

        test_stats = evaluate(data_loader_val, model, device, args.batch_size, logger=logger, use_npu=args.npu)

        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

        return



    if args.throughput:

        throughput(data_loader_val, model, logger=logger, use_npu=args.npu)

        return



    criterion = criterion.to(device)

    logger.info(f"Start training for {args.epochs} epochs")

    start_time = time.time()

    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:

            data_loader_train.sampler.set_epoch(epoch)



        train_stats = train_one_epoch(

            model, criterion, data_loader_train,

            optimizer, device, epoch, args.output_dir, args.batch_size, 

            # loss_scaler,

            args.clip_grad, model_ema, mixup_fn,

            set_training_mode=args.finetune == '',  # keep in eval mode during finetuning

            logger=logger, 

            use_npu=args.npu

        )



        lr_scheduler.step(epoch)

        if args.output_dir and epoch % 5==0:

            checkpoint_paths = [output_dir / f'checkpoint_{str(epoch)}.pth']

            for checkpoint_path in checkpoint_paths:

                utils.save_on_master({

                    'model': model_without_ddp.state_dict(),

                    'optimizer': optimizer.state_dict(),

                    'lr_scheduler': lr_scheduler.state_dict(),

                    'epoch': epoch,

                    'model_ema': get_state_dict(model_ema),

                    # 'scaler': loss_scaler.state_dict(),

                    'amp': apex.amp.state_dict(),

                    'args': args,

                }, checkpoint_path)



        test_stats = evaluate(data_loader_val, model, device, args.batch_size, logger=logger, use_npu=args.npu)

        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")



        if test_stats["acc1"] > max_accuracy:

            best_checkpoint_path = output_dir / 'best.pth'

            utils.save_on_master({

                'model': model_without_ddp.state_dict(),

                'optimizer': optimizer.state_dict(),

                'lr_scheduler': lr_scheduler.state_dict(),

                'epoch': epoch,

                'model_ema': get_state_dict(model_ema),

                # 'scaler': loss_scaler.state_dict(),

                'amp': apex.amp.state_dict(),

                'args': args,

            }, best_checkpoint_path)

        max_accuracy = max(max_accuracy, test_stats["acc1"])

        logger.info(f'Max accuracy: {max_accuracy:.2f}%')



        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},

                     **{f'test_{k}': v for k, v in test_stats.items()},

                     'epoch': epoch,

                     'n_parameters': n_parameters}



        if args.output_dir and utils.is_main_process():

            with (output_dir / "log.txt").open("a") as f:

                f.write(json.dumps(log_stats) + "\n")



    total_time = time.time() - start_time

    total_time_str = str(datetime.timedelta(seconds=int(total_time)))

    logger.info('Training time {}'.format(total_time_str))





if __name__ == '__main__':

    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])

    args = parser.parse_args()

    if args.output_dir:

        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    main(args)