from __future__ import print_function
import sys
sys.path.insert(0, '.')
import torch
if torch.__version__ >= '1.8':
import torch_npu
from torch.autograd import Variable
import torch.npu
import torch.nn as nn
from torch.nn.parallel import DataParallel
import apex
from apex import amp
import time
import numpy as np
import argparse
from aligned_reid.dataset import create_dataset
from aligned_reid.model.Model import Model
from aligned_reid.model.TripletLoss import TripletLoss
from aligned_reid.model.loss import global_loss
from aligned_reid.model.loss import local_loss
from aligned_reid.utils.utilsn1 import str2bool
from aligned_reid.utils.utilsn1 import may_set_mode
from aligned_reid.utils.utilsn1 import load_ckpt
from aligned_reid.utils.utilsn1 import save_ckpt
from aligned_reid.utils.utilsn1 import set_devices
from aligned_reid.utils.utilsn1 import AverageMeter
from aligned_reid.utils.utilsn1 import to_scalar
from aligned_reid.utils.utilsn1 import set_seed
from aligned_reid.utils.utilsn1 import adjust_lr_exp
from aligned_reid.utils.utilsn1 import adjust_lr_staircase
parser_ddp = argparse.ArgumentParser(description='PyTorch AlignedReID Training')
parser_ddp.add_argument('--data_pth', type=str, default='')
parser_ddp.add_argument('--pkl', type=str, default='')
parser_ddp.add_argument('--ids_per_batch', type=int, default=32)
parser_ddp.add_argument('--base_lr', type=float, default=2e-4)
parser_ddp.add_argument('--exp_decay_at_epoch', type=int, default=1000)
parser_ddp.add_argument('--total_epochs', type=int, default=300)
parser_ddp.add_argument('--model_weight_file', type=str, default='')
parser_ddp.add_argument('--only_test', type=str2bool, default=False)
parser_ddp.add_argument('--log_to_file', type=str2bool, default=False)
class Config(object):
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--sys_device_ids', type=eval, default=(0,))
parser.add_argument('-r', '--run', type=int, default=1)
parser.add_argument('--set_seed', type=str2bool, default=True)
parser.add_argument('--dataset', type=str, default='market1501',
choices=['market1501', 'cuhk03', 'duke', 'combined'])
parser.add_argument('--trainset_part', type=str, default='trainval',
choices=['trainval', 'train'])
parser.add_argument('--resize_h_w', type=eval, default=(256, 128))
parser.add_argument('--crop_prob', type=float, default=0)
parser.add_argument('--crop_ratio', type=float, default=1)
parser.add_argument('--ims_per_id', type=int, default=4)
parser.add_argument('--normalize_feature', type=str2bool, default=False)
parser.add_argument('--local_dist_own_hard_sample',
type=str2bool, default=True)
parser.add_argument('-gm', '--global_margin', type=float, default=0.3)
parser.add_argument('-lm', '--local_margin', type=float, default=0.3)
parser.add_argument('-glw', '--g_loss_weight', type=float, default=1)
parser.add_argument('-llw', '--l_loss_weight', type=float, default=0.)
parser.add_argument('-idlw', '--id_loss_weight', type=float, default=0.)
parser.add_argument('--resume', type=str2bool, default=False)
parser.add_argument('--exp_dir', type=str, default='')
parser.add_argument('--lr_decay_type', type=str, default='exp',
choices=['exp', 'staircase'])
parser.add_argument('--staircase_decay_at_epochs',
type=eval, default=(101, 201,))
parser.add_argument('--staircase_decay_multiply_factor',
type=float, default=0.1)
args = parser.parse_known_args()[0]
args_ddp = parser_ddp.parse_args()
self.sys_device_ids = args.sys_device_ids
self.sys_device_n = len(self.sys_device_ids)
if args.set_seed:
self.seed = 1
else:
self.seed = None
self.run = args.run
if self.seed is not None:
self.prefetch_threads = 1
else:
self.prefetch_threads = 2
self.dataset = args.dataset
self.trainset_part = args.trainset_part
self.crop_prob = args.crop_prob
self.crop_ratio = args.crop_ratio
self.resize_h_w = args.resize_h_w
self.scale_im = True
self.im_mean = [0.486, 0.459, 0.408]
self.im_std = [0.229, 0.224, 0.225]
self.ids_per_batch = args_ddp.ids_per_batch
self.ims_per_id = args.ims_per_id
self.train_final_batch = True
self.train_mirror_type = ['random', 'always', None][0]
self.train_shuffle = True
self.test_batch_size = 32
self.test_final_batch = True
self.test_mirror_type = ['random', 'always', None][2]
self.test_shuffle = False
dataset_kwargs = dict(
name=self.dataset,
resize_h_w=self.resize_h_w,
scale=self.scale_im,
im_mean=self.im_mean,
im_std=self.im_std,
batch_dims='NCHW',
num_prefetch_threads=self.prefetch_threads)
prng = np.random
if self.seed is not None:
prng = np.random.RandomState(self.seed)
self.train_set_kwargs = dict(
part=self.trainset_part,
ids_per_batch=self.ids_per_batch,
ims_per_id=self.ims_per_id,
final_batch=self.train_final_batch,
shuffle=self.train_shuffle,
crop_prob=self.crop_prob,
crop_ratio=self.crop_ratio,
mirror_type=self.train_mirror_type,
prng=prng)
self.train_set_kwargs.update(dataset_kwargs)
prng = np.random
if self.seed is not None:
prng = np.random.RandomState(self.seed)
self.test_set_kwargs = dict(
part='test',
batch_size=self.test_batch_size,
final_batch=self.test_final_batch,
shuffle=self.test_shuffle,
mirror_type=self.test_mirror_type,
prng=prng)
self.test_set_kwargs.update(dataset_kwargs)
self.local_dist_own_hard_sample = args.local_dist_own_hard_sample
self.normalize_feature = args.normalize_feature
self.local_conv_out_channels = 128
self.global_margin = args.global_margin
self.local_margin = args.local_margin
self.id_loss_weight = args.id_loss_weight
self.g_loss_weight = args.g_loss_weight
self.l_loss_weight = args.l_loss_weight
self.weight_decay = 0.0005
self.base_lr = args_ddp.base_lr
self.lr_decay_type = args.lr_decay_type
self.exp_decay_at_epoch = args_ddp.exp_decay_at_epoch
self.staircase_decay_at_epochs = args.staircase_decay_at_epochs
self.staircase_decay_multiply_factor = args.staircase_decay_multiply_factor
self.total_epochs = args_ddp.total_epochs
self.log_steps = 1
self.only_test = args_ddp.only_test
self.resume = args.resume
self.log_to_file = args_ddp.log_to_file
self.ckpt_file = './ckpt.pth'
self.model_weight_file = args_ddp.model_weight_file
class ExtractFeature(object):
"""A function to be called in the val/test set, to extract features.
Args:
TVT: A callable to transfer images to specific device.
"""
def __init__(self, model, TVT):
self.model = model
self.TVT = TVT
def __call__(self, ims):
old_train_eval_model = self.model.training
self.model.eval()
ims = Variable(self.TVT(torch.from_numpy(ims).float()))
global_feat = self.model(ims)
global_feat = global_feat.data.cpu().numpy()
self.model.train(old_train_eval_model)
return global_feat
def main():
cfg = Config()
args = parser_ddp.parse_args()
TVT, TMO = set_devices(cfg.sys_device_ids)
if cfg.seed is not None:
set_seed(cfg.seed)
import pprint
print('-' * 60)
print('cfg.__dict__')
pprint.pprint(cfg.__dict__)
print('-' * 60)
train_set = create_dataset(pth=args.data_pth, pkl=args.pkl, **cfg.train_set_kwargs)
test_sets = []
test_set_names = []
if cfg.dataset == 'combined':
for name in ['market1501', 'cuhk03', 'duke']:
cfg.test_set_kwargs['name'] = name
test_sets.append(create_dataset(pth=args.data_pth, pkl=args.pkl, **cfg.test_set_kwargs))
test_set_names.append(name)
else:
test_sets.append(create_dataset(pth=args.data_pth, pkl=args.pkl, **cfg.test_set_kwargs))
test_set_names.append(cfg.dataset)
model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
num_classes=len(train_set.ids2labels))
model = model.npu()
id_criterion = nn.CrossEntropyLoss()
g_tri_loss = TripletLoss(margin=cfg.global_margin)
l_tri_loss = TripletLoss(margin=cfg.local_margin)
optimizer = apex.optimizers.NpuFusedAdam(model.parameters(),
lr=cfg.base_lr,
weight_decay=cfg.weight_decay)
torch.backends.cudnn.enabled = True
amp.register_half_function(torch, 'addmm')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', loss_scale=128, combine_grad=True)
modules_optims = [model, optimizer]
model_w = DataParallel(model)
model_w = model_w.npu()
if cfg.resume:
resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)
TMO(modules_optims)
def test(load_model_weight=False):
if load_model_weight:
load_ckpt(modules_optims, cfg.model_weight_file)
use_local_distance = False
for test_set, name in zip(test_sets, test_set_names):
test_set.set_feat_func(ExtractFeature(model_w, TVT))
print('\n=========> Test on dataset: {} <=========\n'.format(name))
test_set.eval(
normalize_feat=cfg.normalize_feature,
use_local_distance=use_local_distance)
if cfg.only_test:
test(load_model_weight=True)
return
start_ep = resume_ep if cfg.resume else 0
for ep in range(start_ep, cfg.total_epochs):
if cfg.lr_decay_type == 'exp':
adjust_lr_exp(
optimizer,
cfg.base_lr,
ep + 1,
cfg.total_epochs,
cfg.exp_decay_at_epoch)
else:
adjust_lr_staircase(
optimizer,
cfg.base_lr,
ep + 1,
cfg.staircase_decay_at_epochs,
cfg.staircase_decay_multiply_factor)
may_set_mode(modules_optims, 'train')
g_prec_meter = AverageMeter()
g_m_meter = AverageMeter()
g_dist_ap_meter = AverageMeter()
g_dist_an_meter = AverageMeter()
g_loss_meter = AverageMeter()
loss_meter = AverageMeter()
ep_st = time.time()
step = 0
fps_all = 0
epoch_done = False
while not epoch_done:
step += 1
step_st = time.time()
ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()
ims_var = Variable(TVT(torch.from_numpy(ims).float()))
labels_t = TVT(torch.from_numpy(labels).long())
labels_var = Variable(labels_t)
global_feat = model_w(ims_var)
local_feat = logits = global_feat
g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
g_tri_loss, global_feat, labels_t,
normalize_feature=cfg.normalize_feature)
if cfg.l_loss_weight == 0:
l_loss = 0
elif cfg.local_dist_own_hard_sample:
l_loss, l_dist_ap, l_dist_an, _ = local_loss(
l_tri_loss, local_feat, None, None, labels_t,
normalize_feature=cfg.normalize_feature)
else:
l_loss, l_dist_ap, l_dist_an = local_loss(
l_tri_loss, local_feat, p_inds, n_inds, labels_t,
normalize_feature=cfg.normalize_feature)
id_loss = 0
if cfg.id_loss_weight > 0:
id_loss = id_criterion(logits, labels_var)
loss = g_loss * cfg.g_loss_weight \
+ l_loss * cfg.l_loss_weight \
+ id_loss * cfg.id_loss_weight
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
g_prec = (g_dist_an > g_dist_ap).data.float().mean()
g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean()
g_d_ap = g_dist_ap.data.mean()
g_d_an = g_dist_an.data.mean()
g_prec_meter.update(g_prec)
g_m_meter.update(g_m)
g_dist_ap_meter.update(g_d_ap)
g_dist_an_meter.update(g_d_an)
g_loss_meter.update(to_scalar(g_loss))
loss_meter.update(to_scalar(loss))
if step % cfg.log_steps == 0:
time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
step, ep + 1, time.time() - step_st, )
if cfg.g_loss_weight > 0:
if step > 2:
fps_step = cfg.ids_per_batch / (time.time() - step_st)
fps_all += fps_step
g_log = (', gp {:.2%}, gm {:.2%}, '
'gd_ap {:.4f}, gd_an {:.4f}, '
'gL {:.4f}, fps {:.2f}'.format(
g_prec_meter.val, g_m_meter.val,
g_dist_ap_meter.val, g_dist_an_meter.val,
g_loss_meter.val, fps_step))
else:
g_log = (', gp {:.2%}, gm {:.2%}, '
'gd_ap {:.4f}, gd_an {:.4f}, '
'gL {:.4f}'.format(
g_prec_meter.val, g_m_meter.val,
g_dist_ap_meter.val, g_dist_an_meter.val,
g_loss_meter.val, ))
else:
g_log = ''
l_log = ''
id_log = ''
total_loss_log = ', loss {:.4f}'.format(loss_meter.val)
log = time_log + \
g_log + l_log + id_log + \
total_loss_log
print(log)
FPS = fps_all / (step - 2)
time_log = 'Ep {}, {:.2f}s, FPS {:.2f}'.format(ep + 1, time.time() - ep_st, fps_step)
if cfg.g_loss_weight > 0:
g_log = (', gp {:.2%}, gm {:.2%}, '
'gd_ap {:.4f}, gd_an {:.4f}, '
'gL {:.4f}'.format(
g_prec_meter.avg, g_m_meter.avg,
g_dist_ap_meter.avg, g_dist_an_meter.avg,
g_loss_meter.avg, ))
else:
g_log = ''
l_log = ''
id_log = ''
total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)
log = time_log + \
g_log + l_log + id_log + \
total_loss_log
print(log)
if cfg.log_to_file:
save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)
test(load_model_weight=False)
if __name__ == '__main__':
main()