from __future__ import absolute_import, division, print_function
import os
import argparse
import numpy as np
def om_post(ar):
res_dir=ar.result_dir
labels = np.load(ar.label_path)
yes = 0
for i in range(len(labels)):
res_path = os.path.join(res_dir, f'{i}_0.npy')
res = np.load(res_path).flatten()
label = labels.flatten()[i]
if (res[1] > res[0] and label > 0.5) or (res[1] < res[0] and label < 0.5):
yes += 1
print("acc = {:.3f}".format(yes / len(labels)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--result_dir", default="", type=str, required=True,
help="infer result dir.")
parser.add_argument("--label_path", default="", type=str, required=True,
help="path for gt label.")
ar = parser.parse_args()
om_post(ar)
if __name__ == "__main__":
main()