import os.path
from mindspeed_mm.data.datasets.mm_base_dataset import MMBaseDataset
class BaseGenEvalDataset(MMBaseDataset):
def __init__(self,
basic_param: dict,
**kwargs
):
super().__init__(**basic_param, **kwargs)
def __getitem__(self, index):
return self.data_samples[index]
class DimCounter:
def __init__(self):
self.cnt = {}
def count(self, dims):
for dim in set(dims):
self.cnt[dim] = self.cnt.get(dim, 0) + 1
def min_count(self, dims):
return min(self.cnt.get(dim, 0) for dim in set(dims))
vbench_dim_proj = {
"subject_consistency": "subject_consistency",
"background_consistency": "scene",
"aesthetic_quality": "overall_consistency",
"imaging_quality": "overall_consistency",
"object_class": "object_class",
"multiple_objects": "multiple_objects",
"color": "color",
"spatial_relationship": "spatial_relationship",
"scene": "scene",
"temporal_style": "temporal_style",
"overall_consistency": "overall_consistency",
"human_action": "human_action",
"temporal_flickering": "temporal_flickering",
"motion_smoothness": "subject_consistency",
"dynamic_degree": "subject_consistency",
"appearance_style": "appearance_style"
}
def prepare_dims(dimensions):
if dimensions is None or len(dimensions) == 0:
return []
dims = set(dimensions)
for dim in set(dimensions):
if dim in vbench_dim_proj:
dims.add(vbench_dim_proj[dim])
return list(dims)
class VbenchGenEvalDataset(BaseGenEvalDataset):
def __init__(self,
basic_param: dict,
extra_param: dict,
dimensions: list = None,
):
super().__init__(basic_param)
self.dimensions = prepare_dims(dimensions)
self.augment = extra_param.get("augment", False)
self.prompts_per_dim = extra_param.get("prompts_per_dim", 0)
self.samples_per_prompt = extra_param.get("samples_per_prompt", 5)
self.prompt_file = extra_param.get("prompt_file", "all_dimension.txt")
self.augmented_prompt_file = extra_param.get("augmented_prompt_file",
"augmented_prompts/gpt_enhanced_prompts/all_dimension_longer.txt")
self.captions = []
self.augmented_captions = []
if self.dimensions is not None and len(self.dimensions) > 0:
prompt_files = [os.path.join("prompts_per_dimension", f"{dim}.txt") for dim in self.dimensions]
for prompt_file in prompt_files:
if not os.path.exists(os.path.join(self.data_folder, prompt_file)):
continue
if self.prompts_per_dim > 0:
self.captions += self.get_data(os.path.join(self.data_folder, prompt_file))[:self.prompts_per_dim]
else:
self.captions += self.get_data(os.path.join(self.data_folder, prompt_file))
if self.augment:
augmented_prompt_files = [
os.path.join("augmented_prompts/gpt_enhanced_prompts/prompts_per_dimension_longer",
f"{dim}_longer.txt") for dim in self.dimensions
]
for prompt_file in augmented_prompt_files:
if not os.path.exists(os.path.join(self.data_folder, prompt_file)):
continue
if self.prompts_per_dim > 0:
self.augmented_captions += self.get_data(os.path.join(self.data_folder, prompt_file))[
:self.prompts_per_dim]
else:
self.augmented_captions += self.get_data(os.path.join(self.data_folder, prompt_file))
else:
self.captions += self.get_data(os.path.join(self.data_folder, self.prompt_file))
if self.augment:
self.augmented_captions += self.get_data(os.path.join(self.data_folder, self.augmented_prompt_file))
def prepare_item(self):
return {
"caption": "",
"prefix": "",
}
def __getitem__(self, index):
caption_index = index // self.samples_per_prompt
sample_index = index % self.samples_per_prompt
item = self.prepare_item()
item["prefix"] = self.captions[caption_index] + f"-{sample_index}"
item["caption"] = self.augmented_captions[caption_index] if self.augment else self.captions[caption_index]
return item
def __len__(self):
return len(self.captions) * self.samples_per_prompt
class VbenchI2VGenEvalDataset(BaseGenEvalDataset):
def __init__(self,
basic_param: dict,
extra_param: dict,
dimensions: list,
):
super().__init__(basic_param)
self.ratio = extra_param.get("ratio", "16-9")
self.prompts_per_dim = extra_param.get("prompts_per_dim", 0)
self.samples_per_prompt = extra_param.get("samples_per_prompt", 5)
self.filter_by_dimension(dimensions)
def prepare_item(self):
return {
"caption": "",
"prefix": "",
"image": "",
}
def filter_by_dimension(self, dimensions):
if dimensions is None or len(dimensions) == 0:
return
dims = set(dimensions)
new_data_samples = []
cnt = DimCounter()
for sample in self.data_samples:
if any(dim in dims for dim in sample.get("dimension")):
if 0 < self.prompts_per_dim <= cnt.min_count(sample.get("dimension")):
continue
new_data_samples += [sample]
cnt.count(sample.get("dimension"))
self.data_samples = new_data_samples
def __getitem__(self, index):
caption_index = index // self.samples_per_prompt
sample_index = index % self.samples_per_prompt
item = self.prepare_item()
data = self.data_samples[caption_index]
item["caption"] = data.get("prompt_en")
item["prefix"] = data.get("prompt_en") + f"-{sample_index}"
item["image"] = os.path.join(self.data_folder, "crop", self.ratio, data.get("image_name"))
return item
def __len__(self):
return len(self.data_samples) * self.samples_per_prompt