import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from mindspeed_mm.tasks.evaluation.utils.string_utils import string_to_list, is_expected_type
from mindspeed_mm.tasks.evaluation.utils.file_utils import is_valid_image, decode_base64_to_image_file
datasets_type = {"ai2d_test": "MCQ", "docvqa_val": "VQA", "chartqa_test": "VQA", "mmmu_dev_val": "MCQ"}
class BaseEvalDataset(Dataset):
def __init__(self, dataset_path, dataset_name):
self.meta_only = True
self.dataset_name = dataset_name
self.dataset_type = None
self.image_path = os.path.join(os.path.split(dataset_path)[0], dataset_name, 'images')
data = self.prepare_tsv(dataset_path)
self.data = self.prepare_image(data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return dict(self.data.iloc[idx])
@staticmethod
def prepare_tsv(data_path):
data = pd.read_csv(data_path, sep='\t')
data = data[~pd.isna(data['image'])]
data['index'] = [str(x) for x in data['index']]
return data
def prepare_image(self, data):
if 'image' in data:
data['image'] = [str(x) for x in data['image']]
image_map = {x: y for x, y in zip(data['index'], data['image'])}
for k in image_map:
if len(image_map[k]) <= 64:
idx = image_map[k]
if idx not in image_map or len(image_map.get(idx, "")) <= 64:
raise ValueError(
f"Key {k} maps to a value of length {len(image_map[k])}, but the target key {idx} "
f"is either not found or has a value of length {len(image_map.get(idx, ''))}.")
image_map[k] = image_map[idx]
images = [string_to_list(image_map[k]) for k in data['index']]
data['image'] = [x[0] if len(x) == 1 else x for x in images]
self.meta_only = False
if 'image_path' in data:
paths = [string_to_list(x) for x in data['image_path']]
data['image_path'] = [x[0] if len(x) == 1 else x for x in paths]
if np.all([is_expected_type(x, int) for x in data['index']]):
data['index'] = [int(x) for x in data['index']]
return data
def dump_image(self, line):
os.makedirs(self.image_path, exist_ok=True)
if 'image' in line:
if isinstance(line['image'], list):
tgt_path = []
if 'image_path' not in line:
raise ValueError("The required key 'image_path' is missing from the provided data.")
for img, im_name in zip(line['image'], line['image_path']):
path = os.path.join(self.image_path, im_name)
if not is_valid_image(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = os.path.join(self.image_path, f"{line['index']}.jpg")
if not is_valid_image(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
tgt_path = [tgt_path]
else:
if 'image_path' not in line:
raise AssertionError("Either image or image_path must be non-empty.")
tgt_path = string_to_list(line['image_path'])
return tgt_path
def build_prompt(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
if self.meta_only:
tgt_path = string_to_list(line['image_path'])
else:
tgt_path = self.dump_image(line)
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]
msgs.append(dict(type='text', value=question))
return msgs