"""Loss functions."""
import torch
import torch.nn as nn
from utils.metrics import bbox_iou
from utils.torch_utils import de_parallel
def smooth_BCE(eps=0.1):
"""Returns label smoothing BCE targets for reducing overfitting; pos: `1.0 - 0.5*eps`, neg: `0.5*eps`. For details see https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441."""
return 1.0 - 0.5 * eps, 0.5 * eps
class BCEBlurWithLogitsLoss(nn.Module):
"""Modified BCEWithLogitsLoss to reduce missing label effects in YOLOv5 training with optional alpha smoothing."""
def __init__(self, alpha=0.05):
"""Initializes a modified BCEWithLogitsLoss with reduced missing label effects, taking optional alpha smoothing
parameter.
"""
super().__init__()
self.loss_fcn = nn.BCEWithLogitsLoss(reduction="none")
self.alpha = alpha
def forward(self, pred, true):
"""Computes modified BCE loss for YOLOv5 with reduced missing label effects, taking pred and true tensors,
returns mean loss.
"""
loss = self.loss_fcn(pred, true)
pred = torch.sigmoid(pred)
dx = pred - true
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
loss *= alpha_factor
return loss.mean()
class FocalLoss(nn.Module):
"""Applies focal loss to address class imbalance by modifying BCEWithLogitsLoss with gamma and alpha parameters."""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
"""Initializes FocalLoss with specified loss function, gamma, and alpha values; modifies loss reduction to
'none'.
"""
super().__init__()
self.loss_fcn = loss_fcn
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = "none"
def forward(self, pred, true):
"""Calculates the focal loss between predicted and true labels using a modified BCEWithLogitsLoss."""
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred)
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
else:
return loss
class QFocalLoss(nn.Module):
"""Implements Quality Focal Loss to address class imbalance by modulating loss based on prediction confidence."""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
"""Initializes Quality Focal Loss with given loss function, gamma, alpha; modifies reduction to 'none'."""
super().__init__()
self.loss_fcn = loss_fcn
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = "none"
def forward(self, pred, true):
"""Computes the focal loss between `pred` and `true` using BCEWithLogitsLoss, adjusting for imbalance with
`gamma` and `alpha`.
"""
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
else:
return loss
class ComputeLoss:
"""Computes the total loss for YOLOv5 model predictions, including classification, box, and objectness losses."""
sort_obj_iou = False
def __init__(self, model, autobalance=False):
"""Initializes ComputeLoss with model and autobalance option, autobalances losses if True."""
device = next(model.parameters()).device
h = model.hyp
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["obj_pw"]], device=device))
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0))
g = h["fl_gamma"]
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
m = de_parallel(model).model[-1]
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])
self.ssi = list(m.stride).index(16) if autobalance else 0
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
self.na = m.na
self.nc = m.nc
self.nl = m.nl
self.anchors = m.anchors
self.device = device
def __call__(self, p, targets):
"""Performs forward pass, calculating class, box, and object loss for given predictions and targets."""
lcls = torch.zeros(1, device=self.device)
lbox = torch.zeros(1, device=self.device)
lobj = torch.zeros(1, device=self.device)
tcls, tbox, indices, anchors = self.build_targets(p, targets)
for i, pi in enumerate(p):
b, a, gj, gi = indices[i]
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)
n = b.shape[0]
if n:
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1)
pxy = pxy.sigmoid() * 2 - 0.5
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1)
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()
lbox += (1.0 - iou).mean()
iou = iou.detach().clamp(0).type(tobj.dtype)
if self.sort_obj_iou:
j = iou.argsort()
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
if self.gr < 1:
iou = (1.0 - self.gr) + self.gr * iou
tobj[b, a, gj, gi] = iou
if self.nc > 1:
t = torch.full_like(pcls, self.cn, device=self.device)
t[range(n), tcls[i]] = self.cp
lcls += self.BCEcls(pcls, t)
obji = self.BCEobj(pi[..., 4], tobj)
lobj += obji * self.balance[i]
if self.autobalance:
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
if self.autobalance:
self.balance = [x / self.balance[self.ssi] for x in self.balance]
lbox *= self.hyp["box"]
lobj *= self.hyp["obj"]
lcls *= self.hyp["cls"]
bs = tobj.shape[0]
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
def build_targets(self, p, targets):
"""Prepares model targets from input targets (image,class,x,y,w,h) for loss computation, returning class, box,
indices, and anchors.
"""
na, nt = self.na, targets.shape[0]
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=self.device)
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)
g = 0.5
off = (
torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1],
],
device=self.device,
).float()
* g
)
for i in range(self.nl):
anchors, shape = self.anchors[i], p[i].shape
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]
t = targets * gain
if nt:
r = t[..., 4:6] / anchors[:, None]
j = torch.max(r, 1 / r).max(2)[0] < self.hyp["anchor_t"]
t = t[j]
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
bc, gxy, gwh, a = t.chunk(4, 1)
a, (b, c) = a.long().view(-1), bc.long().T
gij = (gxy - offsets).long()
gi, gj = gij.T
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))
tbox.append(torch.cat((gxy - gij, gwh), 1))
anch.append(anchors[a])
tcls.append(c)
return tcls, tbox, indices, anch