import os
import sys
from collections import defaultdict
import torch
import numpy as np
from tqdm import tqdm
from ECAPA_TDNN.main import inference_embeddings_to_plt_hist_and_roc
if __name__ == "__main__":
result_path = sys.argv[1]
speakers_path = sys.argv[2]
batch_size = 1
total_nums = 4648
embedding_holder = defaultdict(list)
for i in tqdm(range(total_nums)):
index = i+1
speakers = torch.load(os.path.join(speakers_path, f"speakers{index}.pt"))
bin_file_path = os.path.join(result_path, f"mels{index}_0.bin")
batch = np.fromfile(bin_file_path, dtype='float32').reshape(-1, 192)
h_tensor = torch.Tensor(batch)
for h, s in zip(h_tensor.detach().cpu(), speakers):
embedding_holder[s.item()].append(h.numpy())
infer_hist, infer_roc, scores = inference_embeddings_to_plt_hist_and_roc(embedding_holder, 88600)
tp, tn, roc_auc = scores
print("roc_auc: ")
print(roc_auc)