import os
import sys
import pickle
import argparse
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms as trans
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch import distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel
from verifacation import evaluate
from model import Backbone, l2_norm
class FinetuneDataset(Dataset):
def __init__(self, dataset_folder, transform):
self.ids = next(os.walk(dataset_folder))[1]
self.label_path = os.path.join(dataset_folder, 'label.txt')
self.transform = transform
self.issame_list = self.prepare_labels(self.label_path)
self.img_paths = self.prepare_imgs(dataset_folder)
def prepare_labels(self, label_path):
with open(label_path, 'r', encoding='utf-8') as f:
data_label_list = f.readlines()
data_label_list = [i.strip().split(' ') for i in data_label_list if len(i) > 4]
return data_label_list
def prepare_imgs(self, dataset_folder):
img_paths = []
for idx in self.ids:
img_ids = next(os.walk(os.path.join(dataset_folder, idx)))[2]
for img_id in img_ids:
img_paths.append(os.path.join(dataset_folder, idx, img_id))
return img_paths
def __len__(self):
return len(self.img_paths)
def __getitem__(self, item):
sample = self.img_paths[item]
img = Image.open(sample)
img = img.convert('RGB')
img = self.transform(img)
return item, img
class LFWDataset(Dataset):
def __init__(self, lfw_bin_path, transform):
self.bins, self.issame_list = pickle.load(open(lfw_bin_path, 'rb'), encoding='bytes')
self.transform = transform
def __len__(self):
return len(self.bins)
def __getitem__(self, item):
sample = self.bins[item]
img_np_arr = np.frombuffer(sample, np.uint8)
img = cv2.imdecode(img_np_arr, cv2.IMREAD_COLOR)
img = Image.fromarray(img.astype(np.uint8))
img = self.transform(img)
return item, img
def build_model(args, device):
model = Backbone(num_layers=args.net_depth, drop_ratio=0.6, mode=args.net_mode)
ckpt = torch.load(args.weights, map_location='cpu')
if 'model' in ckpt:
ckpt = ckpt['model']
model.load_state_dict(ckpt)
model = model.to(device)
model.eval()
if args.distributed:
model = DistributedDataParallel(model, device_ids=[args.device_id], broadcast_buffers=False)
return model
def build_data_loader(args):
transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
if args.finetune:
dataset = FinetuneDataset(args.data_path, transform)
else:
dataset = LFWDataset(args.data_path, transform)
embeddings = torch.zeros([dataset.__len__(), 512], dtype=torch.float32)
if args.distributed:
ds_sample = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, sampler=ds_sample)
else:
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
return loader, embeddings, dataset.issame_list
@torch.no_grad()
def evaluation(args, model, loader, embeddings, issame_list, device):
"""
args
model
loader: lfw data loader
embeddings: empty result
issame_list: label
device
"""
embeddings = embeddings.to(device)
for idx, img in loader:
img_flip = torch.flip(img, dims=[3])
img, img_flip = img.to(device), img_flip.to(device)
emb_batch = model(img) + model(img_flip)
outputs = l2_norm(emb_batch).detach()
embeddings[idx] = outputs
if args.distributed:
embed_gather_list = [torch.zeros_like(embeddings) for _ in range(args.world_size)]
dist.all_gather(embed_gather_list, embeddings)
embeddings = embed_gather_list[0]
for i in range(1, args.world_size):
embeddings += embed_gather_list[i]
embeddings = embeddings.cpu().numpy()
if args.is_master_node:
tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame_list, nrof_folds=10)
print('*'*50)
print('lfw_accuracy: {}'.format(accuracy.mean()))
print('best_thresholds: {}'.format(best_thresholds.mean()))
print('*'*50)
def prepare_parser():
parser = argparse.ArgumentParser(description='evaluation')
parser.add_argument("--net_mode", help="which network, [ir, ir_se, mobilefacenet]", default='ir_se', type=str)
parser.add_argument("--net_depth", help="how many layers [50,100,152]", default=100, type=int)
parser.add_argument("--weights", help="weights path name", default='./work_space/save/model_ir_se100.pth', type=str)
parser.add_argument("--data_path", help="lfw bin data path", default='./data/faces_emore/lfw.bin', type=str)
parser.add_argument("--batch_size", help="eval batch size", default=512, type=int)
parser.add_argument("--num_workers", help="num of workers", default=8, type=int)
parser.add_argument("--finetune", help="if finetune dataset", default=0, type=int)
parser.add_argument("--device_type", help="device_type choice in [npu gpu]", default='npu', type=str)
parser.add_argument("--device_id", help="device_id", default=0, type=int)
parser.add_argument("--distributed", help="is distributed evaluation", default=1, type=int)
parser.add_argument("--backend", help="", default='nccl', type=str)
parser.add_argument("--dist_url", help="", default='127.0.0.1:41111', type=str)
parser.add_argument("--gpus", help="number of gpus per node", default=1, type=int)
parser.add_argument("--dist_rank", help="node rank for distributed training", default=0, type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = prepare_parser()
if args.device_type == 'gpu':
device = torch.device(f"cuda:{args.device_id}")
elif args.device_type == 'npu':
device = torch.device(f"npu:{args.device_id}")
torch.npu.set_device(device)
else:
raise ValueError('device type error,please choice in ["gpu","npu"]')
args.is_master_node = not args.distributed or args.device_id == 0
addr, port = args.dist_url.split(':')
os.environ['MASTER_ADDR'] = addr
os.environ['MASTER_PORT'] = port
if 'RANK_SIZE' in os.environ:
args.rank_size = int(os.environ['RANK_SIZE'])
args.rank = args.dist_rank * args.rank_size + args.device_id
args.world_size = args.gpus * args.rank_size
args.batch_size = int(args.batch_size / args.rank_size)
else:
raise RuntimeError("init_distributed_mode failed.")
torch.distributed.init_process_group(backend=args.backend, init_method="env://",
world_size=args.world_size, rank=args.rank)
model = build_model(args, device)
dataloader, embeddings, issame_list = build_data_loader(args)
evaluation(args, model, dataloader, embeddings, issame_list, device)