import multiprocessing as mp
from functools import partial
import numpy as np
import pandas as pd
import torch.distributed as dist
from mindspeed_mm.tasks.evaluation.eval_impl.impl_base import BaseEvalImpl
from mindspeed_mm.tasks.evaluation.utils.analysis_utils import hit_calculate, process_line
from mindspeed_mm.tasks.evaluation.utils.string_utils import is_list_in_str, dict2dataframe, logger_rank_0
from mindspeed_mm.utils.security_utils.input_filter import sanitize_dataframe
class VQAEvalImpl(BaseEvalImpl):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self):
super().__call__()
self.gather_result()
self.analyse_result()
def analyse_result(self):
if self.world_size > 0:
dist.barrier()
if dist.get_rank() == 0:
data = pd.read_excel(self.result_path)
dataset = self.dataset_name
if 'answer' not in data or 'prediction' not in data:
raise ValueError("Data must contain both 'answer' and 'prediction' keys.")
data['prediction'] = [str(x) for x in data['prediction']]
data['answer'] = [str(x) for x in data['answer']]
lt = len(data)
pool = mp.Pool(1)
lines = [data.iloc[i] for i in range(lt)]
if is_list_in_str(['chartqa'], dataset):
res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
elif is_list_in_str(['docvqa', 'infovqa'], dataset):
res = pool.map(partial(process_line, method='anls'), lines)
else:
res = pool.map(process_line, lines)
pool.close()
hit = hit_calculate(res, dataset)
ret = dict()
if 'split' in data:
splits = set(data['split'])
for sp in splits:
sub = [result for line, result in zip(lines, res) if line['split'] == sp]
hit = hit_calculate(sub, dataset)
ret[sp] = np.mean(hit) * 100
sub = [result for line, result in zip(lines, res)]
hit = hit_calculate(sub, dataset)
ret['Overall'] = np.mean(hit) * 100
else:
ret['Overall'] = np.mean(hit) * 100
if 'category' in data:
cates = list(set(data['category']))
cates.sort()
for c in cates:
sub = [result for line, result in zip(lines, res) if line['category'] == c]
hit = hit_calculate(sub, dataset)
ret[c] = np.mean(hit) * 100
ret = dict2dataframe(ret)
ret.round(2)
suffix = self.result_path.split('.')[-1]
result_file = self.result_path.replace(f'.{suffix}', '_acc.xlsx')
cleaned_ret = sanitize_dataframe(ret)
cleaned_ret.to_excel(result_file, index=False, engine='xlsxwriter')
logger_rank_0(f"save acc file to {result_file}")
if self.world_size > 0:
dist.barrier()