import os
import sys
import json
import torch
import argparse
import numpy as np
from easydict import EasyDict
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
sys.path.append('./models')
from mtcnn import PNet, RNet, ONet
from facenet_pytorch import MTCNN
from facenet_pytorch.models.utils.detect_face import imresample, generateBoundingBox, batched_nms, rerec, pad, bbreg, batched_nms_numpy
NET_MAP = {
'pnet': PNet,
'rnet': RNet,
'onet': ONet
}
def build_dataset(config):
orig_img_ds = datasets.ImageFolder(config.data_dir, transform=None)
orig_img_ds.samples = [(p, p)for p, _ in orig_img_ds.samples]
def collate_fn(x):
out_x, out_y = [], []
for xx, yy in x:
out_x.append(xx)
out_y.append(yy)
return out_x, out_y
loader = DataLoader(
orig_img_ds,
num_workers=config.num_workers,
batch_size=config.batch_size,
collate_fn=collate_fn
)
return loader
def dump_to_json(content, outpath):
os.makedirs(os.path.dirname(outpath), exist_ok=True)
with open(outpath, 'w') as f:
json.dump(content, f)
def load_json(json_path):
with open(json_path) as f:
return json.load(f)
class MTCNNPreprocessor():
def __init__(self, config):
self.net_name = config.net
self.net = NET_MAP[self.net_name](config)
self.threshold = [0.6, 0.7, 0.7]
self.data_device = torch.device('cpu')
def pnet_process(self, imgs):
if self.net_name != 'pnet':
raise ValueError('Pnet process not support for {} !'.format(self.net))
factor = 0.709
minsize = 20
imgs = imgs.permute(0, 3, 1, 2).type(torch.float32)
batch_size = len(imgs)
h, w = imgs.shape[2:4]
m = 12.0 / minsize
minl = min(h, w)
minl = minl * m
scale_i = m
scales = []
while minl >= 12:
scales.append(scale_i)
scale_i = scale_i * factor
minl = minl * factor
boxes = []
image_inds = []
scale_picks = []
all_i = 0
offset = 0
for scale in scales:
im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1)))
im_data = (im_data - 127.5) * 0.0078125
reg, probs = self.net.forward(im_data.cpu().numpy())
reg = torch.from_numpy(reg)
probs = torch.from_numpy(probs)
boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, self.threshold[0])
boxes.append(boxes_scale)
image_inds.append(image_inds_scale)
pick = batched_nms(boxes_scale[:, :4], boxes_scale[:, 4], image_inds_scale, 0.5)
scale_picks.append(pick + offset)
offset += boxes_scale.shape[0]
boxes = torch.cat(boxes, dim=0)
image_inds = torch.cat(image_inds, dim=0)
scale_picks = torch.cat(scale_picks, dim=0)
boxes, image_inds = boxes[scale_picks], image_inds[scale_picks]
pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
boxes, image_inds = boxes[pick], image_inds[pick]
regw = boxes[:, 2] - boxes[:, 0]
regh = boxes[:, 3] - boxes[:, 1]
qq1 = boxes[:, 0] + boxes[:, 5] * regw
qq2 = boxes[:, 1] + boxes[:, 6] * regh
qq3 = boxes[:, 2] + boxes[:, 7] * regw
qq4 = boxes[:, 3] + boxes[:, 8] * regh
boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0)
boxes = rerec(boxes)
return boxes, image_inds
def rnet_process(self, imgs, boxes, image_inds):
if self.net_name != 'rnet':
raise ValueError('Rnet process not support for {} !'.format(self.net))
imgs = imgs.permute(0, 3, 1, 2).type(torch.float32)
h, w = imgs.shape[2:4]
y, ey, x, ex = pad(boxes, w, h)
if len(boxes) > 0:
im_data = []
for k in range(len(y)):
if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
im_data.append(imresample(img_k, (24, 24)))
im_data = torch.cat(im_data, dim=0)
im_data = (im_data - 127.5) * 0.0078125
out = self.net.forward(im_data.cpu().numpy())
out = [torch.from_numpy(o) for o in out]
out0 = out[0].permute(1, 0)
out1 = out[1].permute(1, 0)
score = out1[1, :]
ipass = score > self.threshold[1]
boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
image_inds = image_inds[ipass]
mv = out0[:, ipass].permute(1, 0)
pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick]
boxes = bbreg(boxes, mv)
boxes = rerec(boxes)
return boxes, image_inds
def onet_process(self, imgs, boxes, image_inds):
if self.net_name != 'onet':
raise ValueError('Onet process not support for {} !'.format(self.net))
imgs = imgs.permute(0, 3, 1, 2).type(torch.float32)
h, w = imgs.shape[2:4]
points = torch.zeros(0, 5, 2, device=self.data_device)
if len(boxes) > 0:
y, ey, x, ex = pad(boxes, w, h)
im_data = []
for k in range(len(y)):
if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
im_data.append(imresample(img_k, (48, 48)))
im_data = torch.cat(im_data, dim=0)
im_data = (im_data - 127.5) * 0.0078125
out = self.net.forward(im_data.cpu().numpy())
out = [torch.from_numpy(o) for o in out]
out0 = out[0].permute(1, 0)
out1 = out[1].permute(1, 0)
out2 = out[2].permute(1, 0)
score = out2[1, :]
points = out1
ipass = score > self.threshold[2]
points = points[:, ipass]
boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
image_inds = image_inds[ipass]
mv = out0[:, ipass].permute(1, 0)
w_i = boxes[:, 2] - boxes[:, 0] + 1
h_i = boxes[:, 3] - boxes[:, 1] + 1
points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1
points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1
points = torch.stack((points_x, points_y)).permute(2, 1, 0)
boxes = bbreg(boxes, mv)
pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min')
boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick]
boxes = boxes.cpu().numpy()
points = points.cpu().numpy()
image_inds = image_inds.cpu()
batch_boxes = []
batch_points = []
for b_i in range(config.batch_size):
b_i_inds = np.where(image_inds == b_i)
batch_boxes.append(boxes[b_i_inds].copy())
batch_points.append(points[b_i_inds].copy())
batch_boxes, batch_points = np.array(batch_boxes), np.array(batch_points)
return batch_boxes, batch_points
def process_pnet(config):
loader = build_dataset(config)
processor = MTCNNPreprocessor(config)
out_json = {}
for idx, (xs, b_paths) in tqdm(enumerate(loader), total=len(loader)):
imgs = np.stack([np.uint8(x) for x in xs])
imgs = torch.as_tensor(imgs.copy(), device=torch.device('cpu'))
boxes, image_inds = processor.pnet_process(imgs)
out_json[str(idx)] = {
'boxes': boxes.tolist(),
'image_inds': image_inds.tolist()
}
save_path = os.path.join(config.output_path, 'pnet.json')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
dump_to_json(out_json, save_path)
def process_rnet(config):
loader = build_dataset(config)
processor = MTCNNPreprocessor(config)
out_json = {}
pnet_data = load_json(config.input_path)
for idx, (xs, b_paths) in tqdm(enumerate(loader), total=len(loader)):
imgs = np.stack([np.uint8(x) for x in xs])
imgs = torch.as_tensor(imgs.copy(), device=torch.device('cpu'))
boxes = torch.from_numpy(np.array(pnet_data[str(idx)]['boxes']))
image_inds = torch.from_numpy(np.array(pnet_data[str(idx)]['image_inds']))
boxes, image_inds = processor.rnet_process(imgs, boxes, image_inds)
out_json[str(idx)] = {
'boxes': boxes.tolist(),
'image_inds': image_inds.tolist()
}
save_path = os.path.join(config.output_path, 'rnet.json')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
dump_to_json(out_json, save_path)
def process_onet(config):
data_dir = config.data_dir
if data_dir[-1] == "/":
data_dir = data_dir[:-1]
loader = build_dataset(config)
processor = MTCNNPreprocessor(config)
pnet_data = load_json(config.input_path)
crop_paths = []
for idx, (xs, b_paths) in tqdm(enumerate(loader), total=len(loader)):
imgs = np.stack([np.uint8(x) for x in xs])
imgs = torch.as_tensor(imgs.copy(), device=torch.device('cpu'))
boxes = torch.from_numpy(np.array(pnet_data[str(idx)]['boxes']))
image_inds = torch.from_numpy(np.array(pnet_data[str(idx)]['image_inds']))
batch_boxes, batch_points = processor.onet_process(imgs, boxes, image_inds)
save_paths = [p.replace(data_dir, data_dir + '_split_om_cropped_{}'.format(config.batch_size)) for p in b_paths]
save_crop_imgs(batch_boxes, batch_points, xs, save_paths)
crop_paths.extend(save_paths)
save_path = os.path.join(config.output_path, 'onet.json')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
dump_to_json(crop_paths, save_path)
def save_crop_imgs(batch_boxes, batch_points, img, save_path):
mtcnn = MTCNN(
image_size=160, margin=14, min_face_size=20,
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
selection_method='center_weighted_size'
)
boxes, probs, points = [], [], []
for box, point in zip(batch_boxes, batch_points):
box = np.array(box)
point = np.array(point)
if len(box) == 0:
boxes.append(None)
probs.append([None])
points.append(None)
elif mtcnn.select_largest:
box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]
box = box[box_order]
point = point[box_order]
boxes.append(box[:, :4])
probs.append(box[:, 4])
points.append(point)
else:
boxes.append(box[:, :4])
probs.append(box[:, 4])
points.append(point)
batch_boxes = np.array(boxes)
batch_probs = np.array(probs)
batch_points = np.array(points)
batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
)
faces = mtcnn.extract(img, batch_boxes, save_path)
return faces
def build_config(arg):
pnet_config = {
'net': 'pnet',
'device_id': arg.device_id,
'output_path': './data/output/split_bs' + str(arg.batch_size) + '/',
'model_path': './weights/PNet_dynamic.om',
'data_dir': arg.data_dir,
'num_workers': 8,
'batch_size': arg.batch_size
}
rnet_config = {
'net': 'rnet',
'device_id': arg.device_id,
'input_path': './data/output/split_bs' + str(arg.batch_size) + '/pnet.json',
'output_path': './data/output/split_bs' + str(arg.batch_size) + '/',
'model_path': './weights/RNet_dynamic.om',
'data_dir': arg.data_dir,
'num_workers': 8,
'batch_size': arg.batch_size
}
onet_config = {
'net': 'onet',
'device_id': arg.device_id,
'input_path': './data/output/split_bs' + str(arg.batch_size) + '/rnet.json',
'output_path': './data/output/split_bs' + str(arg.batch_size) + '/',
'model_path': './weights/ONet_dynamic.om',
'data_dir': arg.data_dir,
'num_workers': 8,
'batch_size': arg.batch_size
}
return EasyDict(pnet_config), EasyDict(rnet_config), EasyDict(onet_config)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='[PNet/RNet/ONet]')
parser.add_argument('--data_dir', type=str, help='the absolute files path of lfw dataset')
parser.add_argument('--batch_size', default=1, type=int, help='[1/16]')
parser.add_argument('--device_id', default=0, type=int)
arg = parser.parse_args()
pnet_config, rnet_config, onet_config = build_config(arg)
if arg.model == 'Pnet':
config = pnet_config
process_pnet(config)
elif arg.model == 'Rnet':
config = rnet_config
process_rnet(config)
elif arg.model == 'Onet':
config = onet_config
process_onet(config)