"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
import torch.nn as nn
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
from config import cfg
from utils.reid_metric import R1_mAP
from apex import amp
global ITER
ITER = 0
def create_supervised_trainer(model, optimizer, loss_fn,
device=None):
"""
Factory function for creating a trainer for supervised models
Args:
model (`torch.nn.Module`): the model to train
optimizer (`torch.optim.Optimizer`): the optimizer to use
loss_fn (torch.nn loss function): the loss function to use
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: a trainer engine with supervised update function
"""
def _update(engine, batch):
model.train()
optimizer.zero_grad()
img, target = batch
img = img.to(device) if torch.npu.device_count() >= 1 or torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.npu.device_count() >= 1 or torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
if "npu" in cfg.MODEL.DEVICE:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc = (score.max(1)[1] == target).float().mean()
return loss.item(), acc.item()
return Engine(_update)
def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, center_loss_weight,
device=None):
"""
Factory function for creating a trainer for supervised models
Args:
model (`torch.nn.Module`): the model to train
optimizer (`torch.optim.Optimizer`): the optimizer to use
loss_fn (torch.nn loss function): the loss function to use
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: a trainer engine with supervised update function
"""
def _update(engine, batch):
model.train()
optimizer.zero_grad()
optimizer_center.zero_grad()
img, target = batch
img = img.to(device) if torch.npu.device_count() >= 1 or torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.npu.device_count() >= 1 or torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
with amp.scale_loss(loss, [optimizer, optimizer_center]) as scaled_loss:
scaled_loss.backward()
optimizer.step()
for name, param in center_criterion.named_parameters():
param.grad.data *= (1. / center_loss_weight)
optimizer_center.step()
acc = (score.max(1)[1] == target).float().mean()
return loss.item(), acc.item()
return Engine(_update)
def create_supervised_evaluator(model, metrics,
device=None):
"""
Factory function for creating an evaluator for supervised models
Args:
model (`torch.nn.Module`): the model to train
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: an evaluator engine with supervised inference function
"""
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.to(device) if torch.npu.device_count() >= 1 or torch.cuda.device_count() >= 1 else data
feat = model(data)
return feat, pids, camids
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
def do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
loss_fn,
num_query,
start_epoch
):
log_period = cfg.SOLVER.LOG_PERIOD
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
eval_period = cfg.SOLVER.EVAL_PERIOD
output_dir = cfg.OUTPUT_DIR
device = cfg.MODEL.DEVICE
epochs = cfg.SOLVER.MAX_EPOCHS
logger = logging.getLogger("reid_baseline.train")
logger.info("Start training")
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False, save_as_state_dict=True)
timer = Timer(average=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
@trainer.on(Events.STARTED)
def start_training(engine):
engine.state.epoch = start_epoch
@trainer.on(Events.EPOCH_STARTED)
def adjust_learning_rate(engine):
scheduler.step()
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
global ITER
ITER += 1
if ITER % log_period == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
.format(engine.state.epoch, ITER, len(train_loader),
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
scheduler.get_lr()[0]))
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
.format(engine.state.epoch, ITER, len(train_loader),
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
scheduler.get_lr()[0]))
if len(train_loader) == ITER:
ITER = 0
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
print('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer.value() * timer.step_count,
train_loader.batch_size / timer.value()))
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer.value() * timer.step_count,
train_loader.batch_size / timer.value()))
logger.info('-' * 10)
timer.reset()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
if engine.state.epoch % eval_period == 0:
evaluator.run(val_loader)
cmc, mAP = evaluator.state.metrics['r1_mAP']
print("Validation Results - Epoch: {}".format(engine.state.epoch))
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
print("mAP: {:.1%}".format(mAP))
logger.info("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
print("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
trainer.run(train_loader, max_epochs=epochs)
def do_train_with_center(
cfg,
model,
center_criterion,
train_loader,
val_loader,
optimizer,
optimizer_center,
scheduler,
loss_fn,
num_query,
start_epoch,
num_npus,
rank
):
log_period = cfg.SOLVER.LOG_PERIOD
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
eval_period = cfg.SOLVER.EVAL_PERIOD
output_dir = cfg.OUTPUT_DIR
device = cfg.MODEL.DEVICE
epochs = cfg.SOLVER.MAX_EPOCHS
logger = logging.getLogger("reid_baseline.train")
logger.info("Start training")
trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False, save_as_state_dict=True)
timer1 = Timer(average=True)
timer2 = Timer(average=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer,
'center_param': center_criterion,
'optimizer_center': optimizer_center})
timer1.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
timer2.attach(trainer, start=Events.STARTED, step=Events.ITERATION_COMPLETED)
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
@trainer.on(Events.STARTED)
def start_training(engine):
engine.state.epoch = start_epoch
@trainer.on(Events.EPOCH_STARTED)
def adjust_learning_rate(engine):
scheduler.step()
@trainer.on(Events.ITERATION_STARTED)
def log_time(engine):
if ITER >= 4:
timer2.resume()
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
global ITER
ITER += 1
data_len = torch.tensor(len(train_loader)).float().npu()
if num_npus > 1:
torch.distributed.all_reduce(data_len)
avg_data_len = int(data_len // num_npus)
if ITER % log_period == 0:
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
.format(engine.state.epoch, ITER, avg_data_len,
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
scheduler.get_lr()[0]))
if ITER == avg_data_len:
ITER = 0
engine.should_terminate_single_epoch = True
if ITER == 4:
timer1.reset()
if engine.state.epoch == 1 :
timer2.reset()
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
if rank == 0:
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer1.value() * timer1.step_count,
num_npus * train_loader.batch_size / timer1.value()))
logger.info("total time: {:.3f}".format(num_npus * train_loader.batch_size * engine.state.epoch /timer2.value()))
logger.info('-' * 10)
timer2.pause()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
if engine.state.epoch % eval_period == 0:
evaluator.run(val_loader)
cmc, mAP = evaluator.state.metrics['r1_mAP']
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
logger.info("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
trainer.run(train_loader, max_epochs=epochs)