BBenedikt Schiffereradding license
780d274b创建于 2025年6月28日历史提交
# --------------------------------------------------------
# Copyright (c) 2025 NVIDIA
# Licensed under customized NSCLv1 [see LICENSE.md for details]
# --------------------------------------------------------

import os
import sys

import argparse
import json
import torch
import types
import pandas as pd

from typing import Annotated, Dict, List, Optional, cast

from datasets import load_dataset
from tqdm import tqdm

from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA, ViDoReEvaluatorBEIR
from vidore_benchmark.evaluation.interfaces import MetadataModel, ViDoReBenchmarkResults
from vidore_benchmark.utils.data_utils import get_datasets_from_collection
from typing import List, Optional, Union

from datetime import datetime
from importlib.metadata import version

import torch
from transformers import AutoModel

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        help='Path to model checkpoint if HF',
        default=''
    )
    parser.add_argument(
        '--model_revision',
        type=str,
        help='Commit Hash of the model as custom code is downloaded and executed',
        default=None
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        help='Batch Size',
        default=1
    )
    parser.add_argument(
        '--savedir_datasets',
        type=str,
        help='Path to save results',
        default='./default/'
    )
    args, extra_args = parser.parse_known_args()

    def convert_value(value):
        if value.replace('.', '', 1).isdigit():  # Check if it's a number (int or float)
            return int(value) if '.' not in value else float(value)
        return value  # Keep as string if not numeric

    # Convert extra_args list to dictionary with proper type conversion
    extra_args_dict = {extra_args[i].lstrip('-'): convert_value(extra_args[i + 1]) 
                    for i in range(0, len(extra_args), 2)}
    
    return args, extra_args_dict

if __name__ == "__main__":
    args, add_args = get_args()
    batch_size = int(args.batch_size)
    savedir_datasets = args.savedir_datasets

    if not os.path.exists(savedir_datasets):
        os.makedirs(savedir_datasets)
    
    vision_retriever = AutoModel.from_pretrained(
        args.model_name_or_path,
        device_map='cuda',
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        revision=args.model_revision
    ).eval()
    
    vidore_evaluator_qa = ViDoReEvaluatorQA(vision_retriever) # ViDoRe-v1
    vidore_evaluator_beir = ViDoReEvaluatorBEIR(vision_retriever) # ViDoRe-v2

    vidore_v2_original_commits = {
        "vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88",
        "vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317",
        "vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9",
        "vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da",
        "vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3",
        "vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1",
        "vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2",
    }
    
    metrics_all: Dict[str, Dict[str, Optional[float]]] = {}
    results_all: List[ViDoReBenchmarkResults] = []  # same as metrics_all but structured + with metadata
    
    # Evaluate ViDoRe V1 with QA Datasets
    dataset_names = get_datasets_from_collection("vidore/vidore-benchmark-667173f98e70a1c0fa4db00d")
    for dataset_name in tqdm(dataset_names, desc="Evaluating dataset(s)"):
        sanitized_dataset_name = dataset_name.replace("/", "_")
        savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json"
        if os.path.isfile(savepath_results):
            saved_results = json.load(open(savepath_results, 'r'))
            metrics = saved_results['metrics']
            results = ViDoReBenchmarkResults(
                metadata=MetadataModel(
                    timestamp=saved_results['metadata']['timestamp'],
                    vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'],
                ),
                metrics=saved_results['metrics'],
            )
        else:
            metrics = {dataset_name: vidore_evaluator_qa.evaluate_dataset(
                ds=load_dataset(dataset_name, split="test"),
                batch_query=batch_size,
                batch_passage=batch_size,
                batch_score=128,
                dataloader_prebatch_query=512,
                dataloader_prebatch_passage=512,
            )}
            results = ViDoReBenchmarkResults(
                metadata=MetadataModel(
                    timestamp=datetime.now(),
                    vidore_benchmark_version=version("vidore_benchmark"),
                ),
                metrics={dataset_name: metrics[dataset_name]},
            )
            with open(str(savepath_results), "w", encoding="utf-8") as f:
                f.write(results.model_dump_json(indent=4))
                
        metrics_all.update(metrics)
        print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}")
        results_all.append(results)

    original_commits = {
        "vidore/synthetic_rse_restaurant_filtered_v1.0_multilingual": "c05a1da867bbedebef239d4aa96cab19160b3d88",
        "vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered_multilingual": "9daa25abc1026f812834ca9a6b48b26ecbc61317",
        "vidore/synthetics_economics_macro_economy_2024_filtered_v1.0_multilingual": "909aa23589332c30d7c6c9a89102fe2711cbb7a9",
        "vidore/restaurant_esg_reports_beir": "d8830ba2d04b285cfb2532b95be3748214e305da",
        "vidore/synthetic_rse_restaurant_filtered_v1.0": "4e52fd878318adb8799d0b6567f1134b3985b9d3",
        "vidore/synthetic_economics_macro_economy_2024_filtered_v1.0": "b6ff628a0b3c49f074abdcc86d29bc0ec21fd0c1",
        "vidore/synthetic_mit_biomedical_tissue_interactions_unfiltered": "c1b889b051113c41e32960cd6b7c5ba5b27e39e2",
    }

    for dataset_name, revision in vidore_v2_original_commits.items():
        sanitized_dataset_name = dataset_name.replace("/", "_")
        savepath_results = savedir_datasets + f"/{sanitized_dataset_name}_metrics.json"
        if os.path.isfile(savepath_results):
            saved_results = json.load(open(savepath_results, 'r'))
            metrics = saved_results['metrics']
            results = ViDoReBenchmarkResults(
                metadata=MetadataModel(
                    timestamp=saved_results['metadata']['timestamp'],
                    vidore_benchmark_version=saved_results['metadata']['vidore_benchmark_version'],
                ),
                metrics=saved_results['metrics'],
            )
        else:
            ds = {
                "corpus": load_dataset(dataset_name, name="corpus", split="test", revision=revision),
                "queries": load_dataset(dataset_name, name="queries", split="test", revision=revision),
                "qrels": load_dataset(dataset_name, name="qrels", split="test", revision=revision)
            }
            metrics = {dataset_name: vidore_evaluator_beir.evaluate_dataset(
                ds=ds,
                batch_query=batch_size,
                batch_passage=batch_size,
                batch_score=128,
                dataloader_prebatch_query=512,
                dataloader_prebatch_passage=512,
            )}
            results = ViDoReBenchmarkResults(
                metadata=MetadataModel(
                    timestamp=datetime.now(),
                    vidore_benchmark_version=version("vidore_benchmark"),
                ),
                metrics={dataset_name: metrics[dataset_name]},
            )
            with open(str(savepath_results), "w", encoding="utf-8") as f:
                f.write(results.model_dump_json(indent=4))

        metrics_all.update(metrics)
        print(f"nDCG@5 on {dataset_name}: {metrics[dataset_name]['ndcg_at_5']}")
        results_all.append(results)

    results_merged = ViDoReBenchmarkResults.merge(results_all)
    savepath_results_merged = savedir_datasets + f"/merged_metrics.json"

    with open(str(savepath_results_merged), "w", encoding="utf-8") as f:
        f.write(results_merged.model_dump_json(indent=4))