# -*- coding: utf-8 -*-

# BSD 3-Clause License

#

# Copyright (c) 2017

# All rights reserved.

# Copyright 2022 Huawei Technologies Co., Ltd

#

# Redistribution and use in source and binary forms, with or without

# modification, are permitted provided that the following conditions are met:

#

# * Redistributions of source code must retain the above copyright notice, this

#   list of conditions and the following disclaimer.

#

# * Redistributions in binary form must reproduce the above copyright notice,

#   this list of conditions and the following disclaimer in the documentation

#   and/or other materials provided with the distribution.

#

# * Neither the name of the copyright holder nor the names of its

#   contributors may be used to endorse or promote products derived from

#   this software without specific prior written permission.

#

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"

# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE

# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# ==========================================================================





# Copyright (c) 2015-present, Facebook, Inc.

# All rights reserved.

"""

Implements the knowledge distillation loss

"""

import torch

from torch.nn import functional as F





class DistillationLoss(torch.nn.Module):

    """

    This module wraps a standard criterion and adds an extra knowledge distillation loss by

    taking a teacher model prediction and using it as additional supervision.

    """

    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,

                 distillation_type: str, alpha: float, tau: float):

        super().__init__()

        self.base_criterion = base_criterion

        self.teacher_model = teacher_model

        assert distillation_type in ['none', 'soft', 'hard']

        self.distillation_type = distillation_type

        self.alpha = alpha

        self.tau = tau



    def forward(self, inputs, outputs, labels):

        """

        Args:

            inputs: The original inputs that are feed to the teacher model

            outputs: the outputs of the model to be trained. It is expected to be

                either a Tensor, or a Tuple[Tensor, Tensor], with the original output

                in the first position and the distillation predictions as the second output

            labels: the labels for the base criterion

        """

        outputs_kd = None

        if not isinstance(outputs, torch.Tensor):

            # assume that the model outputs a tuple of [outputs, outputs_kd]

            outputs, outputs_kd = outputs

        base_loss = self.base_criterion(outputs, labels)

        if self.distillation_type == 'none':

            return base_loss



        if outputs_kd is None:

            raise ValueError("When knowledge distillation is enabled, the model is "

                             "expected to return a Tuple[Tensor, Tensor] with the output of the "

                             "class_token and the dist_token")

        # don't backprop throught the teacher

        with torch.no_grad():

            teacher_outputs = self.teacher_model(inputs)



        if self.distillation_type == 'soft':

            T = self.tau

            # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100

            # with slight modifications

            distillation_loss = F.kl_div(

                F.log_softmax(outputs_kd / T, dim=1),

                F.log_softmax(teacher_outputs / T, dim=1),

                reduction='sum',

                log_target=True

            ) * (T * T) / outputs_kd.numel()

        elif self.distillation_type == 'hard':

            distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))



        loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha

        return loss