import copy
import json
import os
import random
import time
from typing import List, Optional, Union
import mindspeed.megatron_adaptor
import torch
import torch.distributed
from megatron.core import mpu
from megatron.training import get_args, print_rank_0
from megatron.training.initialize import initialize_megatron, set_jit_fusion_options
from numpy import save
from mindspeed_mm.configs.config import merge_mm_args, mm_extra_args_provider
from mindspeed_mm.data import build_mm_dataloader, build_mm_dataset
from mindspeed_mm.data.data_utils.constants import (
FILE_INFO,
PROMPT_IDS,
PROMPT_MASK,
VIDEO,
VIDEO_MASK,
)
from mindspeed_mm.data.data_utils.transform_pipeline import get_transforms
from mindspeed_mm.data.datasets.t2v_dataset import T2VDataset
from mindspeed_mm.models.ae import AEModel
from mindspeed_mm.models.text_encoder import TextEncoder
from mindspeed_mm.tools.profiler import Profiler
from mindspeed_mm.utils.utils import get_device, get_dtype, is_npu_available
from mindspeed_mm.tools.feature_extraction.get_sora_feature import FeatureExtractor
if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
class WanTextVideoDataset(T2VDataset):
def __init__(
self,
task,
basic_param,
vid_img_process: dict,
use_text_processer: bool = False,
enable_text_preprocessing: bool = True,
text_preprocess_methods: Optional[Union[dict, List[dict]]] = None,
tokenizer_config: Optional[Union[dict, List[dict]]] = None,
**kwargs,
):
video_only_transforms = vid_img_process.get("train_pipeline", {}).get("video_only", None)
if video_only_transforms is None:
raise ValueError('"video_only" key not found in vid_img_process["train_pipeline"]')
video_and_first_frame_transforms = vid_img_process.get("train_pipeline", {}).get("video_and_first_frame", None)
if video_and_first_frame_transforms is None:
raise ValueError('"video_and_first_frame" key not found in vid_img_process["train_pipeline"]')
video_only_preprocess = {"video": video_only_transforms}
vid_img_process["train_pipeline"] = {"video": video_and_first_frame_transforms}
super().__init__(
basic_param=basic_param,
vid_img_process=vid_img_process,
use_text_processer=use_text_processer,
enable_text_preprocessing=enable_text_preprocessing,
text_preprocess_methods=text_preprocess_methods,
tokenizer_config=tokenizer_config
)
self.video_only_preprocess = get_transforms(
is_video=True,
train_pipeline=video_only_preprocess,
transform_size={"max_height": vid_img_process['max_height'], "max_width": vid_img_process['max_width']}
)
self.task = task
def __getitem__(self, index):
example = {}
sample = self.data_samples[index]
file_path = sample["path"]
if not os.path.exists(file_path):
raise AssertionError(f"file {file_path} do not exist!")
vframes = self.video_reader(file_path)
video = self.video_processer(vframes=vframes, **sample)
if self.task == "i2v":
first_frame = video[:, 0, :, :]
example["first_frame"] = first_frame
video = self.video_only_preprocess(video)
example[VIDEO] = video
text = sample["cap"]
if not isinstance(text, list):
text = [text]
text = [random.choice(text)]
if self.use_text_processer:
prompt_ids, prompt_mask = self.get_text_processer(text)
example[PROMPT_IDS], example[PROMPT_MASK] = prompt_ids, prompt_mask
else:
example["text"] = text
example[FILE_INFO] = file_path
return example
class WanFeatureExtractor(FeatureExtractor):
def _prepare_data(self):
args = get_args()
task = args.mm.model.task if hasattr(args.mm.model, "task") else "t2v"
dataset_param = args.mm.data.dataset_param.to_dict()
dataset = WanTextVideoDataset(
task,
dataset_param["basic_parameters"],
dataset_param["preprocess_parameters"],
**dataset_param
)
dataloader = build_mm_dataloader(
dataset,
args.mm.data.dataloader_param,
process_group=mpu.get_data_parallel_group(),
dataset_param=args.mm.data.dataset_param,
)
return dataset, dataloader
if __name__ == "__main__":
print_rank_0("Starting feature extraction process")
extractor = WanFeatureExtractor()
extractor.extract_all()
print_rank_0("Feature extraction completed successfully")