import os
import random
from typing import Union
import copy
import torch
from mindspeed_mm.data.data_utils.constants import (
CAPTIONS,
FILE_INFO,
IMG_FPS,
PROMPT_IDS,
PROMPT_MASK,
TEXT,
VIDEO,
MASKED_VIDEO
)
from mindspeed_mm.data.data_utils.data_transform import (
MaskGenerator,
add_aesthetic_notice_image,
add_aesthetic_notice_video,
)
from mindspeed_mm.data.data_utils.transform_pipeline import get_transforms
from mindspeed_mm.data.datasets.t2v_dataset import T2VDataset
from mindspeed_mm.utils.mask_utils import STR_TO_TYPE, MaskProcessor
def type_ratio_normalize(mask_type_ratio_dict):
total = sum(mask_type_ratio_dict.values())
length = len(mask_type_ratio_dict)
if total == 0:
return {k: 1.0 / length for k in mask_type_ratio_dict.keys()}
return {k: v / total for k, v in mask_type_ratio_dict.items()}
I2VOutputData = {
VIDEO: [],
MASKED_VIDEO: [],
TEXT: [],
PROMPT_IDS: [],
PROMPT_MASK: [],
}
class I2VDataset(T2VDataset):
def __init__(
self,
basic_param: dict,
vid_img_process: dict,
use_text_processer: bool = False,
use_clean_caption: bool = True,
support_chinese: bool = False,
tokenizer_config: Union[dict, None] = None,
vid_img_fusion_by_splicing: bool = False,
use_img_num: int = 0,
use_img_from_vid: bool = True,
mask_type_ratio_dict_video: Union[dict, None] = None,
mask_type_ratio_dict_image: Union[dict, None] = None,
default_text_ratio: float = 0.5,
min_clear_ratio: float = 0.0,
max_clear_ratio: float = 1.0,
**kwargs,
):
if vid_img_process.get("num_frames") != 1:
self.mask_type_ratio_dict_video = mask_type_ratio_dict_video if mask_type_ratio_dict_video is not None else {
'i2v': 1.0}
self.mask_type_ratio_dict_video = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_video.items()}
self.mask_type_ratio_dict_video = type_ratio_normalize(self.mask_type_ratio_dict_video)
self.mask_type_ratio_dict_image = mask_type_ratio_dict_image if mask_type_ratio_dict_image is not None else {
'clear': 1.0}
self.mask_type_ratio_dict_image = {STR_TO_TYPE[k]: v for k, v in self.mask_type_ratio_dict_image.items()}
self.mask_type_ratio_dict_image = type_ratio_normalize(self.mask_type_ratio_dict_image)
self.mask_processor = MaskProcessor(
max_height=vid_img_process.get("max_height"),
max_width=vid_img_process.get("max_width"),
min_clear_ratio=min_clear_ratio,
max_clear_ratio=max_clear_ratio,
)
self.train_pipeline_after_resize = vid_img_process.pop("train_pipeline_after_resize", None)
self.video_transforms_after_resize = get_transforms(is_video=True, train_pipeline=self.train_pipeline_after_resize)
self.image_transforms_after_resize = get_transforms(
is_video=False, train_pipeline=self.train_pipeline_after_resize
)
self.default_text_ratio = default_text_ratio
super().__init__(
basic_param=basic_param,
vid_img_process=vid_img_process,
use_text_processer=use_text_processer,
use_clean_caption=use_clean_caption,
support_chinese=support_chinese,
tokenizer_config=tokenizer_config,
vid_img_fusion_by_splicing=vid_img_fusion_by_splicing,
use_img_num=use_img_num,
use_img_from_vid=use_img_from_vid,
**kwargs,
)
def getitem(self, index):
examples = copy.deepcopy(I2VOutputData)
if self.data_storage_mode == "combine":
sample = self.data_samples[index]
file_path = sample["path"]
texts = sample["cap"]
elif self.data_storage_mode == "standard":
sample = self.data_samples[index]
file_path, texts = sample[FILE_INFO], sample[CAPTIONS]
if self.data_folder:
file_path = os.path.join(self.data_folder, file_path)
else:
raise NotImplementedError(
f"Not support now: data_storage_mode={self.data_storage_mode}."
)
file_type = self.get_type(file_path)
if file_type == "image":
video_value = self.image_processer(file_path)
video_value = video_value.transpose(0, 1)
transforms_after_resize = self.image_transforms_after_resize
elif file_type == "video":
vframes = self.video_reader(file_path)
video_value = self.video_processer(vframes=vframes, **sample)
if self.vid_img_fusion_by_splicing:
video_value = self.get_vid_img_fusion(video_value)
video_value = video_value.permute(1, 0, 2, 3)
transforms_after_resize = self.video_transforms_after_resize
inpaint_cond_data = self.mask_processor(video_value, mask_type_ratio_dict=self.mask_type_ratio_dict_video)
mask, masked_video = inpaint_cond_data['mask'], inpaint_cond_data['masked_pixel_values']
video_value = transforms_after_resize(video_value)
masked_video = transforms_after_resize(masked_video)
video_value = torch.cat([video_value, masked_video, mask], dim=1)
video_value = video_value.transpose(0, 1)
examples[VIDEO] = video_value
if (isinstance(texts, list) or isinstance(texts, tuple)) and len(texts) > 1:
texts = random.choice(texts)
if self.use_aesthetic:
if sample.get('aesthetic', None) is not None or sample.get('aes', None) is not None:
aes = sample.get('aesthetic', None) or sample.get('aes', None)
if file_type == "video":
texts = add_aesthetic_notice_video(texts, aes)
elif file_type == "image":
texts = add_aesthetic_notice_image(texts, aes)
if self.use_text_processer:
prompt_ids, prompt_mask = self.get_text_processer(texts)
examples[PROMPT_IDS], examples[PROMPT_MASK] = (
prompt_ids,
prompt_mask,
)
examples[FILE_INFO] = file_path
return examples