# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

# ============================================================================



import os

import time

import argparse

import datetime

import numpy as np



import torch

import torch.backends.cudnn as cudnn

import torch.distributed as dist



from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from timm.utils import accuracy, AverageMeter



from config import get_config

from classification import build_model

from data import build_loader

from lr_scheduler import build_scheduler

from optimizer import build_optimizer

from logger import create_logger

from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor

from apex.optimizers import NpuFusedAdamW



option = {}

# option["ACL_OP_DEBUG_LEVEL"] = 3

option["ACL_OP_COMPILER_CACHE_MODE"] = "enable"

kernel_meta_path = "./kernel_meta"

option["ACL_OP_COMPILER_CACHE_DIR"] = kernel_meta_path

if not os.path.exists(kernel_meta_path):

    os.makedirs(kernel_meta_path, exist_ok=True)



torch.npu.set_option(option)



try:

    # noinspection PyUnresolvedReferences

    from apex import amp

except ImportError:

    amp = None





def parse_option():

    parser = argparse.ArgumentParser('Focal Transformer training and evaluation script', add_help=False)

    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )

    parser.add_argument(

        "--opts",

        help="Modify config options by adding 'KEY VALUE' pairs. ",

        default=None,

        nargs='+',

    )



    # easy config modification

    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")

    parser.add_argument('--dataset', type=str, default='imagenet', help='dataset name')

    parser.add_argument('--data-path', type=str, help='path to dataset')

    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')

    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],

                        help='no: no cache, '

                             'full: cache all data, '

                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')

    parser.add_argument('--resume', help='resume from checkpoint', default=False)

    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")

    parser.add_argument('--use-checkpoint', action='store_true',

                        help="whether to use gradient checkpointing to save memory")

    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],

                        help='mixed precision opt level, if O0, no amp is used')

    parser.add_argument('--output', default='output', type=str, metavar='PATH',

                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')

    parser.add_argument('--tag', help='tag of experiment')

    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')

    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    parser.add_argument('--debug', action='store_true', help='Perform debug only')

    parser.add_argument('--stop_step', type=bool, default=False)

    parser.add_argument('--finetune_switch', type=bool, default=False)

    parser.add_argument('--finetune_model', type=str, default=False)



    # distributed training

    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')



    args, unparsed = parser.parse_known_args()



    config = get_config(args)



    return args, config





def check_keywords_in_name(name, keywords=()):

    isin = False

    for keyword in keywords:

        if keyword in name:

            isin = True

    return isin





def set_weight_decay(model, skip_list=(), skip_keywords=()):

    has_decay = []

    no_decay = []



    for name, param in model.named_parameters():

        if not param.requires_grad:

            continue  # frozen weights

        if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \

                check_keywords_in_name(name, skip_keywords):

            no_decay.append(param)

            # print(f"{name} has no weight decay")

        else:

            has_decay.append(param)

    return [{'params': has_decay},

            {'params': no_decay, 'weight_decay': 0.}]





def main(config):

    if not config.DEBUG_MODE:

        dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)



    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")

    model = build_model(config)

    if config.FINETUNE_MODEL:

        checkpoint = torch.load(config.FINETUNE_MODEL, map_location='cpu')

        model.load_state_dict(checkpoint['model'], strict=False)

        print("load pretrained model: ",config.FINETUNE_MODEL)

    model.npu()

    # logger.info(str(model))



    #optimizer = build_optimizer(config, model)

    skip = {}

    skip_keywords = {}

    if hasattr(model, 'no_weight_decay'):

        skip = model.no_weight_decay()

    if hasattr(model, 'no_weight_decay_keywords'):

        skip_keywords = model.no_weight_decay_keywords()

    parameters = set_weight_decay(model, skip, skip_keywords)

    optim_dict = {"lr": config.TRAIN.BASE_LR, "weight_decay": config.TRAIN.WEIGHT_DECAY,

                  "eps": config.TRAIN.OPTIMIZER.EPS, "betas": config.TRAIN.OPTIMIZER.BETAS}

    optimizer = NpuFusedAdamW(parameters, **optim_dict)

    if config.AMP_OPT_LEVEL != "O0":

        # dailr

        model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL,loss_scale="dynamic", combine_grad=True)

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)

    model_without_ddp = model.module



    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    if dist.get_rank() == 0:

        logger.info(f"number of params: {n_parameters}")

    if hasattr(model_without_ddp, 'flops'):

        flops = model_without_ddp.flops()

        logger.info(f"number of GFLOPs: {flops / 1e9}")



    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))



    if config.AUG.MIXUP > 0.:

        # smoothing is handled with mixup label transform

        criterion = SoftTargetCrossEntropy()

    elif config.MODEL.LABEL_SMOOTHING > 0.:

        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)

    else:

        criterion = torch.nn.CrossEntropyLoss()

    criterion = criterion.npu()

    max_accuracy = 0.0



    if config.TRAIN.AUTO_RESUME:

        resume_file = auto_resume_helper(config.OUTPUT)

        if resume_file:

            if config.MODEL.RESUME:

                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")

            config.defrost()

            config.MODEL.RESUME = resume_file

            config.freeze()

            logger.info(f'auto resuming from {resume_file}')

        else:

            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')



    if config.MODEL.RESUME:

        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)

        acc1, acc5, loss = validate(config, data_loader_val, model)

        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")

        if config.EVAL_MODE:

            return



    if config.THROUGHPUT_MODE:

        throughput(data_loader_val, model, logger)

        return



    logger.info("Start training")

    start_time = time.time()

    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):

        data_loader_train.sampler.set_epoch(epoch)



        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)

        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):

            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)



        acc1, acc5, loss = validate(config, data_loader_val, model)

        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")

        max_accuracy = max(max_accuracy, acc1)

        logger.info(f'Max accuracy: {max_accuracy:.2f}%')



    total_time = time.time() - start_time

    total_time_str = str(datetime.timedelta(seconds=int(total_time)))

    logger.info('Training time {}'.format(total_time_str))





def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler):

    model.train()

    optimizer.zero_grad()



    num_steps = len(data_loader)

    batch_time = AverageMeter()

    loss_meter = AverageMeter()

    norm_meter = AverageMeter()



    start = time.time()

    end = time.time()

    for idx, (samples, targets) in enumerate(data_loader):

        # if dist.get_rank() == 0:

        #     print('=====iter: ',idx)

        samples = samples.npu(non_blocking=True)

        targets = targets.npu(non_blocking=True)



        if mixup_fn is not None:

            samples, targets = mixup_fn(samples, targets)



        outputs = model(samples)



        if config.TRAIN.ACCUMULATION_STEPS > 1:

            loss = criterion(outputs, targets)

            loss = loss / config.TRAIN.ACCUMULATION_STEPS

            if config.AMP_OPT_LEVEL != "O0":

                with amp.scale_loss(loss, optimizer) as scaled_loss:

                    scaled_loss.backward()

                if config.TRAIN.CLIP_GRAD:

                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)

                else:

                    grad_norm = get_grad_norm(amp.master_params(optimizer))

            else:

                loss.backward()

                if config.TRAIN.CLIP_GRAD:

                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)

                else:

                    grad_norm = get_grad_norm(model.parameters())

            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:

                optimizer.step()

                optimizer.zero_grad()

                lr_scheduler.step_update(epoch * num_steps + idx)

        else:

            loss = criterion(outputs, targets)

            optimizer.zero_grad()

            if config.AMP_OPT_LEVEL != "O0":

                with amp.scale_loss(loss, optimizer) as scaled_loss:

                    scaled_loss.backward()

                if config.TRAIN.CLIP_GRAD:

                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)

                else:

                    grad_norm = get_grad_norm(amp.master_params(optimizer))

            else:

                loss.backward()

                if config.TRAIN.CLIP_GRAD:

                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)

                else:

                    grad_norm = get_grad_norm(model.parameters())



            optimizer.step()

            lr_scheduler.step_update(epoch * num_steps + idx)



        torch.npu.synchronize()



        loss_meter.update(loss.item(), targets.size(0))

        norm_meter.update(grad_norm)

        batch_time.update(time.time() - end)

        end = time.time()

        if idx == 10:

            prev_time = 0

            init_step = 10

            init_time = batch_time.sum



        if idx == 20:

            FPS = config.DATA.BATCH_SIZE *  world_size * (20 - 10) / (batch_time.sum-init_time)

            logger.info(f"EPOCH {epoch} training FPS: {FPS}\t")

            if config.STOP_STEP == True:

                print("train 20 steps and finish")

                exit()



        if idx % config.PRINT_FREQ == 0 and idx != 0:

            lr = optimizer.param_groups[0]['lr']

            memory_used = torch.npu.max_memory_allocated() / (1024.0 * 1024.0)

            etas = batch_time.avg * (num_steps - idx)

            logger.info(

                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'

                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'

                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'

                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'

                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'

                f'mem {memory_used:.0f}MB\t'

                f'FPS: {config.DATA.BATCH_SIZE *  world_size * (config.PRINT_FREQ-init_step) / (batch_time.sum-prev_time-init_time)}')

            prev_time = batch_time.sum

            init_time = 0

            init_step = 0



    epoch_time = time.time() - start

    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")





@torch.no_grad()

def validate(config, data_loader, model):

    criterion = torch.nn.CrossEntropyLoss().npu()

    model.eval()



    batch_time = AverageMeter()

    loss_meter = AverageMeter()

    acc1_meter = AverageMeter()

    acc5_meter = AverageMeter()



    end = time.time()

    for idx, (images, target) in enumerate(data_loader):

        images = images.npu(non_blocking=True)

        target = target.npu(non_blocking=True)



        # compute output

        output = model(images)

        # measure accuracy and record loss

        loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))



        acc1 = reduce_tensor(acc1)

        acc5 = reduce_tensor(acc5)

        loss = reduce_tensor(loss)



        loss_meter.update(loss.item(), target.size(0))

        acc1_meter.update(acc1.item(), target.size(0))

        acc5_meter.update(acc5.item(), target.size(0))



        # measure elapsed time

        batch_time.update(time.time() - end)

        end = time.time()



        if idx % config.PRINT_FREQ == 0:

            memory_used = torch.npu.max_memory_allocated() / (1024.0 * 1024.0)

            logger.info(

                f'Test: [{idx}/{len(data_loader)}]\t'

                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'

                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'

                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'

                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'

                f'Mem {memory_used:.0f}MB')

    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')

    

    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg





@torch.no_grad()

def throughput(data_loader, model, logger):

    model.eval()



    for idx, (images, _) in enumerate(data_loader):

        images = images.npu(non_blocking=True)

        batch_size = images.shape[0]

        for i in range(50):

            model(images)

        torch.npu.synchronize()

        logger.info(f"throughput averaged with 30 times")

        tic1 = time.time()

        for i in range(30):

            model(images)

        torch.npu.synchronize()

        tic2 = time.time()

        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")

        return





if __name__ == '__main__':

    _, config = parse_option()



    if config.AMP_OPT_LEVEL != "O0":

        assert amp is not None, "amp not installed!"



    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:

        rank = int(os.environ["RANK"])

        world_size = int(os.environ['WORLD_SIZE'])

        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")

    else:

        rank = -1

        world_size = -1

    torch.npu.set_device(config.LOCAL_RANK)

    torch.distributed.init_process_group(backend='hccl', init_method='env://', world_size=world_size, rank=rank)

    torch.distributed.barrier()



    seed = config.SEED + dist.get_rank()

    torch.manual_seed(seed)

    np.random.seed(seed)

    cudnn.benchmark = True



    # linear scale the learning rate according to total batch size, may not be optimal

    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0

    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0

    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0

    # gradient accumulation also need to scale the learning rate

    if config.TRAIN.ACCUMULATION_STEPS > 1:

        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS

        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS

        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS

    config.defrost()

    config.TRAIN.BASE_LR = linear_scaled_lr

    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr

    config.TRAIN.MIN_LR = linear_scaled_min_lr

    config.freeze()



    config.defrost()

    config.OUTPUT = os.getenv('PT_OUTPUT_DIR') if os.getenv('PT_OUTPUT_DIR') else config.OUTPUT

    config.freeze()



    os.makedirs(config.OUTPUT, exist_ok=True)

    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")



    if dist.get_rank() == 0:

        path = os.path.join(config.OUTPUT, "config.json")

        with open(path, "w") as f:

            f.write(config.dump())

        logger.info(f"Full config saved to {path}")



    # print config

    logger.info(config.dump())

    main(config)