# Copyright 2026 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import subprocess

import ast

import os

import argparse





def main():

    parser = argparse.ArgumentParser(description='批量计算 SAM3 预测结果的 cgF1 指标')

    parser.add_argument('--gt_base', type=str, default='./dataset/gt-annotations',

                        help='Ground truth 标注文件所在目录')

    parser.add_argument('--pred_dir', type=str, default='./gold_predictions',

                        help='预测结果 JSON 文件所在目录')

    args = parser.parse_args()



    gt_base = args.gt_base

    pred_dir = args.pred_dir



    subsets = ["metaclip", "sa1b", "crowded", "fg_food", "fg_sports_equipment", "attributes", "wiki_common"]

    versions = ["a", "b", "c"]



    results = {}

    for subset in subsets:

        pred_file = os.path.join(pred_dir, f"predictions_{subset}.json")

        if not os.path.exists(pred_file):

            print(f"警告: 预测文件不存在 {pred_file}")

            continue

        gt_files = [os.path.join(gt_base, f"gold_{subset}_merged_{v}_release_test.json") for v in versions]

        cmd = ["python", "standalone_cgf1.py", "--pred_file", pred_file, "--gt_files"] + gt_files

        print(f"运行: {' '.join(cmd)}")

        try:

            proc = subprocess.run(cmd, capture_output=True, text=True, check=True)

            output = proc.stdout.strip()

            # 查找最后一行以 '{' 开头 '}' 结尾的字典

            lines = output.split('\n')

            dict_line = None

            for line in reversed(lines):

                line = line.strip()

                if line.startswith('{') and line.endswith('}'):

                    dict_line = line

                    break

            if dict_line:

                data = ast.literal_eval(dict_line)   # 安全解析 Python 字典

                cgf1 = data.get('cgF1_eval_segm_cgF1')

                if cgf1 is not None:

                    results[subset] = cgf1

                    print(f"{subset}: {cgf1:.6f}")

                else:

                    print(f"{subset}: 未找到 cgF1 键")

            else:

                print(f"{subset}: 未找到 JSON 输出")

                # 调试:打印最后几行

                print(output[-500:])

        except subprocess.CalledProcessError as e:

            print(f"{subset}: 运行失败, 错误码 {e.returncode}")

            print(e.stderr)



    if results:

        avg = sum(results.values()) / len(results)

        print(f"\n平均 cgF1: {avg:.6f}")

    else:

        print("\n没有成功评估的子集")



if __name__ == '__main__':

    main()