import argparse
import os
from insightface.app import FaceAnalysis
import numpy as np
import cv2


def parse_pairs_file(pairs_file_path):
    """
    解析LFW的pairs.txt文件,生成图像对路径列表
    
    Args:
        pairs_file_path: pairs.txt文件路径
        
    Returns:
        img_pairs: 包含图像对路径的列表,格式为[(path1, path2, label), ...]
                  label=1表示同一人,label=0表示不同人
    """
    img_pairs = []
    
    with open(pairs_file_path, 'r') as f:
        # 跳过第一行(通常是"10 300")
        next(f)
        
        for line in f:
            parts = line.strip().split()
            
            # 处理匹配对(同一人)
            if len(parts) == 3:
                name = parts[0]
                idx1 = int(parts[1])
                idx2 = int(parts[2])
                # 构建文件名(确保4位数字格式)
                img1 = f"{name}_{idx1:04d}.jpg"
                img2 = f"{name}_{idx2:04d}.jpg"
                img_pairs.append((name, img1, name, img2, 1))
            
            # 处理不匹配对(不同人)
            elif len(parts) == 4:
                name1 = parts[0]
                idx1 = int(parts[1])
                name2 = parts[2]
                idx2 = int(parts[3])
                # 构建文件名
                img1 = f"{name1}_{idx1:04d}.jpg"
                img2 = f"{name2}_{idx2:04d}.jpg"
                img_pairs.append((name1, img1, name2, img2, 0))
    
    return img_pairs


def load_lfw_images(pairs_file_path, lfw_dir):
    """
    根据pairs.txt加载LFW图像
    
    Args:
        pairs_file_path: pairs.txt文件路径
        lfw_dir: LFW数据集根目录(包含按人名组织的子文件夹)
        
    Returns:
        images1: 第一张图像列表
        images2: 第二张图像列表
        labels: 标签列表(1=同一人,0=不同人)
    """
    # 解析pairs文件
    img_pairs = parse_pairs_file(pairs_file_path)
    
    images1 = []
    images2 = []
    labels = []
    
    # 读取图像
    for name1, img1_name, name2, img2_name, label in img_pairs:
        # 构建完整路径
        path1 = os.path.join(lfw_dir, name1, img1_name)
        path2 = os.path.join(lfw_dir, name2, img2_name)
        
        # 读取图像(使用OpenCV)
        img1 = cv2.imread(path1)
        img2 = cv2.imread(path2)
        
        # 检查图像是否成功加载
        if img1 is None or img2 is None:
            print(f"警告:无法加载图像 {path1}{path2}")
            continue

        images1.append(img1)
        images2.append(img2)
        labels.append(label)
    
    return images1, images2, labels


def evaluate_lfw_recall(model, pairs_file, lfw_dir):
    face_pairs = []
    detected_labels = []
    num_detected_faces = 0

    images1, images2, labels = load_lfw_images(pairs_file, lfw_dir)
    for img1, img2, label in zip(images1, images2, labels):
        # 提取特征向量
        face1 = model.get(img1)
        face2 = model.get(img2)
        # 两张图片均检测到人脸
        if len(face1) != 0 and len(face2) != 0:
            num_detected_faces += 2
            face_pairs.append((face1, face2))
            detected_labels.append(label)
        # 只有一张图片检测到人脸
        elif len(face1) != 0 or len(face2) != 0:
            num_detected_faces += 1

    recall = num_detected_faces / (len(images1) + len(images2))
    return recall, face_pairs, detected_labels


def evaluate_lfw_accuracy(model, face_pairs, labels, threshold=0.363):
    # 计算相似度
    similarities = []
    recognition_model = model.models['recognition']
    for face1, face2 in face_pairs:
        emb1 = face1[0].normed_embedding
        emb2 = face2[0].normed_embedding
        # 计算余弦相似度
        sim = recognition_model.compute_sim(emb1, emb2)
        similarities.append(sim)
    
    # 计算准确率
    predictions = [1 if s > threshold else 0 for s in similarities]
    accuracy = sum([prediction == label for prediction, label in zip(predictions, labels)]) / len(labels)
    
    return accuracy

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='FaceAnalysis')
    parser.add_argument('--device', default=0)
    args = parser.parse_args()
    device = int(args.device)
    app = FaceAnalysis(device=device)
    app.prepare(ctx_id=0, det_size=(640, 640))
    recall, face_pairs, detected_labels = evaluate_lfw_recall(app, "./data/pairs.txt", "./data/lfw")
    accuracy = evaluate_lfw_accuracy(app, face_pairs, detected_labels)
    print(f"LFW验证召回率:{recall * 100:.2f}%, LFW 验证准确率: {accuracy * 100:.2f}%")