import os
import sys
sys.path.append("./LResNet")
import argparse
import numpy as np
from tqdm import tqdm
import torch
from verifacation import evaluate
def l2_norm(input,axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
def evaluation_om(result_dir, target_list_path):
issame = np.load(target_list_path)
for root, dirs, files in os.walk(result_dir):
results_path = os.path.join(result_dir, dirs[-1])
break
result_ids = next(os.walk(results_path))
n = len(result_ids[-1])//2
embeddings = np.zeros([n, 512], dtype=np.float32)
for idx in tqdm(range(n)):
result_path = os.path.join(results_path, f'{idx}_0.txt')
result_flip_path = os.path.join(results_path, f'{idx}_flip_0.txt')
emb = np.loadtxt(result_path, dtype=np.float32) + np.loadtxt(result_flip_path, dtype=np.float32)
embeddings[idx:idx+1] = l2_norm(torch.tensor(emb).unsqueeze(0)).detach().numpy()
tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds=10)
print('*'*50)
print('accuracy: {}'.format(accuracy.mean()))
print('best_thresholds: {}'.format(best_thresholds.mean()))
print('*'*50)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--result', type=str, default="./result/dumpOutput_device0")
parser.add_argument('--data_path', type=str, default="./data/lfw_list.npy")
args = parser.parse_args()
evaluation_om(args.result, args.data_path)