from __future__ import print_function
import os
import sys
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.metrics import average_precision_score
from datasets import get_num_classes, get_class_names
from models import get_model
from base_trainer import BaseTrainer
from functools import partial
from opts import get_arguments
from core.config import cfg, cfg_from_file, cfg_from_list
from datasets.utils import Colorize
from losses import get_criterion, mask_loss_ce
from utils.timer import Timer
from utils.stat_manager import StatManager
from utils.metrics import compute_jaccard
from apex import amp
from models.stage_net import balanced_mask_loss_ce
torch.backends.cudnn.benchmark = True
DEBUG = False
world_size = 8
dist_backend = 'hccl'
import random
def seed_everything():
random.seed(1021202930)
os.environ['PYTHONHASHSEED'] = str(1021202930)
np.random.seed(1021202930)
torch.manual_seed(1021202930)
from torch.utils import data
from datasets.pascal_voc import VOCSegmentation
datasets = {
'pascal_voc': VOCSegmentation
}
def get_dataloader(args, cfg, split, batch_size=None, test_mode=None):
assert split in ('train', 'train_voc', 'val'), "Unknown split '{}'".format(split)
dataset_name = args.dataset.lower()
dataset_cls = datasets[dataset_name]
dataset = dataset_cls(cfg, split, test_mode)
kwargs = {'num_workers': args.workers, 'pin_memory': True}
shuffle, drop_last = [True, True] if split == 'train' else [False, False]
train_sampler = data.distributed.DistributedSampler(dataset)
if batch_size is None:
batch_size = cfg.TRAIN.BATCH_SIZE
return data.DataLoader(dataset,
batch_size=batch_size,
drop_last=drop_last,
shuffle=False,
sampler=train_sampler if split == 'train' else None,
**kwargs), train_sampler
def rescale_as(x, y, mode="bilinear", align_corners=True):
h, w = y.size()[2:]
x = F.interpolate(x, size=[h, w], mode=mode, align_corners=align_corners)
return x
class DecTrainer(BaseTrainer):
def __init__(self, args, **kwargs):
super(DecTrainer, self).__init__(args, **kwargs)
self.trainloader,self.train_sampler = get_dataloader(args, cfg, 'train')
self.trainloader_val, _ = get_dataloader(args, cfg, 'train_voc')
self.valloader, _ = get_dataloader(args, cfg, 'val')
self.denorm = self.trainloader.dataset.denorm
self.nclass = get_num_classes(args)
self.classNames = get_class_names(args)[:-1]
assert self.nclass == len(self.classNames)
self.classIndex = {}
for i, cname in enumerate(self.classNames):
self.classIndex[cname] = i
self.enc = get_model(cfg.NET, num_classes=self.nclass).npu()
self.criterion_cls = get_criterion(cfg.NET.LOSS).npu()
print(self.enc)
enc_params = self.enc.parameter_groups(cfg.NET.LR, cfg.NET.WEIGHT_DECAY)
self.optim_enc = self.get_optim(enc_params, cfg.NET)
self.enc, self.optim_enc = amp.initialize(self.enc, self.optim_enc, opt_level="O1",loss_scale = 128, combine_grad=True)
self._define_checkpoint('enc', self.enc, self.optim_enc)
self._load_checkpoint(args.resume)
self.fixed_batch = None
self.fixed_batch_path = args.fixed_batch_path
if os.path.isfile(self.fixed_batch_path):
print("Loading fixed batch from {}".format(self.fixed_batch_path))
self.fixed_batch = torch.load(self.fixed_batch_path)
self.enc = torch.nn.parallel.DistributedDataParallel(self.enc, device_ids=[args.local_rank], find_unused_parameters = True, broadcast_buffers=False)
def step(self, epoch, image, gt_labels, train=False, visualise=False):
PRETRAIN = epoch < (11 if DEBUG else cfg.TRAIN.PRETRAIN)
image_raw = self.denorm(image.clone())
cls_out, cls_fg, masks, mask_logits, pseudo_gt = self.enc(image.npu(), image_raw.npu(), gt_labels.npu())
loss_cls = self.criterion_cls(cls_out.npu(), gt_labels.npu()).mean()
loss_mask = balanced_mask_loss_ce(mask_logits.npu(), pseudo_gt.npu(), gt_labels.npu())
losses = {"loss_cls": loss_cls.item(),
"loss_fg": cls_fg.mean().item()}
loss = loss_cls.clone()
if "dec" in masks:
loss_mask = loss_mask.mean()
if not PRETRAIN:
loss += cfg.NET.MASK_LOSS_BCE * loss_mask
assert not "pseudo" in masks
masks["pseudo"] = pseudo_gt
losses["loss_mask"] = loss_mask.item()
losses["loss"] = loss.item()
if train:
self.optim_enc.zero_grad()
with amp.scale_loss(loss, self.optim_enc) as scaled_loss:
scaled_loss.backward()
self.optim_enc.step()
for mask_key, mask_val in masks.items():
masks[mask_key] = masks[mask_key].detach()
mask_logits = mask_logits.detach()
if visualise:
self._visualise(epoch, image, masks, mask_logits, cls_out, gt_labels)
return losses, cls_out.detach(), masks, mask_logits
def adjust_learning_rate(self, epoch, t=10):
if epoch == 10:
self.optim_enc.param_groups[0]['lr'] = self.optim_enc.param_groups[0]['lr'] * 0.1
print("LR_reduce: ", self.optim_enc.param_groups[0]['lr'])
if epoch == 15:
self.optim_enc.param_groups[0]['lr'] = self.optim_enc.param_groups[0]['lr'] * 0.1
print("LR_reduce: ", self.optim_enc.param_groups[0]['lr'])
if epoch == 20:
self.optim_enc.param_groups[0]['lr'] = self.optim_enc.param_groups[0]['lr'] * 0.1
print("LR_reduce: ", self.optim_enc.param_groups[0]['lr'])
def train_epoch(self, epoch):
self.enc.train()
stat = StatManager()
stat.add_val("loss")
stat.add_val("loss_cls")
stat.add_val("loss_fg")
stat.add_val("loss_bce")
self.adjust_learning_rate(epoch)
print("LR: ", self.optim_enc.state_dict()['param_groups'][0]['lr'])
timer = Timer("New Epoch: ")
train_step = partial(self.step, train=True, visualise=False)
for i, (image, gt_labels, _) in enumerate(self.trainloader):
losses, _, _, _ = train_step(epoch, image, gt_labels)
if self.fixed_batch is None:
self.fixed_batch = {}
self.fixed_batch["image"] = image.clone()
self.fixed_batch["labels"] = gt_labels.clone()
torch.save(self.fixed_batch, self.fixed_batch_path)
for loss_key, loss_val in losses.items():
stat.update_stats(loss_key, loss_val)
if i % 10 == 0:
msg = "Loss [{:04d}]: ".format(i)
for loss_key, loss_val in losses.items():
msg += "{}: {:.4f} | ".format(loss_key, loss_val)
msg += " | Im/Sec: {:.1f}".format(i * cfg.TRAIN.BATCH_SIZE / timer.get_stage_elapsed())
print(msg)
sys.stdout.flush()
del image, gt_labels
if DEBUG and i > 100:
break
def publish_loss(stats, name, t, prefix='data/'):
print("{}: {:4.3f}".format(name, stats.summarize_key(name)))
for stat_key in stat.vals.keys():
publish_loss(stat, stat_key, epoch)
for ii, l in enumerate(self.optim_enc.param_groups):
print("Learning rate [{}]: {:4.3e}".format(ii, l['lr']))
self.writer.add_scalar('lr/enc_group_%02d' % ii, l['lr'], epoch)
self.enc.eval()
with torch.no_grad():
self.step(epoch, self.fixed_batch["image"], \
self.fixed_batch["labels"], \
train=False, visualise=False)
def _mask_rgb(self, masks, image_norm):
masks_conf, masks_idx = torch.max(masks, 1)
masks_conf = masks_conf - F.relu(masks_conf - 1, 0)
masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu())
return 0.3 * image_norm + 0.7 * masks_idx_rgb
def _init_norm(self):
self.trainloader.dataset.set_norm(self.enc.normalize)
self.valloader.dataset.set_norm(self.enc.normalize)
self.trainloader_val.dataset.set_norm(self.enc.normalize)
def _apply_cmap(self, mask_idx, mask_conf):
palette = self.trainloader.dataset.get_palette()
masks = []
col = Colorize()
mask_conf = mask_conf.float() / 255.0
for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)):
m = col(mask).float()
m = m * conf
masks.append(m[None, ...])
return torch.cat(masks, 0)
def validation(self, epoch, writer, loader, checkpoint=False):
stat = StatManager()
def eval_batch(image, gt_labels):
losses, cls, masks, mask_logits = \
self.step(epoch, image, gt_labels, train=False, visualise=False)
for loss_key, loss_val in losses.items():
stat.update_stats(loss_key, loss_val)
return cls.cpu(), masks, mask_logits.cpu()
self.enc.eval()
targets_all = []
preds_all = []
def add_stats(means, stds, x):
means.append(x.mean())
stds.append(x.std())
for n, (image, gt_labels, _) in enumerate(loader):
with torch.no_grad():
cls_raw, masks_all, mask_logits = eval_batch(image, gt_labels)
cls_sigmoid = torch.sigmoid(cls_raw).numpy()
preds_all.append(cls_sigmoid)
targets_all.append(gt_labels.cpu().numpy())
targets_stacked = np.vstack(targets_all)
preds_stacked = np.vstack(preds_all)
aps = average_precision_score(targets_stacked, preds_stacked, average=None)
offset = self.nclass - aps.size
assert offset == 1, 'Class number mismatch'
classNames = self.classNames[offset:]
for ni, className in enumerate(classNames):
writer.add_scalar('%02d_%s/AP' % (ni + offset, className), aps[ni], epoch)
print("AP_{}: {:4.3f}".format(className, aps[ni]))
meanAP = np.mean(aps)
writer.add_scalar('all_wo_BG/mAP', meanAP, epoch)
print('mAP: {:4.3f}'.format(meanAP))
for stat_key in stat.vals.keys():
writer.add_scalar('all/{}'.format(stat_key), stat.summarize_key(stat_key), epoch)
if checkpoint and epoch >= cfg.TRAIN.PRETRAIN:
proxy_score = 1 - stat.summarize_key("loss")
writer.add_scalar('all/checkpoint_score', proxy_score, epoch)
self.checkpoint_best(proxy_score, epoch)
def _visualise(self, epoch, image, masks, mask_logits, cls_out, gt_labels):
image_norm = self.denorm(image.clone()).cpu()
visual = [image_norm]
if "cam" in masks:
visual.append(self._mask_rgb(masks["cam"], image_norm))
if "dec" in masks:
visual.append(self._mask_rgb(masks["dec"], image_norm))
if "pseudo" in masks:
pseudo_gt_rgb = self._mask_rgb(masks["pseudo"], image_norm)
ambiguous = 1 - masks["pseudo"].sum(1, keepdim=True).cpu()
pseudo_gt_rgb = ambiguous * image_norm + (1 - ambiguous) * pseudo_gt_rgb
visual.append(pseudo_gt_rgb)
visual_logits = torch.cat(visual, -1)
self._visualise_grid(visual_logits, gt_labels, epoch, scores=cls_out)
if __name__ == "__main__":
args = get_arguments(sys.argv[1:])
rank = args.local_rank
print("rank: ",rank)
torch.npu.set_device(rank)
seed_everything()
torch.distributed.init_process_group(backend=dist_backend, init_method="env://", world_size=args.world_size, rank=rank)
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print("Config: \n", cfg)
trainer = DecTrainer(args)
torch.manual_seed(0)
timer = Timer()
def time_call(func, msg, *args, **kwargs):
timer.reset_stage()
func(*args, **kwargs)
print(msg + (" {:3.2}m".format(timer.get_stage_elapsed() / 60.)))
for epoch in range(trainer.start_epoch, cfg.TRAIN.NUM_EPOCHS):
print("Epoch >>> ", epoch)
trainer.train_sampler.set_epoch(epoch)
log_int = 5 if DEBUG else 2
print("log_int: ",log_int)
if epoch % log_int == 0 and int(os.getenv("RUN_VALIDATION", 1)):
with torch.no_grad():
if not DEBUG:
time_call(trainer.validation, "Validation / Train: ", epoch, trainer.writer, trainer.trainloader_val)
time_call(trainer.validation, "Validation / Val: ", epoch, trainer.writer_val, trainer.valloader, checkpoint=True)
time_call(trainer.train_epoch, "Train epoch: ", epoch)