"""
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Copyright 2020 Huawei Technologies Co., Ltd
Licensed under the BSD 3-Clause License (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://spdx.org/licenses/BSD-3-Clause.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
""" Mixup and Cutmix
Papers:
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch
Hacked together by / Copyright 2020 Ross Wightman
"""
import numpy as np
import torch
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
return y1 * lam + y2 * (1. - lam)
def rand_bbox(img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def rand_bbox_minmax(img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
class Mixup:
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam
self.mixup_enabled = True
def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
return lam, use_cutmix
def _params_per_batch(self):
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix
def _mix_elem(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone()
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
def _mix_pair(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
x_orig = x.clone()
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
def _mix_batch(self, x):
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
else:
x_flipped = x.flip(0).mul_(1. - lam)
x.mul_(lam).add_(x_flipped)
return lam
def __call__(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target
class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
A Mixup impl that's performed while collating the batches.
"""
def _mix_elem_collate(self, output, batch, half=False):
batch_size = len(batch)
num_elem = batch_size // 2 if half else batch_size
assert len(output) == num_elem
lam_batch, use_cutmix = self._params_per_elem(num_elem)
for i in range(num_elem):
j = batch_size - i - 1
lam = lam_batch[i]
mixed = batch[i][0]
if lam != 1.:
if use_cutmix[i]:
if not half:
mixed = mixed.copy()
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
if half:
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_pair_collate(self, output, batch):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
mixed_i = batch[i][0]
mixed_j = batch[j][0]
assert 0 <= lam <= 1.0
if lam < 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
mixed_j[:, yl:yh, xl:xh] = patch_i
lam_batch[i] = lam
else:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j)
np.rint(mixed_i, out=mixed_i)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_batch_collate(self, output, batch):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
for i in range(batch_size):
j = batch_size - i - 1
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy()
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam
def __call__(self, batch, _=None):
batch_size = len(batch)
assert batch_size % 2 == 0, 'Batch size should be even when using this'
half = 'half' in self.mode
if half:
batch_size //= 2
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
if self.mode == 'elem' or self.mode == 'half':
lam = self._mix_elem_collate(output, batch, half=half)
elif self.mode == 'pair':
lam = self._mix_pair_collate(output, batch)
else:
lam = self._mix_batch_collate(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = target[:batch_size]
return output, target