# -*- coding: utf-8 -*-

# BSD 3-Clause License

#

# Copyright (c) 2017

# All rights reserved.

# Copyright 2022 Huawei Technologies Co., Ltd

#

# Redistribution and use in source and binary forms, with or without

# modification, are permitted provided that the following conditions are met:

#

# * Redistributions of source code must retain the above copyright notice, this

#   list of conditions and the following disclaimer.

#

# * Redistributions in binary form must reproduce the above copyright notice,

#   this list of conditions and the following disclaimer in the documentation

#   and/or other materials provided with the distribution.

#

# * Neither the name of the copyright holder nor the names of its

#   contributors may be used to endorse or promote products derived from

#   this software without specific prior written permission.

#

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"

# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE

# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# ==========================================================================





# Copyright (c) 2015-present, Facebook, Inc.

# All rights reserved.

"""

Train and eval functions used in main.py

"""

import math

from operator import mod

import sys

from typing import Iterable, Optional

import time

import logging

import os

import torch



from timm.data import Mixup

from timm.utils import accuracy, ModelEma



from losses import DistillationLoss

import utils

import apex.amp



def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,

                    data_loader: Iterable, optimizer: torch.optim.Optimizer,

                    device: torch.device, epoch: int, 

                    output_dir: str, batch_size: int,

                    max_norm: float = 0,

                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,

                    set_training_mode=True, logger=logging, use_npu=False):

    model.train(set_training_mode)

    metric_logger = utils.MetricLogger(delimiter="  ", logger=logger)

    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    header = 'Epoch: [{}]'.format(epoch)

    print_freq = 10



    for samples, targets in metric_logger.log_every(data_loader, print_freq, batch_size, header, use_npu=use_npu):

        samples = samples.to(device, non_blocking=True)

        targets = targets.to(device, non_blocking=True)



        if mixup_fn is not None:

            samples, targets = mixup_fn(samples, targets)



        outputs = model(samples)

        loss = criterion(samples, outputs, targets)

        loss_value = loss.item()



        if not math.isfinite(loss_value):

            logger.info("Loss is {}, stopping training".format(loss_value))

            sys.exit(1)



        optimizer.zero_grad()



        # this attribute is added by timm on one optimizer (adahessian)

        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order



        with apex.amp.scale_loss(loss, optimizer) as scaled_loss:

            scaled_loss.backward(create_graph=is_second_order)



        optimizer.step()

        # loss_scaler(loss, optimizer, clip_grad=max_norm,

        #            parameters=model.parameters(), create_graph=is_second_order)



        # torch.npu.synchronize()

        # if model_ema is not None:

        #    model_ema.update(model)



        metric_logger.update(loss=loss_value)

        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes

    metric_logger.synchronize_between_processes(use_npu=use_npu, device=device)

    logger.info(f"Averaged stats: {metric_logger}")

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}





@torch.no_grad()

def evaluate(data_loader, model, device, batch_size, logger=logging, use_npu=False):

    criterion = torch.nn.CrossEntropyLoss()



    metric_logger = utils.MetricLogger(delimiter="  ")

    header = 'Test:'



    # switch to evaluation mode

    model.eval()



    for images, target in metric_logger.log_every(data_loader, 10, batch_size, header, use_npu=use_npu):

        images = images.to(device, non_blocking=True)

        target = target.to(device, non_blocking=True)



        # compute output

        # with torch.npu.amp.autocast():

        output = model(images)

        loss = criterion(output, target)



        acc1, acc5 = accuracy(output, target, topk=(1, 5))



        batch_size = images.shape[0]

        metric_logger.update(loss=loss.item())

        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)

        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    # gather the stats from all processes

    metric_logger.synchronize_between_processes(use_npu=use_npu, device=device)

    logger.info('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'

                .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))



    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}





@torch.no_grad()

def throughput(data_loader, model, logger=logging, use_npu=False):

    model.eval()



    if use_npu:

        for idx, (images, _) in enumerate(data_loader):

            images = images.npu(non_blocking=True)

            batch_size = images.shape[0]

            for i in range(50):

                model(images)

            #torch.npu.synchronize()

            logger.info(f"throughput averaged with 30 times")

            tic1 = time.time()

            for i in range(30):

                model(images)

            # torch.npu.synchronize()

            tic2 = time.time()

            logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")

        return

    else:

        for idx, (images, _) in enumerate(data_loader):

            images = images.cuda(non_blocking=True)

            batch_size = images.shape[0]

            for i in range(50):

                model(images)

            #torch.cuda.synchronize()

            logger.info(f"throughput averaged with 30 times")

            tic1 = time.time()

            for i in range(30):

                model(images)

            #torch.cuda.synchronize()

            tic2 = time.time()

            logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")

            return