import subprocess
import ast
import os
import argparse
def main():
parser = argparse.ArgumentParser(description='批量计算 SAM3 预测结果的 cgF1 指标')
parser.add_argument('--gt_base', type=str, default='./dataset/gt-annotations',
help='Ground truth 标注文件所在目录')
parser.add_argument('--pred_dir', type=str, default='./gold_predictions',
help='预测结果 JSON 文件所在目录')
args = parser.parse_args()
gt_base = args.gt_base
pred_dir = args.pred_dir
subsets = ["metaclip", "sa1b", "crowded", "fg_food", "fg_sports_equipment", "attributes", "wiki_common"]
versions = ["a", "b", "c"]
results = {}
for subset in subsets:
pred_file = os.path.join(pred_dir, f"predictions_{subset}.json")
if not os.path.exists(pred_file):
print(f"警告: 预测文件不存在 {pred_file}")
continue
gt_files = [os.path.join(gt_base, f"gold_{subset}_merged_{v}_release_test.json") for v in versions]
cmd = ["python", "standalone_cgf1.py", "--pred_file", pred_file, "--gt_files"] + gt_files
print(f"运行: {' '.join(cmd)}")
try:
proc = subprocess.run(cmd, capture_output=True, text=True, check=True)
output = proc.stdout.strip()
lines = output.split('\n')
dict_line = None
for line in reversed(lines):
line = line.strip()
if line.startswith('{') and line.endswith('}'):
dict_line = line
break
if dict_line:
data = ast.literal_eval(dict_line)
cgf1 = data.get('cgF1_eval_segm_cgF1')
if cgf1 is not None:
results[subset] = cgf1
print(f"{subset}: {cgf1:.6f}")
else:
print(f"{subset}: 未找到 cgF1 键")
else:
print(f"{subset}: 未找到 JSON 输出")
print(output[-500:])
except subprocess.CalledProcessError as e:
print(f"{subset}: 运行失败, 错误码 {e.returncode}")
print(e.stderr)
if results:
avg = sum(results.values()) / len(results)
print(f"\n平均 cgF1: {avg:.6f}")
else:
print("\n没有成功评估的子集")
if __name__ == '__main__':
main()