# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



from base_model import Loss

import sys

import os

from opt_loss import OptLoss

from mlperf_logger import configure_logger, log_start, log_end, log_event, set_seeds, get_rank, barrier

from mlperf_logging.mllog import constants

import torch

from torch.autograd import Variable

import time

import numpy as np

import io

from bisect import bisect       # for lr_scheduler

from apex import amp

from ssd300 import SSD300

from master_params import create_flat_master

from parse_config import parse_args, validate_arguments, validate_group_bn

from data.build_pipeline import prebuild_pipeline, build_pipeline

from box_coder import dboxes300_coco, build_ssd300_coder

from async_evaluator import AsyncEvaluator

from eval import coco_eval

from apex.optimizers import NpuFusedSGD

import gc

from torch.nn.parallel import DistributedDataParallel

# necessary pytorch imports

import torch.utils.data.distributed

import torch.distributed as dist



# Apex imports

try:

    import apex_C

    import apex

    from apex.parallel.LARC import LARC

    from apex.parallel import DistributedDataParallel as DDP

    from apex.fp16_utils import *

    from apex.multi_tensor_apply import multi_tensor_applier

    #import amp_C

except ImportError:

    raise ImportError("Please install APEX from https://github.com/nvidia/apex")



from contextlib import redirect_stdout



import logging





class Logger(object):

    logfile = ""



    def __init__(self, filename=""):

        self.logfile = filename

        self.terminal = sys.stdout

        # self.log = open(filename, "a")

        return



    def write(self, message):

        self.terminal.write(message)

        if self.logfile != "":

            try:

                self.log = open(self.logfile, "a")

                self.log.write(message)

                self.log.close()

            except:

                pass



    def flush(self):

        pass





def print_message(rank, *print_args):

    if rank == 0:

        print(*print_args)



def load_checkpoint(model, checkpoint):

    print("loading model checkpoint", checkpoint)

    od = torch.load(checkpoint)

    # remove proceeding 'module' from checkpoint

    saved_model = od["model"]

    for k in list(saved_model.keys()):

        if k.startswith('module.'):

            saved_model[k[7:]] = saved_model.pop(k)

        if k.startswith('mbox.'):

            saved_model.pop(k)

    model.load_state_dict(saved_model,strict=False)



def check_async_evals(args, evaluator, threshold):

    finished = 0

    # Note: only one rank does COCOEval, so we need to check there if we've

    # finished -- we'll broadcast that to a "finished" tensor to determine

    # if we should stop

    # Note2: ssd_print contains a barrier() call, implemented with all_reduce

    #        If we conditional on rank 0, then an ssd_print all_reduce matches with

    #        the finished all_reduce and all hell breaks loose.

    if args.rank == 0:

        for epoch, current_accuracy in evaluator.finished_tasks().items():

            # Note: Move to per-iter check

            # EVAL_START should be prior to the accuracy/score evaluation but adding the missing EVAL_START here for now

            log_start(key=constants.EVAL_START, metadata={'epoch_num' : epoch})

            log_event(key=constants.EVAL_ACCURACY,

                      value=current_accuracy,

                      metadata={'epoch_num' : epoch})

            log_end(key=constants.EVAL_STOP, metadata={'epoch_num' : epoch})

            if current_accuracy >= threshold:

                finished = 1



    # handle the non-distributed case -- don't need to bcast, just take local result

    if not args.distributed:

        return finished == 1



    # Now we know from all ranks if they're done - reduce result

    # Note: Already caught the non-distributed case above, can assume broadcast is available

    with torch.no_grad():

        finish_tensor = torch.tensor([finished], dtype=torch.int32, device=torch.device('npu'))

        # torch.distributed.all_reduce(finish_tensor)

        torch.distributed.broadcast(finish_tensor, src=0)



        # >= 1 ranks has seen final accuracy

        if finish_tensor.item() >= 1:

            return True



    # Default case: No results, or no accuracte enough results

    return False



def lr_warmup(optim, warmup_iter, iter_num, epoch, base_lr, args):

    if iter_num < warmup_iter:

        warmup_step = base_lr / (warmup_iter * (2 ** args.warmup_factor))

        new_lr = base_lr - (warmup_iter - iter_num) * warmup_step



        for param_group in optim.param_groups:

            param_group['lr'] = new_lr



def setup_distributed(args):

    # Setup multi-GPU if necessary



    args.distributed = False

    if 'WORLD_SIZE' in os.environ:

        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    os.environ['MASTER_ADDR'] = '127.0.0.1'

    os.environ['MASTER_PORT'] = '29688'

    if args.distributed:

        torch.npu.set_device(args.local_rank)

        torch.distributed.init_process_group(backend='hccl',

                                            world_size=int(os.environ['WORLD_SIZE']),

                                            rank=args.local_rank,

                                             )

    args.local_seed = set_seeds(args)

    # start timing here

    if args.distributed:

        args.N_gpu = torch.distributed.get_world_size()

        args.rank = torch.distributed.get_rank()

    else:

        args.N_gpu = 1

        args.rank = 0



    validate_group_bn(args.bn_group)



    return args



def train300_mlperf_coco(args):



    

    args = setup_distributed(args)

    

    if not args.distributed:

        torch.npu.set_device(args.device_id)



    # Build the model

    model_options = {

        'use_nhwc' : args.nhwc,

        'pad_input' : args.pad_input,

        'bn_group' : args.bn_group,

    }



    ssd300 = SSD300(args, args.num_classes, **model_options)

    if args.checkpoint is not None:

        load_checkpoint(ssd300, args.checkpoint)



    ssd300.train()

    ssd300.npu()

    dboxes = dboxes300_coco()

    # Note: No reason not to use optimised loss

    #loss_func = OptLoss()

    loss_func = Loss(dboxes)

    loss_func.npu()



    # Create optimizer.  This must also be done after network_to_half.

    global_batch_size = (args.N_gpu * args.batch_size)

    log_event(key=constants.MODEL_BN_SPAN, value=args.bn_group*args.batch_size)

    log_event(key=constants.GLOBAL_BATCH_SIZE, value=global_batch_size)



    # mlperf only allows base_lr scaled by an integer

    base_lr = 2.5e-3

    requested_lr_multiplier = args.lr / base_lr

    adjusted_multiplier = max(1, round(requested_lr_multiplier * global_batch_size / 32))



    current_lr = base_lr * adjusted_multiplier

    current_momentum = 0.9

    current_weight_decay = args.wd

    static_loss_scale = args.loss_scale





    optim = apex.optimizers.NpuFusedSGD(ssd300.parameters(),

                                     lr=current_lr,

                                     momentum=current_momentum,

                                     weight_decay=current_weight_decay)

    ssd300, optim = amp.initialize(ssd300, optim, opt_level='O2', loss_scale=static_loss_scale,combine_grad=True)

    # Parallelize.  Need to do this after network_to_half.

    if args.distributed:

        if args.delay_allreduce:

            print_message(args.local_rank, "Delaying allreduces to the end of backward()")

        ssd300 = DistributedDataParallel(ssd300, device_ids=[args.local_rank])



    log_event(key=constants.OPT_BASE_LR, value=current_lr)

    log_event(key=constants.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=args.lr_decay_epochs)

    log_event(key=constants.OPT_LR_DECAY_STEPS, value=args.lr_decay_epochs)

    log_event(key=constants.OPT_WEIGHT_DECAY, value=current_weight_decay)

    if args.warmup is not None:

        log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup)

        log_event(key=constants.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)



    # Model is completely finished -- need to create separate copies, preserve parameters across

    # them, and jit

    ssd300_eval = SSD300(args, args.num_classes, **model_options).npu()



    if args.use_fp16:

        convert_network(ssd300_eval, torch.half)



    # Get the existant state from the train model

    # * if we use distributed, then we want .module

    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()





    print_message(args.local_rank, "epoch", "nbatch", "loss")



    iter_num = args.iteration

    avg_loss = 0.0



    start_elapsed_time = time.time()

    last_printed_iter = args.iteration

    num_elapsed_samples = 0



    input_c = 4 if args.pad_input else 3

    example_shape = [args.batch_size, 300, 300, input_c] if args.nhwc else [args.batch_size, input_c, 300, 300]

    example_input = torch.randn(*example_shape).npu()



    if args.use_fp16:

        example_input = example_input.half()

    

    if args.jit:

        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,

        # the resulting ScriptModule will elide this control flow, resulting in allreduce

        # hooks not being called.  If we're running distributed, we need to extract and JIT

        # the wrapped .module.

        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks

        # to go out of scope, and therefore silently disappear.

        module_to_jit = ssd300.module if args.distributed else ssd300

        if args.distributed:

            ssd300.module = torch.jit.trace(module_to_jit, example_input, check_trace=False)

        else:

            ssd300 = torch.jit.trace(module_to_jit, example_input, check_trace=False)

        # JIT the eval model too

        ssd300_eval = torch.jit.trace(ssd300_eval, example_input, check_trace=False)



    # do a dummy fprop & bprop to make sure cudnnFind etc. are timed here

    ploc, plabel = ssd300(example_input)



    # produce a single dummy "loss" to make things easier

    loss = ploc[0,0,0] + plabel[0,0,0]

    dloss = torch.randn_like(loss)

    # Cause cudnnFind for dgrad, wgrad to run

    loss.backward(dloss)



    # Necessary import in init

    #from pycocotools.coco import COCO



    encoder = build_ssd300_coder()



    evaluator = AsyncEvaluator(num_threads=1)



    log_end(key=constants.INIT_STOP)



    ##### END INIT



    # This is the first place we touch anything related to data

    ##### START DATA TOUCHING

    barrier()

    log_start(key=constants.RUN_START)

    barrier()



    train_pipe = prebuild_pipeline(args)

       

    train_loader, epoch_size = build_pipeline(args, training=True, pipe=train_pipe)

    if args.rank == 0:

        print("epoch size is: ", epoch_size, " images")



    val_loader, inv_map, cocoGt = build_pipeline(args, training=False)

    if args.profile_gc_off:

        gc.disable()

        gc.collect()



    ##### END DATA TOUCHING

    i_eval = 0

    block_start_epoch = 1

    log_start(key=constants.BLOCK_START,

              metadata={'first_epoch_num': block_start_epoch,

                        'epoch_count': args.evaluation[i_eval]})

    for epoch in range(args.epochs):

        optim.zero_grad()

        



        if epoch in args.evaluation:

            # Get the existant state from the train model

            # * if we use distributed, then we want .module

            train_model = ssd300.module if args.distributed else ssd300



            if args.distributed and args.allreduce_running_stats:

                if args.rank == 0: print("averaging bn running means and vars")

                # make sure every node has the same running bn stats before

                # using them to evaluate, or saving the model for inference

                world_size = float(torch.distributed.get_world_size())

                for bn_name, bn_buf in train_model.named_buffers(recurse=True):

                    if ('running_mean' in bn_name) or ('running_var' in bn_name):

                        torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)

                        bn_buf /= world_size



            if args.rank == 0:

                if args.save:

                    print("saving model...")

                    if not os.path.isdir('./models'):

                        os.mkdir('./models')

                    torch.save({"model" : ssd300.state_dict()}, "./models/iter_{}.pt".format(iter_num))

            

            ssd300_eval.load_state_dict(train_model.state_dict())

            # Note: No longer returns, evaluation is abstracted away inside evaluator

            coco_eval(args,

                      ssd300_eval,

                      val_loader,

                      cocoGt,

                      encoder,

                      inv_map,

                      epoch,

                      iter_num,

                      evaluator=evaluator)

            log_end(key=constants.BLOCK_STOP, metadata={'first_epoch_num': block_start_epoch})

            if epoch != max(args.evaluation):

                i_eval += 1

                block_start_epoch = epoch + 1

                log_start(key=constants.BLOCK_START,

                          metadata={'first_epoch_num': block_start_epoch,

                                    'epoch_count': (args.evaluation[i_eval] -

                                                    args.evaluation[i_eval - 1])})



        if epoch in args.lr_decay_epochs:

            current_lr *= args.lr_decay_factor

            print_message(args.rank, "lr decay step #" + str(bisect(args.lr_decay_epochs, epoch)))

            for param_group in optim.param_groups:

                param_group['lr'] = current_lr



        log_start(key=constants.EPOCH_START,

                  metadata={'epoch_num': epoch + 1,

                            'current_iter_nufm': iter_num})



        for i, data in enumerate(train_loader):

            (img, bbox, label, _) = data

            img = img.npu()

            bbox = bbox.npu()

            label = label.npu()

            if args.profile_start is not None and iter_num == args.profile_start:

                torch.npu.profiler.start()

                torch.npu.synchronize()

                if args.profile_nvtx:

                    torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)



            if args.profile is not None and iter_num == args.profile:

                if args.profile_start is not None and iter_num >=args.profile_start:

                    # we turned npu and nvtx profiling on, better turn it off too

                    if args.profile_nvtx:

                        torch.autograd._disable_profiler()

                    torch.npu.profiler.stop()

                return



            if args.warmup is not None:

                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr, args)



            if (img is None) or (bbox is None) or (label is None):

                print("No labels in batch")

                continue



            ploc, plabel = ssd300(img)

            ploc, plabel = ploc.float(), plabel.float()



            N = img.shape[0]

            bbox.requires_grad = False

            label.requires_grad = False

            # reshape (N*8732X4 -> Nx8732x4) and transpose (Nx8732x4 -> Nx4x8732)

            bbox = bbox.view(N, -1, 4).transpose(1,2).contiguous()

            # reshape (N*8732 -> Nx8732) and cast to Long

            label = label.view(N, -1).long()

            loss = loss_func(ploc, plabel, bbox, label)



            if np.isfinite(loss.item()):

                avg_loss = 0.999*avg_loss + 0.001*loss.item()

            else:

                print("model exploded (corrupted by Inf or Nan)")

                sys.exit()



            num_elapsed_samples += N

            # if args.rank == 0 and iter_num % args.print_interval == 0:

            if args.rank == 0 and iter_num % args.print_interval == 0:

                end_elapsed_time = time.time()

                elapsed_time = end_elapsed_time - start_elapsed_time



                avg_samples_per_sec = num_elapsed_samples * args.N_gpu / elapsed_time



                print("Epoch:{:4d}, Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\

                            .format(epoch, iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")



                last_printed_iter = iter_num

                start_elapsed_time = time.time()

                num_elapsed_samples = 0





            with amp.scale_loss(loss, optim) as scaled_loss:

                scaled_loss.backward()





            optim.step()





            # Likely a decent skew here, let's take this opportunity to set the

            # gradients to None.  After DALI integration, playing with the

            # placement of this is worth trying.



            optim.zero_grad()



            # Don't check every iteration due to cost of broadcast

            if iter_num % 20 == 0:

                finished = check_async_evals(args, evaluator, args.threshold)



                if finished:

                    return True



            iter_num += 1



        #train_loader.reset()

        log_end(key=constants.EPOCH_STOP, metadata={'epoch_num': epoch + 1})



    return False



def main():

    # torch.multiprocessing.set_start_method('spawn')

    configure_logger(constants.SSD)

    log_start(key=constants.INIT_START, log_all_ranks=True)

    args = parse_args()

    sys.stdout = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))

    sys.stderr = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))#1p

    if args.local_rank == 0:

        print(args)



    # make sure the epoch lists are in sorted order

    args.evaluation.sort()

    args.lr_decay_epochs.sort()



    validate_arguments(args)



    torch.set_num_threads(1)

    torch.backends.cudnn.benchmark = not args.profile_cudnn_get



    success = train300_mlperf_coco(args)

    status = 'success' if success else 'aborted'



    # end timing here

    log_end(key=constants.RUN_STOP, metadata={'status': status})





if __name__ == "__main__":



    main()