# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import argparse

import logging

import math

import os

import random

import time

from pathlib import Path

from warnings import warn

import numpy as np

import torch

if torch.__version__ >= '1.8':

    import torch_npu

import torch.distributed as dist

import torch.nn.functional as F

import torch.optim as optim

import torch.optim.lr_scheduler as lr_scheduler

import torch.utils.data

import yaml

from torch.nn.parallel import DistributedDataParallel as DDP

from torch.utils.tensorboard import SummaryWriter

import torch.multiprocessing as mp



import test  # import test.py to get mAP after each epoch

#from models.yolo import Model

from models.models import *

from utils.autoanchor import check_anchors

from utils.datasets import create_dataloader

from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \

    fitness, fitness_p, fitness_r, fitness_ap50, fitness_ap, fitness_f, strip_optimizer, get_latest_run,\

    check_dataset, check_file, check_git_status, check_img_size, print_mutation, set_logging

from utils.google_utils import attempt_download

from utils.loss import compute_loss

from utils.plots import plot_images, plot_labels, plot_results, plot_evolution, output_to_target

from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first



logger = logging.getLogger(__name__)



mixed_precision = True

try:

    import apex

    from apex import amp

except:

    print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')

    mixed_precision = False  # not installed



use_wandb = False

if use_wandb:

    try:

        import wandb

    except:

        print("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")

else:

    wandb = None





def train(hyp, opt, device, tb_writer=None, wandb=None):

    print(f'Hyperparameters {hyp}')

    save_dir, epochs, batch_size, total_batch_size, weights, rank = \

        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank



    # Directories

    wdir = save_dir / 'weights'

    wdir.mkdir(parents=True, exist_ok=True)  # make dir

    last = wdir / 'last.pt'

    best = wdir / 'best.pt'

    results_file = save_dir / 'results.txt'



    # Save run settings

    with open(save_dir / 'hyp.yaml', 'w') as f:

        yaml.dump(hyp, f, sort_keys=False)

    with open(save_dir / 'opt.yaml', 'w') as f:

        yaml.dump(vars(opt), f, sort_keys=False)



    # Configure

    plots = not opt.evolve  # create plots

    cuda = device.type != 'cpu'

    init_seeds(2 + rank)

    with open(opt.data) as f:

        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # data dict

    with torch_distributed_zero_first(rank):

        check_dataset(data_dict)  # check

    train_path = data_dict['train']

    test_path = data_dict['val']

    nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names'])  # number classes, names

    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data)  # check



    # Model

    pretrained = weights.endswith('.pt')

    if pretrained:

        with torch_distributed_zero_first(rank):

            attempt_download(weights)  # download if not found locally

        ckpt = torch.load(weights, map_location=device)  # load checkpoint

        model = Darknet(opt.cfg).to(device)  # create

        state_dict = {k: v for k, v in ckpt['model'].items() if model.state_dict()[k].numel() == v.numel()}

        model.load_state_dict(state_dict, strict=False)

        print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights))  # report

    else:

        model = Darknet(opt.cfg).to(device) # create



    # Image sizes

    gs = 64 #int(max(model.stride))  # grid size (max stride)

    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size]  # verify imgsz are gs-multiples



    # Optimizer

    nbs = 64  # nominal batch size

    accumulate = max(round(nbs / total_batch_size), 1)  # accumulate loss before optimizing

    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay



    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups

    for k, v in dict(model.named_parameters()).items():

        if '.bias' in k:

            pg2.append(v)  # biases

        elif 'Conv2d.weight' in k:

            pg1.append(v)  # apply weight_decay

        elif 'm.weight' in k:

            pg1.append(v)  # apply weight_decay

        elif 'w.weight' in k:

            pg1.append(v)  # apply weight_decay

        else:

            pg0.append(v)  # all else



    if opt.adam:

        optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum

    else:

        if device.type == 'npu':

            optimizer = apex.optimizers.NpuFusedSGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

        else:

            optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)



    optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay

    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)

    print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))

    del pg0, pg1, pg2



    # Logging

    if wandb and wandb.run is None:

        opt.hyp = hyp  # add hyperparameters

        wandb_run = wandb.init(config=opt, resume="allow",

                               project='YOLOR' if opt.project == 'runs/train' else Path(opt.project).stem,

                               name=save_dir.stem,

                               id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)



    # Resume

    start_epoch = 0

    # best_fitness, best_fitness_p, best_fitness_r, best_fitness_ap50, best_fitness_ap, best_fitness_f = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

    if pretrained:

        # Optimizer

        if ckpt['optimizer'] is not None:

            optimizer.load_state_dict(ckpt['optimizer'])

            # best_fitness = ckpt['best_fitness']

            # best_fitness_p = ckpt['best_fitness_p']

            # best_fitness_r = ckpt['best_fitness_r']

            # best_fitness_ap50 = ckpt['best_fitness_ap50']

            # best_fitness_ap = ckpt['best_fitness_ap']

            # best_fitness_f = ckpt['best_fitness_f']



        # Results

        if ckpt.get('training_results') is not None:

            with open(results_file, 'w') as file:

                file.write(ckpt['training_results'])  # write results.txt



        # Epochs

        start_epoch = ckpt['epoch'] + 1

        if opt.resume:

            assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)

        if epochs < start_epoch:

            print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %

                        (weights, ckpt['epoch'], epochs))

            epochs += ckpt['epoch']  # finetune additional epochs



        del ckpt, state_dict



    # Mixed precision training https://github.com/NVIDIA/apex

    if mixed_precision:

        if device.type == 'npu':

            model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0, loss_scale=64, combine_grad=True)

        else:

            model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0, loss_scale=64)



    # Scheduler https://arxiv.org/pdf/1812.01187.pdf

    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR

    lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf']  # cosine

    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    # plot_lr_scheduler(optimizer, scheduler, epochs)





    # SyncBatchNorm

    if opt.sync_bn and cuda and rank != -1:

        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)

        print('Using SyncBatchNorm()')



    # EMA

    ema = ModelEMA(model) if rank in [-1, 0] else None



    # DDP mode

    if device.type == 'cuda' and rank != -1:

        model = DDP(model, device_ids=[rank], output_device=rank)

    elif device.type == 'npu' and rank != -1:

        model = DDP(model, device_ids=[rank], broadcast_buffers=False)



    # Trainloader

    dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,

                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,

                                            rank=rank, world_size=opt.world_size, workers=opt.workers)

    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class

    nb = len(dataloader)  # number of batches

    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)



    # Model parameters

    hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset

    model.nc = nc  # attach number of classes to model

    model.hyp = hyp  # attach hyperparameters to model

    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)

    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights

    model.names = names



    # Process 0

    if rank in [-1, 0]:

        ema.updates = start_epoch * nb // accumulate  # set EMA updates

        testloader = create_dataloader(test_path, imgsz_test, batch_size*2, gs, opt,

                                       hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,

                                       rank=-1, world_size=opt.world_size, workers=opt.workers)[0]  # testloader



        if not opt.resume:

            labels = np.concatenate(dataset.labels, 0)

            c = torch.tensor(labels[:, 0])  # classes

            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency

            # model._initialize_biases(cf.to(device))

            if plots:

                plot_labels(labels, save_dir=save_dir)

                if tb_writer:

                    tb_writer.add_histogram('classes', c, 0)

                if wandb:

                    wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})



            # Anchors

            # if not opt.noautoanchor:

            #     check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)





    # Start training

    t0 = time.time()

    nw = max(round(hyp['warmup_epochs'] * nb), 1000)  # number of warmup iterations, max(3 epochs, 1k iterations)

    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training

    maps = np.zeros(nc)  # mAP per class

    results = (0, 0, 0, 0, 0, 0, 0)  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)

    scheduler.last_epoch = start_epoch - 1  # do not move

    if rank in [0, -1]:

        print('Image sizes %g train, %g test\n'

                    'Using %g dataloader workers\nLogging results to %s\n'

                    'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))



    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------

        model.train()



        # Update image weights (optional)

        # When in DDP mode, the generated indices will be broadcasted to synchronize dataset.

        if opt.image_weights:

            # Generate indices

            if rank in [-1, 0]:

                cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2  # class weights

                iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights

                dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx

            # Broadcast if DDP

            if rank != -1:

                indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()

                dist.broadcast(indices, 0)

                if rank != 0:

                    dataset.indices = indices.cpu().numpy()



        # Update mosaic border

        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)

        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders



        mloss = torch.zeros(4, device=device)  # mean losses

        if rank != -1:

            dataloader.sampler.set_epoch(epoch)

        optimizer.zero_grad()

        start_time = time.time()

        d_1 = time.time()

        PERF_MAX_STEPS = os.environ.get("PERF_MAX_STEPS", None)

        for i, (imgs, targets, paths, _) in enumerate(dataloader):  # batch -------------------------------------------------------------

            t_time = time.time()

            d_time = t_time - d_1

            ni = i + nb * epoch  # number integrated batches (since train start)

            imgs = imgs.to(device, non_blocking=True).float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0



            # Warmup

            if ni <= nw:

                xi = [0, nw]  # x interp

                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)

                accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())

                for j, x in enumerate(optimizer.param_groups):

                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0

                    x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])

                    if 'momentum' in x:

                        x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])



            # Multi-scale

            if opt.multi_scale:

                sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  # size

                sf = sz / max(imgs.shape[2:])  # scale factor

                if sf != 1:

                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)

                    imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)



            # Forward

            pred = model(imgs)

            

            # Loss

            loss, loss_items = compute_loss(pred, targets.to(device), model)  # scaled by batch_size

            if rank != -1:

                loss *= opt.world_size  # gradient averaged between devices in DDP mode

            if not torch.isfinite(loss):

                print('WARNING: non-finite loss, ending training ', loss_items)

                return results

            

            # Backward

            if mixed_precision:

                with amp.scale_loss(loss, optimizer) as scaled_loss:

                    scaled_loss.backward()

            else:

                loss.backward()



            # Optimize

            if ni % accumulate == 0:

                optimizer.step()

                optimizer.zero_grad()

                if ema is not None:

                    x = torch.tensor([1.]).to(device)

                    if device.type == 'npu':

                        params_fp32_fused = optimizer.get_model_combined_params()

                        ema.update(model, x, params_fp32_fused[0])

                    else:

                        ema.update(model, x)

                        

            if i <= 10:

                sum_time = (time.time() - start_time) / (i + 1)

                if i == 10:

                    start_time = time.time()

            else:

                sum_time = (time.time() - start_time) / (i - 10)

            ptime = time.time() - d_1

            # Print

            if rank in [-1, 0]:

                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses

                mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB)

                s = ('%10s' * 2 + '%10.4g' * 6) % (

                    '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])

                print(

                    'Epoch:[%2g][%4g/%4g][%s][FPS:%3.1f][mTime:%3.3f][pTime:%3.3f][dTime:%3.3f] GIoU:%.3f objectness:%.3f classfication:%.3f totalLoss:%.3f' % (

                    epoch, i, nb, device, opt.total_batch_size / sum_time, sum_time, ptime, d_time, *mloss))



                # Plot

                if plots and ni < 3:

                    f = save_dir / f'train_batch{ni}.jpg'  # filename

                    plot_images(images=imgs, targets=targets, paths=paths, fname=f)

                    # if tb_writer:

                    #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)

                    #     tb_writer.add_graph(model, imgs)  # add model to tensorboard

                elif plots and ni == 3 and wandb:

                    wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')]})

            if PERF_MAX_STEPS and i > int(PERF_MAX_STEPS):

                break

            d_1 = time.time()

            # end batch ------------------------------------------------------------------------------------------------

        # end epoch ----------------------------------------------------------------------------------------------------



        # Scheduler

        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard

        scheduler.step()



        # DDP process 0 or single-GPU

        if rank in [-1, 0]:

            # mAP

            if ema:

                ema.update_attr(model)

            final_epoch = epoch + 1 == epochs

            if False: # No test during training

                    results, maps, _ = test.test(opt.data,

                                                 batch_size=batch_size*2,

                                                 imgsz=imgsz_test,

                                                 model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,

                                                 single_cls=opt.single_cls,

                                                 dataloader=testloader,

                                                 save_dir=save_dir,

                                                 plots=True)



            # Write

            with open(results_file, 'a') as f:

                f.write(s + '%10.4g' * 7 % results + '\n')  # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)

            if len(opt.name) and opt.bucket:

                os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))



            # Log

            tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss',  # train loss

                    'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',

                    'val/box_loss', 'val/obj_loss', 'val/cls_loss',  # val loss

                    'x/lr0', 'x/lr1', 'x/lr2']  # params

            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):

                if tb_writer:

                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard

                if wandb:

                    wandb.log({tag: x})  # W&B



            # # Update best mAP

            # fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # fi_p = fitness_p(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # fi_r = fitness_r(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # fi_ap50 = fitness_ap50(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # fi_ap = fitness_ap(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # if (fi_p > 0.0) or (fi_r > 0.0):

            #     fi_f = fitness_f(np.array(results).reshape(1, -1))  # weighted combination of [P, R, mAP@.5, mAP@.5-.95]

            # else:

            #     fi_f = 0.0

            # if fi > best_fitness:

            #     best_fitness = fi

            # if fi_p > best_fitness_p:

            #     best_fitness_p = fi_p

            # if fi_r > best_fitness_r:

            #     best_fitness_r = fi_r

            # if fi_ap50 > best_fitness_ap50:

            #     best_fitness_ap50 = fi_ap50

            # if fi_ap > best_fitness_ap:

            #     best_fitness_ap = fi_ap

            # if fi_f > best_fitness_f:

            #     best_fitness_f = fi_f



            # Save model

            save = (not opt.nosave) or (final_epoch and not opt.evolve)

            if save:

                with open(results_file, 'r') as f:  # create checkpoint

                    ckpt = {'epoch': epoch,

                            # 'best_fitness': best_fitness,

                            # 'best_fitness_p': best_fitness_p,

                            # 'best_fitness_r': best_fitness_r,

                            # 'best_fitness_ap50': best_fitness_ap50,

                            # 'best_fitness_ap': best_fitness_ap,

                            # 'best_fitness_f': best_fitness_f,

                            'training_results': f.read(),

                            'model': ema.ema.module.state_dict() if hasattr(ema, 'module') else ema.ema.state_dict(),

                            'optimizer': None if final_epoch else optimizer.state_dict(),

                            'wandb_id': wandb_run.id if wandb else None}



                # Save last, best and delete

                torch.save(ckpt, last)

                # if best_fitness == fi:

                #     torch.save(ckpt, best)

                # if (best_fitness == fi) and (epoch >= 200):

                #     torch.save(ckpt, wdir / 'best_{:03d}.pt'.format(epoch))

                # if best_fitness == fi:

                #     torch.save(ckpt, wdir / 'best_overall.pt')

                # if best_fitness_p == fi_p:

                #     torch.save(ckpt, wdir / 'best_p.pt')

                # if best_fitness_r == fi_r:

                #     torch.save(ckpt, wdir / 'best_r.pt')

                # if best_fitness_ap50 == fi_ap50:

                #     torch.save(ckpt, wdir / 'best_ap50.pt')

                # if best_fitness_ap == fi_ap:

                #     torch.save(ckpt, wdir / 'best_ap.pt')

                # if best_fitness_f == fi_f:

                #     torch.save(ckpt, wdir / 'best_f.pt')

                if epoch == 0:

                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))

                if ((epoch+1) % 25) == 0:

                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))

                if epoch >= (epochs-5):

                    torch.save(ckpt, wdir / 'last_{:03d}.pt'.format(epoch))

                elif epoch >= 420: 

                    torch.save(ckpt, wdir / 'last_{:03d}.pt'.format(epoch))

                del ckpt

        # end epoch ----------------------------------------------------------------------------------------------------

    # end training



    if rank in [-1, 0]:

        # Strip optimizers

        n = opt.name if opt.name.isnumeric() else ''

        fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'

        for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):

            if f1.exists():

                os.rename(f1, f2)  # rename

                if str(f2).endswith('.pt'):  # is *.pt

                    strip_optimizer(f2)  # strip optimizer

                    os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None  # upload

        # Finish

        if plots:

            plot_results(save_dir=save_dir)  # save as results.png

            if wandb:

                wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in

                                       ['results.png', 'precision-recall_curve.png']]})

        print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))



    wandb.run.finish() if wandb and wandb.run else None

    torch.cuda.empty_cache()

    

    return results





def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--weights', type=str, default='yolor_p6.pt', help='initial weights path')

    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')

    parser.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')

    parser.add_argument('--hyp', type=str, default='data/hyp.scratch.1280.yaml', help='hyperparameters path')

    parser.add_argument('--epochs', type=int, default=300)

    parser.add_argument('--batch-size', type=int, default=8, help='total batch size for all GPUs')

    parser.add_argument('--img-size', nargs='+', type=int, default=[1280, 1280], help='[train, test] image sizes')

    parser.add_argument('--rect', action='store_true', help='rectangular training')

    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')

    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')

    parser.add_argument('--notest', action='store_true', help='only test final epoch')

    parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')

    parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')

    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')

    parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')

    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')

    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu or npu')

    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')

    parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')

    parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')

    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')

    parser.add_argument('--full', action='store_true', help='full mode')

    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')

    parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')

    parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')

    parser.add_argument('--project', default='runs/train', help='save to project/name')

    parser.add_argument('--name', default='exp', help='save to project/name')

    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')

    parser.add_argument('--npu', default=-1, type=int, help='NPU id to use.')

    

    # NPU DDP mode

    parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')

    parser.add_argument('--device-num', default=1, type=int, help='multi NPU parameter, GPU or CPU do not modify')

    parser.add_argument('--addr', default='127.0.0.1', type=str, help='DDP master node IP')

    parser.add_argument('--dist-url', default='tcp://127.0.0.1:29501', type=str,

                        help='url used to set up distributed training')

    opt = parser.parse_args()



    if opt.dist_url == "env://":

        opt.world_size = int(os.environ["WORLD_SIZE"])

    ngpus_per_node = opt.device_num

    opt.npu_ddp = (opt.device_num > 1 or opt.world_size > 1)

    opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)  # increment run

    if opt.npu_ddp:

        print('multi npu training')

        os.environ['MASTER_ADDR'] = opt.addr  # master ip

        os.environ['MASTER_PORT'] = '29501'

        os.environ['KERNEL_NAME_ID'] = str(0)

        opt.world_size = ngpus_per_node * opt.world_size  # the sum of GPU or NPU in all the nodes

        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))

    else:

        print('1p training')

        main_worker(opt.npu, ngpus_per_node, opt)



    if opt.full and opt.local_rank in [-1, 0]:

        # Directories

        wdir = Path(opt.save_dir) / 'weights' / 'last.pt' 

        os.system(f"python3 test.py \

                --data data/coco.yaml \

                --img 1280 \

                --batch 32 \

                --conf 0.001 \

                --iou 0.65 \

                --cfg cfg/yolor_p6.cfg \

                --weights {wdir} \

                --name yolor_val \

                --device npu \

                --npu 1 \

                2>&1 | tee -a npu_8p_full.log")



    # # Set DDP variables

    # opt.total_batch_size = opt.batch_size

    # opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1

    # opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1

    # set_logging(opt.global_rank)

    # if opt.global_rank in [-1, 0]:

    #     check_git_status()



    # # DDP mode

    # device = select_device(opt.device, batch_size=opt.batch_size)

    # if opt.local_rank != -1:

    #     assert torch.cuda.device_count() > opt.local_rank

    #     torch.cuda.set_device(opt.local_rank)

    #     device = torch.device('cuda', opt.local_rank)

    #     dist.init_process_group(backend='nccl', init_method='env://')  # distributed backend

    #     assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'

    #     opt.batch_size = opt.total_batch_size // opt.world_size





def main_worker(npu, ngpus_per_node, opt):

    # Resume

    if opt.resume:  # resume an interrupted run

        ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()  # specified or most recent path

        assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'

        with open(Path(ckpt).parent.parent / 'opt.yaml') as f:

            opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader))  # replace

        opt.cfg, opt.weights, opt.resume = '', ckpt, True

        print('Resuming training from %s' % ckpt)

    else:

        # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')

        opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp)  # check files

        assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'

        opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size)))  # extend to 2 sizes (train, test)

        opt.name = 'evolve' if opt.evolve else opt.name

        opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)  # increment run



    # Hyperparameters

    with open(opt.hyp) as f:

        hyp = yaml.load(f, Loader=yaml.FullLoader)  # load hyps

        if 'box' not in hyp:

            warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %

                 (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))

            hyp['box'] = hyp.pop('giou')

    

    # npu DDP

    if opt.npu_ddp:

        opt.npu = npu

        os.environ['KERNEL_NAME_ID'] = str(npu)

        print("[npu id:", opt.npu, "]", "+++++++++++++++++++++++++++KERNEL_NAME_ID:", os.environ['KERNEL_NAME_ID'])

        opt.local_rank = opt.local_rank * ngpus_per_node + npu

    global mixed_precision

    device = torch_utils.select_device(opt.device, opt.npu, apex=mixed_precision, batch_size=opt.batch_size)

    opt.total_batch_size = opt.batch_size

    if device.type == 'cpu':

        mixed_precision = False

    elif opt.local_rank != -1 and device.type == 'cuda':

        # DDP mode

        assert torch.cuda.device_count() > opt.local_rank

        torch.cuda.set_device(opt.local_rank)

        device = torch.device("cuda", opt.local_rank)

        dist.init_process_group(backend='nccl', init_method='env://')  # distributed backend



        opt.world_size = dist.get_world_size()

        assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"

        opt.batch_size = opt.total_batch_size // opt.world_size

    elif opt.local_rank != -1 and device.type == 'npu':

        dist.init_process_group(backend='hccl', world_size=opt.world_size, rank=opt.local_rank)

        assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"

        opt.batch_size = opt.total_batch_size // opt.world_size



    # Train

    print(opt)

    if not opt.evolve:

        tb_writer = None  # init loggers

        if opt.local_rank in [-1, 0]:

            print(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')

            tb_writer = SummaryWriter(opt.save_dir)  # Tensorboard

        train(hyp, opt, device, tb_writer, wandb)



    # Evolve hyperparameters (optional)

    else:

        # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)

        meta = {'lr0': (1, 1e-5, 1e-1),  # initial learning rate (SGD=1E-2, Adam=1E-3)

                'lrf': (1, 0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)

                'momentum': (0.3, 0.6, 0.98),  # SGD momentum/Adam beta1

                'weight_decay': (1, 0.0, 0.001),  # optimizer weight decay

                'warmup_epochs': (1, 0.0, 5.0),  # warmup epochs (fractions ok)

                'warmup_momentum': (1, 0.0, 0.95),  # warmup initial momentum

                'warmup_bias_lr': (1, 0.0, 0.2),  # warmup initial bias lr

                'box': (1, 0.02, 0.2),  # box loss gain

                'cls': (1, 0.2, 4.0),  # cls loss gain

                'cls_pw': (1, 0.5, 2.0),  # cls BCELoss positive_weight

                'obj': (1, 0.2, 4.0),  # obj loss gain (scale with pixels)

                'obj_pw': (1, 0.5, 2.0),  # obj BCELoss positive_weight

                'iou_t': (0, 0.1, 0.7),  # IoU training threshold

                'anchor_t': (1, 2.0, 8.0),  # anchor-multiple threshold

                'anchors': (2, 2.0, 10.0),  # anchors per output grid (0 to ignore)

                'fl_gamma': (0, 0.0, 2.0),  # focal loss gamma (efficientDet default gamma=1.5)

                'hsv_h': (1, 0.0, 0.1),  # image HSV-Hue augmentation (fraction)

                'hsv_s': (1, 0.0, 0.9),  # image HSV-Saturation augmentation (fraction)

                'hsv_v': (1, 0.0, 0.9),  # image HSV-Value augmentation (fraction)

                'degrees': (1, 0.0, 45.0),  # image rotation (+/- deg)

                'translate': (1, 0.0, 0.9),  # image translation (+/- fraction)

                'scale': (1, 0.0, 0.9),  # image scale (+/- gain)

                'shear': (1, 0.0, 10.0),  # image shear (+/- deg)

                'perspective': (0, 0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001

                'flipud': (1, 0.0, 1.0),  # image flip up-down (probability)

                'fliplr': (0, 0.0, 1.0),  # image flip left-right (probability)

                'mosaic': (1, 0.0, 1.0),  # image mixup (probability)

                'mixup': (1, 0.0, 1.0)}  # image mixup (probability)



        assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'

        opt.notest, opt.nosave = True, True  # only test/save final epoch

        # ei = [isinstance(x, (int, float)) for x in hyp.values()]  # evolvable indices

        yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml'  # save best result here

        if opt.bucket:

            os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket)  # download evolve.txt if exists



        for _ in range(300):  # generations to evolve

            if Path('evolve.txt').exists():  # if evolve.txt exists: select best hyps and mutate

                # Select parent(s)

                parent = 'single'  # parent selection method: 'single' or 'weighted'

                x = np.loadtxt('evolve.txt', ndmin=2)

                n = min(5, len(x))  # number of previous results to consider

                x = x[np.argsort(-fitness(x))][:n]  # top n mutations

                w = fitness(x) - fitness(x).min()  # weights

                if parent == 'single' or len(x) == 1:

                    # x = x[random.randint(0, n - 1)]  # random selection

                    x = x[random.choices(range(n), weights=w)[0]]  # weighted selection

                elif parent == 'weighted':

                    x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination



                # Mutate

                mp, s = 0.8, 0.2  # mutation probability, sigma

                npr = np.random

                npr.seed(int(time.time()))

                g = np.array([x[0] for x in meta.values()])  # gains 0-1

                ng = len(meta)

                v = np.ones(ng)

                while all(v == 1):  # mutate until a change occurs (prevent duplicates)

                    v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)

                for i, k in enumerate(hyp.keys()):  # plt.hist(v.ravel(), 300)

                    hyp[k] = float(x[i + 7] * v[i])  # mutate



            # Constrain to limits

            for k, v in meta.items():

                hyp[k] = max(hyp[k], v[1])  # lower limit

                hyp[k] = min(hyp[k], v[2])  # upper limit

                hyp[k] = round(hyp[k], 5)  # significant digits



            # Train mutation

            results = train(hyp.copy(), opt, device, wandb=wandb)



            # Write mutation results

            print_mutation(hyp.copy(), results, yaml_file, opt.bucket)



        # Plot results

        plot_evolution(yaml_file)

        print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'

              f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')





if __name__ == '__main__':

    option = {}

    # option["ACL_OP_DEBUG_LEVEL"] = 3 # 算子debug功能,暂不开启

    # option["ACL_DEBUG_DIR"] = "debug_file" # 算子debug功能对应文件夹,暂不开启

    option["ACL_OP_COMPILER_CACHE_MODE"] = "enable" # cache功能启用 

    option["ACL_OP_COMPILER_CACHE_DIR"] = "./cache" # cache所在文件夹

    print("option:",option)

    torch.npu.set_option(option)

    main()