05360171创建于 2022年3月18日历史提交
#!/usr/bin/env python

# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the BSD 3-Clause License  (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# https://opensource.org/licenses/BSD-3-Clause

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



""" ImageNet Training Script

This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet

training results with some of the latest networks and training techniques. It favours canonical PyTorch

and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed

and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.

This script was started from an early version of the PyTorch ImageNet example

(https://github.com/pytorch/examples/tree/master/imagenet)

NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples

(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)

"""

import argparse

import time

import yaml

import os

import logging

from collections import OrderedDict

from contextlib import suppress

from datetime import datetime

import glob

import shutil



import torch

import torch.nn as nn

import torchvision.utils

from torch.nn.parallel import DistributedDataParallel as NativeDDP

import torch.onnx



from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset

from timm.data import create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset

from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model

from timm.utils import *

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy

from timm.optim import create_optimizer

from timm.scheduler import create_scheduler

from timm.utils import ApexScaler, NativeScaler

from ghostnet.ghostnet_pytorch.ghostnet import ghostnet

import torch.npu



# modelarts modification

import moxing as mox



CALCULATE_DEVICE = "npu:0"



try:

    from apex import amp

    from apex.parallel import DistributedDataParallel as ApexDDP

    from apex.parallel import convert_syncbn_model



    has_apex = True

except ImportError:

    has_apex = False



has_native_amp = False

try:

    if getattr(torch.cuda.amp, 'autocast') is not None:

        has_native_amp = True

except AttributeError:

    pass



torch.backends.cudnn.benchmark = True

_logger = logging.getLogger('train')



# The first arg parser parses out only the --config argument, this argument is used to

# load a yaml file containing key-values that override the defaults for the main parser below

config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)

parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',

                    help='YAML config file specifying default arguments')

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Dataset / Model parameters

parser.add_argument('data_dir', metavar='DIR',

                    help='path to dataset')                                                                                        

parser.add_argument('--dataset', '-d', metavar='NAME', default='',

                    help='dataset type (default: ImageFolder/ImageTar if empty)')

parser.add_argument('--train-split', metavar='NAME', default='train',

                    help='dataset train split (default: train)')

parser.add_argument('--val-split', metavar='NAME', default='validation',

                    help='dataset validation split (default: validation)')

parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',

                    help='Name of model to train (default: "countception"')

parser.add_argument('--pretrained', action='store_true', default=False,

                    help='Start with pretrained version of specified network (if avail)')

parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',

                    help='Initialize model from this checkpoint (default: none)')

parser.add_argument('--resume', default='', type=str, metavar='PATH',

                    help='Resume full model and optimizer state from checkpoint (default: none)')

parser.add_argument('--no-resume-opt', action='store_true', default=False,

                    help='prevent resume of optimizer state when resuming model')

parser.add_argument('--num-classes', type=int, default=None, metavar='N',

                    help='number of label classes (Model default if None)')

parser.add_argument('--gp', default=None, type=str, metavar='POOL',

                    help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')

parser.add_argument('--img-size', type=int, default=None, metavar='N',

                    help='Image patch size (default: None => model default)')

parser.add_argument('--input-size', default=None, nargs=3, type=int,

                    metavar='N N N',

                    help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')

parser.add_argument('--crop-pct', default=None, type=float,

                    metavar='N', help='Input image center crop percent (for validation only)')

parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',

                    help='Override mean pixel value of dataset')

parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',

                    help='Override std deviation of of dataset')

parser.add_argument('--interpolation', default='', type=str, metavar='NAME',

                    help='Image resize interpolation type (overrides model)')

parser.add_argument('-b', '--batch-size', type=int, default=1024, metavar='N',

                    help='input batch size for training (default: 32)')

parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',

                    help='ratio of validation batch size to training batch size (default: 1)')



# Optimizer parameters

parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',

                    help='Optimizer (default: "sgd"')

parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',

                    help='Optimizer Epsilon (default: None, use opt default)')

parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',

                    help='Optimizer Betas (default: None, use opt default)')

parser.add_argument('--momentum', type=float, default=0.9, metavar='M',

                    help='Optimizer momentum (default: 0.9)')

parser.add_argument('--weight-decay', type=float, default=0.0001,

                    help='weight decay (default: 0.0001)')

parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',

                    help='Clip gradient norm (default: None, no clipping)')



# Learning rate schedule parameters

parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',

                    help='LR scheduler (default: "step"')

parser.add_argument('--lr', type=float, default=0.4, metavar='LR',

                    help='learning rate (default: 0.01)')

parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',

                    help='learning rate noise on/off epoch percentages')

parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',

                    help='learning rate noise limit percent (default: 0.67)')

parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',

                    help='learning rate noise std-dev (default: 1.0)')

parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',

                    help='learning rate cycle len multiplier (default: 1.0)')

parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',

                    help='learning rate cycle limit')

parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',

                    help='warmup learning rate (default: 0.0001)')

parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',

                    help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

parser.add_argument('--epochs', type=int, default=200, metavar='N',

                    help='number of epochs to train (default: 2)')

parser.add_argument('--start-epoch', default=None, type=int, metavar='N',

                    help='manual epoch number (useful on restarts)')

parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',

                    help='epoch interval to decay LR')

parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',

                    help='epochs to warmup LR, if scheduler supports')

parser.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',

                    help='epochs to cooldown LR at min_lr, after cyclic schedule ends')

parser.add_argument('--patience-epochs', type=int, default=0, metavar='N',

                    help='patience epochs for Plateau LR scheduler (default: 10')

parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',

                    help='LR decay rate (default: 0.1)')



# Augmentation & regularization parameters

parser.add_argument('--no-aug', action='store_true', default=False,

                    help='Disable all training augmentation, override other train aug args')

parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',

                    help='Random resize scale (default: 0.08 1.0)')

parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',

                    help='Random resize aspect ratio (default: 0.75 1.33)')

parser.add_argument('--hflip', type=float, default=0.5,

                    help='Horizontal flip training aug probability')

parser.add_argument('--vflip', type=float, default=0.,

                    help='Vertical flip training aug probability')

parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',

                    help='Color jitter factor (default: 0.4)')

parser.add_argument('--aug-splits', type=int, default=0,

                    help='Number of augmentation splits (default: 0, valid: 0 or >=2)')

parser.add_argument('--jsd', action='store_true', default=False,

                    help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')

parser.add_argument('--recount', type=int, default=1,

                    help='Random erase count (default: 1)')

parser.add_argument('--resplit', action='store_true', default=False,

                    help='Do not random erase first (clean) augmentation split')

parser.add_argument('--mixup', type=float, default=0.0,

                    help='mixup alpha, mixup enabled if > 0. (default: 0.)')

parser.add_argument('--cutmix', type=float, default=0.0,

                    help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')

parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,

                    help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')

parser.add_argument('--mixup-prob', type=float, default=1.0,

                    help='Probability of performing mixup or cutmix when either/both is enabled')

parser.add_argument('--mixup-switch-prob', type=float, default=0.5,

                    help='Probability of switching to cutmix when both mixup and cutmix enabled')

parser.add_argument('--mixup-mode', type=str, default='batch',

                    help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',

                    help='Turn off mixup after this epoch, disabled if 0 (default: 0)')

parser.add_argument('--smoothing', type=float, default=0.1,

                    help='Label smoothing (default: 0.1)')

parser.add_argument('--train-interpolation', type=str, default='random',

                    help='Training interpolation (random, bilinear, bicubic default: "random")')

parser.add_argument('--drop', type=float, default=0.2, metavar='PCT',

                    help='Dropout rate (default: 0.)')

parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',

                    help='Drop path rate (default: None)')

parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',

                    help='Drop block rate (default: None)')

parser.add_argument('--bn-tf', action='store_true', default=False,

                    help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')

parser.add_argument('--bn-momentum', type=float, default=None,

                    help='BatchNorm momentum override (if not None)')

parser.add_argument('--bn-eps', type=float, default=None,

                    help='BatchNorm epsilon override (if not None)')

parser.add_argument('--sync-bn', action='store_true',

                    help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')

parser.add_argument('--dist-bn', type=str, default='',

                    help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')

parser.add_argument('--split-bn', action='store_true',

                    help='Enable separate BN layers per augmentation split.')

# Misc

parser.add_argument('--seed', type=int, default=42, metavar='S',

                    help='random seed (default: 42)')

parser.add_argument('--log-interval', type=int, default=50, metavar='N',

                    help='how many batches to wait before logging training status')

parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',

                    help='how many batches to wait before writing recovery checkpoint')

parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',

                    help='number of checkpoints to keep (default: 10)')

parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',

                    help='how many training processes to use (default: 1)')

parser.add_argument('--save-images', action='store_true', default=False,

                    help='save images of input bathes every log interval for debugging')

parser.add_argument('--amp', action='store_true', default=False,

                    help='use NVIDIA Apex AMP or Native AMP for mixed precision training')

parser.add_argument('--apex-amp', action='store_true', default=False,

                    help='Use NVIDIA Apex AMP mixed precision')

parser.add_argument('--native-amp', action='store_true', default=False,

                    help='Use Native Torch AMP mixed precision')

parser.add_argument('--channels-last', action='store_true', default=False,

                    help='Use channels_last memory layout')

parser.add_argument('--pin-mem', action='store_true', default=False,

                    help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')

parser.add_argument('--no-prefetcher', action='store_true', default=True,

                    help='disable fast prefetcher')

parser.add_argument('--output', default='', type=str, metavar='PATH',

                    help='path to output folder (default: none, current dir)')

parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',

                    help='Best metric (default: "top1"')

parser.add_argument('--tta', type=int, default=0, metavar='N',

                    help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')

parser.add_argument("--local_rank", default=0, type=int)

parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,

                    help='use the multi-epochs-loader to save time at the beginning of every epoch')

parser.add_argument('--torchscript', dest='torchscript', action='store_true',

                    help='convert model torchscript for inference')

parser.add_argument('--width', type=float, default=1.0,

                    help='Width ratio (default: 1.0)')

parser.add_argument('--dist-url', default='tcp://127.0.0.1:50000', type=str,

                    help='url used to set up pretained training')

parser.add_argument('--dist-backend', default='nccl', type=str,

                    help='distributed backend')

parser.add_argument('--npu', default=None, type=int,

                    help='NPU id to use.')

#modelarts

parser.add_argument('--modelarts_mod', action='store_true', default=False,

                    help='Enable modelarts mode loss function to train')

parser.add_argument('--train_url',

                    default="/cache/training",

                    type=str,

                    help="setting dir of training output")

parser.add_argument('--pretrained_weight', default='', type=str, metavar='PATH',

                    help='path to pretrained weight')

parser.add_argument('--onnx', default=True, action='store_true',

                    help="convert pth model to onnx")

parser.add_argument('--data_url',

                    type=str,

                    default='/cache/data_url',

                    help='the training data')

                    

CACHE_TRAINING_URL = "/cache/training"

def _parse_args():

    # Do we have a config file to parse?

    args_config, remaining = config_parser.parse_known_args()

    if args_config.config:

        with open(args_config.config, 'r') as f:

            cfg = yaml.safe_load(f)

            parser.set_defaults(**cfg)

    # The main arg parser parses the rest of the args, the usual

    # defaults will have been overridden if config file specified.

    args = parser.parse_args(remaining)

    print(args)

    # Cache the args as a text string to save them in the output dir later

    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    return args, args_text





def main():

    setup_default_logging()

    args, args_text = _parse_args()



    args.prefetcher = not args.no_prefetcher

    args.distributed = False

    if args.npu is None:

        args.npu = 0

    CALCULATE_DEVICE = "npu:{}".format(args.npu)

    torch.npu.set_device(CALCULATE_DEVICE)

    print("use ", CALCULATE_DEVICE)



    if 'WORLD_SIZE' in os.environ:

        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.device = 'npu:0'

    args.world_size = 1

    args.rank = 0  # global rank

    _logger.info('Training with a single process on 1 NPUs.')

    assert args.rank >= 0



    # resolve AMP arguments based on PyTorch / Apex availability

    use_amp = None

    if args.amp:

        # for backwards compat, `--amp` arg tries apex before native amp

        if has_apex:

            args.apex_amp = True

        elif has_native_amp:

            args.native_amp = True

    if args.apex_amp and has_apex:

        use_amp = 'apex'

    elif args.native_amp and has_native_amp:

        use_amp = 'native'

    elif args.apex_amp or args.native_amp:

        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "

                        "Install NVIDA apex or upgrade to PyTorch 1.6")



    torch.manual_seed(args.seed + args.rank)

    model = ghostnet(num_classes=args.num_classes, width=args.width, dropout=args.drop)

    

    if args.pretrained:

        CACHE_MODEL_URL = "/cache/model"

        os.makedirs(CACHE_MODEL_URL, exist_ok=True)

        mox.file.copy_parallel(args.pretrained_weight, os.path.join(CACHE_MODEL_URL, "model_best.pth.tar"))

        pretrained_weight = os.path.join(CACHE_MODEL_URL, "model_best.pth.tar")

        pretrained_dict = torch.load(pretrained_weight)["state_dict"]

        pretrained_model = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}

        if "classifier.weight" in pretrained_model:

            pretrained_model.pop('classifier.weight')

            pretrained_model.pop('classifier.bias')

        model.load_state_dict(pretrained_model, strict=False)



        for param in model.parameters():

            param.requires_grad = False



        for param in model.classifier.parameters():

            param.requires_grad = True

                         

    if args.num_classes is None:

        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly



    if args.local_rank == 0:

        _logger.info('Model %s created, param count: %d' %

                     (args.model, sum([m.numel() for m in model.parameters()])))



    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)



    # setup augmentation batch splits for contrastive loss or split bn

    num_aug_splits = 0

    if args.aug_splits > 0:

        assert args.aug_splits > 1, 'A split of 1 makes no sense'

        num_aug_splits = args.aug_splits



    # enable split bn (separate bn stats per batch-portion)

    if args.split_bn:

        assert num_aug_splits > 1 or args.resplit

        model = convert_splitbn_model(model, max(num_aug_splits, 2))



    # move model to GPU, enable channels last layout if set

    model = model.to(CALCULATE_DEVICE)

    if args.channels_last:

        model = model.to(memory_format=torch.channels_last)



    if args.torchscript:

        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'

        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'

        model = torch.jit.script(model)



    optimizer = create_optimizer(args, model)



    # setup automatic mixed-precision (AMP) loss scaling and op casting

    amp_autocast = suppress  # do nothing

    loss_scaler = None

    if use_amp == 'apex':

        model, optimizer = amp.initialize(model, optimizer, opt_level='O2', loss_scale=128)

        loss_scaler = ApexScaler()

        if args.local_rank == 0:

            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')

    elif use_amp == 'native':

        amp_autocast = torch.cuda.amp.autocast

        loss_scaler = NativeScaler()

        if args.local_rank == 0:

            _logger.info('Using native Torch AMP. Training in mixed precision.')

    else:

        if args.local_rank == 0:

            _logger.info('AMP not enabled. Training in float32.')



    # optionally resume from a checkpoint

    resume_epoch = None

    if args.resume:

        resume_epoch = resume_checkpoint(

            model, args.resume,

            optimizer=None if args.no_resume_opt else optimizer,

            loss_scaler=None if args.no_resume_opt else loss_scaler,

            log_info=args.local_rank == 0)



    # setup exponential moving average of model weights, SWA could be used here too

    model_ema = None

    # setup learning rate schedule and starting epoch

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)

    start_epoch = 0

    if args.start_epoch is not None:

        # a specified start_epoch will always override the resume epoch

        start_epoch = args.start_epoch

    elif resume_epoch is not None:

        start_epoch = resume_epoch

    if lr_scheduler is not None and start_epoch > 0:

        lr_scheduler.step(start_epoch)



    if args.local_rank == 0:

        _logger.info('Scheduled epochs: {}'.format(num_epochs))



    # create the train and eval datasets

    real_path = '/cache/data_url'

    if not os.path.exists(real_path):

        os.makedirs(real_path)

    mox.file.copy_parallel(args.data_url, real_path)

    print("training data finish copy to %s." % real_path)

    dataset_train = create_dataset(

        args.dataset, root=real_path, split=args.train_split, is_training=True, batch_size=args.batch_size)

    dataset_eval = create_dataset(

        args.dataset, root=real_path, split=args.val_split, is_training=False, batch_size=args.batch_size)              

    # setup mixup / cutmix

    collate_fn = None

    mixup_fn = None

    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None

    if mixup_active:

        mixup_args = dict(

            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,

            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,

            label_smoothing=args.smoothing, num_classes=args.num_classes)

        if args.prefetcher:

            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)

            collate_fn = FastCollateMixup(**mixup_args)

        else:

            mixup_fn = Mixup(**mixup_args)



    # wrap dataset in AugMix helper

    if num_aug_splits > 1:

        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)



    # create data loaders w/ augmentation pipeiine

    train_interpolation = args.train_interpolation

    if args.no_aug or not train_interpolation:

        train_interpolation = data_config['interpolation']

    loader_train = create_loader(

        dataset_train,

        input_size=data_config['input_size'],

        batch_size=args.batch_size,

        is_training=True,

        use_prefetcher=not args.no_prefetcher,

        no_aug=args.no_aug,

        re_count=args.recount,

        re_split=args.resplit,

        scale=args.scale,

        ratio=args.ratio,

        hflip=args.hflip,

        vflip=args.vflip,

        color_jitter=args.color_jitter,

        num_aug_splits=num_aug_splits,

        interpolation=train_interpolation,

        mean=data_config['mean'],

        std=data_config['std'],

        num_workers=args.workers,

        collate_fn=collate_fn,

        pin_memory=args.pin_mem,

        use_multi_epochs_loader=args.use_multi_epochs_loader

    )



    loader_eval = create_loader(

        dataset_eval,

        input_size=data_config['input_size'],

        batch_size=args.validation_batch_size_multiplier * args.batch_size,

        is_training=False,

        use_prefetcher=not args.no_prefetcher,

        interpolation=data_config['interpolation'],

        mean=data_config['mean'],

        std=data_config['std'],

        num_workers=args.workers,

        crop_pct=data_config['crop_pct'],

        pin_memory=args.pin_mem,

    )



    # setup loss function

    if args.jsd:

        #assert num_aug_splits > 1  # JSD only valid with aug splits set

        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).to(CALCULATE_DEVICE)

    elif mixup_active:

        # smoothing is handled with mixup target transform

        train_loss_fn = SoftTargetCrossEntropy().to(CALCULATE_DEVICE)

    elif args.smoothing:

        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).to(CALCULATE_DEVICE)

    else:

        train_loss_fn = nn.CrossEntropyLoss().to(CALCULATE_DEVICE)

    if args.modelarts_mod:

        train_loss_fn = nn.CrossEntropyLoss().to(CALCULATE_DEVICE)

    

    validate_loss_fn = nn.CrossEntropyLoss().to(CALCULATE_DEVICE)

    # setup checkpoint saver and eval metric tracking

    eval_metric = args.eval_metric

    best_metric = None

    best_epoch = None

    saver = None

    output_dir = ''

    if args.local_rank == 0:

        output_base = args.output if args.output else './output'

        exp_name = '-'.join([

            datetime.now().strftime("%Y%m%d-%H%M%S"),

            args.model,

            str(data_config['input_size'][-1])

        ])

        output_dir = get_outdir(output_base, 'train', exp_name)

        decreasing = True if eval_metric == 'loss' else False

        saver = CheckpointSaver(

            model=model, optimizer=optimizer, args=args, amp_scaler=loss_scaler,

            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)

            

        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:

            f.write(args_text)



    try:

        for epoch in range(start_epoch, num_epochs):



            train_metrics = train_one_epoch(

                epoch, model, loader_train, optimizer, train_loss_fn, args,

                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,

                amp_autocast=amp_autocast, loss_scaler=loss_scaler, mixup_fn=mixup_fn)



            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)



            if lr_scheduler is not None:

                # step LR for next epoch

                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])



            update_summary(

                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),

                write_header=best_metric is None)



            if saver is not None:

                # save proper checkpoint with eval metric

                save_metric = eval_metrics[eval_metric]

                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)



    except KeyboardInterrupt:

        pass

    if best_metric is not None:

        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

      

    if args.onnx:

        os.makedirs(CACHE_TRAINING_URL, exist_ok=True)

        print("abspath:",os.path.abspath(output_dir))

        print("output_dir:",os.listdir(output_dir))

        shutil.copy(os.path.join(os.path.abspath(output_dir), 'model_best.pth.tar'), CACHE_TRAINING_URL)   

        pth_pattern = os.path.join(CACHE_TRAINING_URL, 'model_best.pth.tar')

        print("pth_pattern:",os.path.abspath(pth_pattern))

        pth_file_list = glob.glob(pth_pattern)

        if not pth_file_list:

            print(f"can't find pth {pth_pattern}")

        pth_file = pth_file_list[0]

        onnx_path = pth_file.split(".")[0] + '.onnx'

        convert(pth_file, onnx_path)              

        # --------------modelarts modification----------

        mox.file.copy_parallel(CACHE_TRAINING_URL, args.train_url)

        # --------------modelarts modification end----------

        print("CACHE_TRAINING_URL:",os.listdir(CACHE_TRAINING_URL))

        

def train_one_epoch(

        epoch, model, loader, optimizer, loss_fn, args,

        lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,

        loss_scaler=None, mixup_fn=None):

    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:

        if args.prefetcher and loader.mixup_enabled:

            loader.mixup_enabled = False

        elif mixup_fn is not None:

            mixup_fn.mixup_enabled = False



    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order

    batch_time_m = AverageMeter()

    data_time_m = AverageMeter()

    losses_m = AverageMeter(start_count_index=0)



    model.train()



    end = time.time()

    last_idx = len(loader) - 1

    num_updates = epoch * len(loader)

    for batch_idx, (input, target) in enumerate(loader):

        last_batch = batch_idx == last_idx

        data_time_m.update(time.time() - end)

        if not args.prefetcher:

            input, target = input.npu(), target.npu()

        if mixup_fn is not None:

            input, target = mixup_fn(input, target)

        if args.channels_last:

            input = input.contiguous(memory_format=torch.channels_last)

               

        with amp_autocast():

            output = model(input)

            target = target.to(torch.int32)

            loss = loss_fn(output, target)

            

        losses_m.update(loss.item(), input.size(0))    



        optimizer.zero_grad()

        if loss_scaler is not None:

            loss_scaler(

                loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)

        else:

            loss.backward(create_graph=second_order)

            if args.clip_grad is not None:

                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)

            optimizer.step()



        num_updates += 1

        batch_time_m.update(time.time() - end)

        if last_batch or batch_idx % args.log_interval == 0:

            lrl = [param_group['lr'] for param_group in optimizer.param_groups]

            lr = sum(lrl) / len(lrl)



            if args.local_rank == 0:

                if batch_time_m.avg > 0:

                    _logger.info(

                        'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '

                        'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '

                        'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '

                        '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '

                        'LR: {lr:.3e}  '

                        'Data: {data_time.val:.3f} ({data_time.avg:.3f})'

                        'fps: {fps:.3f}  '

                        'Batch_Size:{batch_size:.1f}  '.format(

                            epoch,

                            batch_idx, len(loader),

                            100. * batch_idx / last_idx,

                            loss=losses_m,

                            batch_time=batch_time_m,

                            rate=input.size(0) * args.world_size / batch_time_m.val,

                            rate_avg=input.size(0) * args.world_size / batch_time_m.avg,

                            lr=lr,

                            data_time=data_time_m,

                            fps=args.batch_size / batch_time_m.avg,

                            batch_size=args.batch_size))



                if args.save_images and output_dir:

                    torchvision.utils.save_image(

                        input,

                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),

                        padding=0,

                        normalize=True)



        if saver is not None and args.recovery_interval and (

                last_batch or (batch_idx + 1) % args.recovery_interval == 0):

            saver.save_recovery(epoch, batch_idx=batch_idx)



        if lr_scheduler is not None:

            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)



        end = time.time()

        # end for



    if hasattr(optimizer, 'sync_lookahead'):

        optimizer.sync_lookahead()



    return OrderedDict([('loss', losses_m.avg)])





def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):

    batch_time_m = AverageMeter()

    losses_m = AverageMeter(start_count_index=0)

    top1_m = AverageMeter(start_count_index=0)

    top5_m = AverageMeter(start_count_index=0)

    model.eval()



    end = time.time()

    last_idx = len(loader) - 1

    with torch.no_grad():

        for batch_idx, (input, target) in enumerate(loader):

            last_batch = batch_idx == last_idx

            if args.no_prefetcher:

                input = input.npu()

                target = target.npu()

            if args.channels_last:

                input = input.contiguous(memory_format=torch.channels_last)



            with amp_autocast():

                output = model(input)

            if isinstance(output, (tuple, list)):

                output = output[0]



            # augmentation reduction

            reduce_factor = args.tta

            if reduce_factor > 1:

                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)

                target = target[0:target.size(0):reduce_factor]



            target = target.to(torch.int32)

            loss = loss_fn(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))



            reduced_loss = loss.data



            losses_m.update(reduced_loss.item(), input.size(0))

            top1_m.update(acc1.item(), output.size(0))

            top5_m.update(acc5.item(), output.size(0))



            batch_time_m.update(time.time() - end)

            end = time.time()

            if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):

                if batch_time_m.avg > 0:

                    log_name = 'Test' + log_suffix

                    _logger.info(

                        '{0}: [{1:>4d}/{2}]  '

                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '

                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '

                        'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '

                        'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'

                        'fps: {fps:.3f}  '

                        'Batch_Size:{batch_size:.1f}  '.format(

                            log_name, batch_idx, last_idx, batch_time=batch_time_m,

                            loss=losses_m, top1=top1_m, top5=top5_m, fps=args.batch_size / batch_time_m.avg,

                            batch_size=args.batch_size))



    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])



    return metrics



def proc_node_module(checkpoint, attr_name):

    new_model_state = OrderedDict()

    for k, v in checkpoint[attr_name].items():

        if(k[0: 7] == "module."):

            name = k[7:]

        else:

            name = k[0:]

        new_model_state[name] = v

    return new_model_state

    

def convert(pth_file_path, onnx_file_path):

    args, args_text = _parse_args()

    checkpoint = torch.load(pth_file_path, map_location='cpu')

    checkpoint['state_dict'] = proc_node_module(checkpoint, 'state_dict')

    model = ghostnet(num_classes=args.num_classes, width=args.width, dropout=args.drop)

    model.load_state_dict(checkpoint['state_dict'], False) 

    model.eval()



    input_names = ["image"]

    output_names = ["output1"]

    dummy_input = torch.randn(2, 3, 224, 224)

    torch.onnx.export(model, dummy_input, onnx_file_path, input_names=input_names, output_names=output_names, opset_version=11)

    

if __name__ == '__main__':

    main()