import os
import time
import sys
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
if torch.__version__ >= "1.8":
import torch_npu
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torchvision import transforms as trans
from tensorboardX import SummaryWriter
import apex
from apex import amp
from data.data_pipe import get_train_loader, get_val_pair
from model import Backbone, Arcface, MobileFaceNet, l2_norm
from verifacation import evaluate
from utils import get_time, gen_plot, hflip_batch, separate_bn_paras
plt.switch_backend('agg')
def flush_print(func):
def new_print(*args, **kwargs):
func(*args, **kwargs)
sys.stdout.flush()
return new_print
print = flush_print(print)
def prepare_eval_data(data_folder):
"""
when finetune, use it
"""
data_label_path = os.path.join(data_folder, 'label.txt')
transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
with open(data_label_path, 'r', encoding='utf-8') as f:
data_label_list = f.readlines()
data_label_list = [i.strip().split(' ') for i in data_label_list if len(i) > 4]
carray, issame = [], []
for data_label in data_label_list:
idx, label = data_label
img_ids = next(os.walk(data_folder / idx))[2]
print(img_ids)
if len(img_ids) != 2:
raise ValueError(f'please check eval dataset, {idx} contains multiple images ')
for img_id in img_ids:
img = Image.open(os.path.join(data_folder / idx / img_id))
img = img.convert('RGB')
img = transform(img)
carray.append(img)
if label == 'True':
issame.append(True)
else:
issame.append(False)
carray = torch.stack(carray)
issame = np.array(issame, dtype=np.bool)
return carray, issame
class face_learner(object):
def __init__(self, conf, inference=False):
if conf.is_master_node:
print(conf)
if conf.use_mobilfacenet:
self.model = MobileFaceNet(conf.embedding_size)
print('MobileFaceNet model generated')
else:
self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode)
print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))
self.model = self.model.to(conf.device)
self.model_without_ddp = self.model
if not inference:
self.milestones = conf.milestones
self.loader, self.class_num = get_train_loader(conf)
if conf.is_finetune:
self.lfw, self.lfw_issame = prepare_eval_data(conf.emore_folder / conf.eval_data_mode)
else:
self.lfw, self.lfw_issame = get_val_pair(conf.emore_folder, conf.eval_data_mode)
self.writer = SummaryWriter(conf.log_path)
self.start_epoch = conf.start_epoch
self.step = 1
self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device)
self.head_without_ddp = self.head
self.loss_func = conf.ce_loss.to(conf.device)
paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)
if conf.use_mobilfacenet:
self.optimizer = optim.SGD([
{'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
{'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4},
{'params': paras_only_bn}
], lr=conf.lr, momentum=conf.momentum)
elif conf.use_amp and conf.device_type == 'npu':
self.optimizer = apex.optimizers.NpuFusedSGD([
{'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
{'params': paras_only_bn}
], lr=conf.lr, momentum=conf.momentum)
else:
self.optimizer = optim.SGD([
{'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
{'params': paras_only_bn}
], lr=conf.lr, momentum=conf.momentum)
if conf.is_master_node:
print(self.optimizer)
print('optimizers generated')
self.board_loss_every = min(500, conf.max_iter // 10) if conf.max_iter != -1 else len(self.loader) // 10
if conf.use_amp:
if conf.device_type == 'npu':
[self.model, self.head], self.optimizer = amp.initialize([self.model, self.head],
self.optimizer,
opt_level=conf.opt_level,
loss_scale=conf.loss_scale,
combine_grad=True)
else:
[self.model, self.head], self.optimizer = amp.initialize([self.model, self.head],
self.optimizer,
opt_level=conf.opt_level,
loss_scale=conf.loss_scale)
if conf.distributed:
if conf.use_amp and conf.device_type == 'gpu':
from apex.parallel import DistributedDataParallel as DDP
self.model = DDP(self.model)
self.head = DDP(self.head)
self.head_without_ddp = self.head.module
else:
self.model = DistributedDataParallel(self.model,
device_ids=[conf.device_id],
broadcast_buffers=False)
self.model_without_ddp = self.model.module
else:
self.threshold = conf.threshold
def save_state(self, conf, extra=None, epoch=0):
save_path = conf.model_path
ckpt = {
'model': self.model_without_ddp.state_dict(),
'head': self.head_without_ddp.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'config': conf,
}
torch.save(
ckpt,
save_path / ('model_{}_epoch:{}_{:.5f}.pth'.format(get_time(), epoch, extra))
)
def load_state_dict(self, weights_path, is_finetune=False):
ckpt = torch.load(weights_path, map_location='cpu')
if 'model' in ckpt:
self.model_without_ddp.load_state_dict(ckpt['model'])
if not is_finetune:
self.head_without_ddp.load_state_dict(ckpt['head'])
self.optimizer.load_state_dict(ckpt['optimizer'])
else:
self.model_without_ddp.load_state_dict(ckpt)
def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)
print('========================')
print(f'{db_name}_accuracy: \t{accuracy}')
print(f'{db_name}_best_threshold: \t{best_threshold}')
print('========================')
def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
self.model.eval()
idx = 0
embeddings = np.zeros([len(carray), conf.embedding_size])
with torch.no_grad():
while idx + conf.batch_size <= len(carray):
batch = torch.tensor(carray[idx:idx + conf.batch_size])
if tta:
fliped = hflip_batch(batch)
emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).detach().cpu().numpy()
else:
embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).detach().cpu().numpy()
idx += conf.batch_size
if idx < len(carray):
batch = torch.tensor(carray[idx:])
if tta:
fliped = hflip_batch(batch)
emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
embeddings[idx:] = l2_norm(emb_batch).detach().cpu().numpy()
else:
embeddings[idx:] = self.model(batch.to(conf.device)).detach().cpu().numpy()
tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
buf = gen_plot(fpr, tpr)
roc_curve = Image.open(buf)
roc_curve_tensor = trans.ToTensor()(roc_curve)
return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor
def train(self, conf, epochs):
self.model.train()
self.head.train()
running_loss = 0.
for e in range(self.start_epoch, epochs):
if conf.is_master_node:
print('epoch {} started'.format(e))
if e in self.milestones:
self.schedule_lr(conf.is_master_node)
if conf.distributed:
print("distributed set epoch: ", e)
self.loader.sampler.set_epoch(e)
res = []
for index, (imgs, labels) in enumerate(self.loader):
start = time.time()
imgs = imgs.to(conf.device)
labels = labels.to(conf.device)
self.optimizer.zero_grad()
embeddings = self.model(imgs)
thetas = self.head(embeddings, labels)
loss = self.loss_func(thetas, labels)
if conf.use_amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
running_loss += scaled_loss.item()
else:
loss.backward()
running_loss += loss.item()
self.optimizer.step()
end = time.time()
res.append(end - start)
if conf.is_master_node and index % self.board_loss_every == 0 and index != 0:
loss_board = running_loss / self.board_loss_every
now_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
print(f'time:[{now_time}] \tstep: [{index}] \ttrain_loss: {loss_board}')
self.writer.add_scalar('train_loss', loss_board, self.step)
running_loss = 0.
if index == conf.max_iter:
break
self.step += 1
if conf.is_master_node:
accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.lfw, self.lfw_issame)
self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor)
self.model.train()
self.save_state(conf, extra=accuracy, epoch=e)
if conf.is_master_node:
time_sum = sum(res[5:])
print('***************************************')
print("fps: %f" % (((len(res) - 5) * self.loader.batch_size) / time_sum))
print('***************************************')
def schedule_lr(self, is_master_node):
for params in self.optimizer.param_groups:
params['lr'] /= 10
if is_master_node:
print(self.optimizer)
def infer(self, conf, faces, target_embs, tta=False):
"""
faces : list of PIL Image
target_embs : [n, 512] computed embeddings of faces in facebank
names : recorded names of faces in facebank
tta : test time augmentation (hfilp, that's all)
"""
embs = []
for img in faces:
if tta:
mirror = trans.functional.hflip(img)
emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))
emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0))
embs.append(l2_norm(emb + emb_mirror))
else:
embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)))
source_embs = torch.cat(embs)
diff = source_embs.unsqueeze(-1) - target_embs.transpose(1, 0).unsqueeze(0)
dist = torch.sum(torch.pow(diff, 2), dim=1)
minimum, min_idx = torch.min(dist, dim=1)
min_idx[minimum > self.threshold] = -1
return min_idx, minimum