import json
import os
import time
import copy
from typing import Dict, Any, Tuple, List, Union, Optional
import mindspeed.megatron_adaptor
import torch
import torch.distributed
from megatron.core import mpu
from megatron.training import get_args, print_rank_0
from megatron.training.initialize import initialize_megatron, set_jit_fusion_options
from numpy import save
from tqdm import tqdm
from mindspeed_mm.configs.config import merge_mm_args, mm_extra_args_provider
from mindspeed_mm.data import build_mm_dataloader, build_mm_dataset
from mindspeed_mm.data.data_utils.constants import (
FILE_INFO,
PROMPT_IDS,
PROMPT_MASK,
VIDEO,
VIDEO_MASK,
)
from mindspeed_mm.models.ae import AEModel
from mindspeed_mm.models.text_encoder import TextEncoder
from mindspeed_mm.tools.profiler import Profiler
from mindspeed_mm.utils.utils import get_device, get_dtype, is_npu_available
if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
class FeatureExtractor:
"""
Distributed feature extractor for multimodal data (video + text)
This class handles:
1. Distributed environment setup using Megatron
2. Data loading and preprocessing
3. Feature extraction using autoencoder (video) and text encoder models
4. Saving extracted features to disk
5. Metadata management for extracted features
"""
def __init__(self):
"""Initialize the feature extraction pipeline"""
self._initialize_distributed()
self.save_path = self.args.mm.tool.sorafeature.save_path
self.features_dir = os.path.join(self.save_path, "features")
if self.rank == 0:
os.makedirs(self.features_dir, exist_ok=True)
print_rank_0(f"Created features directory at: {self.features_dir}")
set_jit_fusion_options()
torch.set_grad_enabled(False)
self.device = get_device("npu")
self.ae_dtype = get_dtype(self.args.mm.model.ae.dtype)
self.dataset, self.dataloader = self._prepare_data()
self._write_data_info()
torch.distributed.barrier()
self.vae, self.text_encoder = self._prepare_models()
def _initialize_distributed(self):
"""Initialize Megatron distributed training environment"""
initialize_megatron(extra_args_provider=mm_extra_args_provider, args_defaults={})
args = get_args()
merge_mm_args(args)
self.args = get_args()
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
print(f"Initialized distributed environment (rank {self.rank}/{self.world_size})")
def _write_data_info(self):
"""
Write dataset metadata information (JSONL file)
"""
if self.rank != 0:
return
print_rank_0("Writing dataset metadata information...")
data_info_path = os.path.join(self.save_path, 'data.jsonl')
with open(data_info_path, 'w', encoding="utf-8") as json_file:
storage_mode = self.args.mm.data.dataset_param.basic_parameters.data_storage_mode
if storage_mode == "combine":
source_file_key = "path"
elif storage_mode == "standard":
source_file_key = FILE_INFO
else:
raise NotImplementedError(f"Unsupported storage mode: {storage_mode}")
for data_sample in self.dataset.data_samples:
file_name = data_sample[source_file_key]
pt_name = self._generate_safe_filename(file_name)
data_info = copy.deepcopy(data_sample)
data_info[FILE_INFO] = f"features/{pt_name}"
json_file.write(json.dumps(data_info) + '\n')
print_rank_0(f"Dataset metadata written to: {data_info_path}")
def extract_all(self):
"""Main method to extract features from all data samples"""
total_samples = len(self.dataset)
print_rank_0(f"Starting feature extraction. Total samples: {total_samples}")
counter = 0
profiler = self._init_profiler()
if profiler:
profiler.start()
try:
for _, batch in tqdm(enumerate(self.dataloader)):
file_names, latents, latents_dict, prompt, prompt_mask = self._extract_single(batch)
batch_size = latents.shape[0]
counter += batch_size
for i in range(batch_size):
self._save_sample_features(
file_name=file_names[i],
latent=latents[i],
prompt=prompt,
prompt_mask=prompt_mask,
sample_idx=i,
latents_dict=latents_dict
)
if profiler:
profiler.step()
except Exception as e:
print_rank_0(f"Feature extraction failed: {str(e)}")
raise
finally:
if profiler:
profiler.stop()
def _init_profiler(self):
"""Initialize performance profiler if enabled in configuration"""
if hasattr(self.args.mm.tool, "profile"):
print_rank_0("Initializing performance profiler")
return Profiler(self.args.mm.tool.profile)
return None
def _extract_single(
self,
batch: Dict[str, Any]
) -> Tuple[List[str], torch.Tensor, Dict[str, Any], Any, Any]:
"""
Extract features from a batch of data
Returns:
file_names: List of original file names
latents: Extracted video features (tensor)
latents_dict: Additional video features (dict)
prompt: Extracted text features
prompt_mask: Text attention masks
"""
if not batch:
raise ValueError("Received empty batch")
video = batch.pop(VIDEO).to(self.device, dtype=self.ae_dtype)
prompt_ids = batch.pop(PROMPT_IDS)
prompt_mask = batch.pop(PROMPT_MASK)
file_names = batch.pop(FILE_INFO)
latents, latents_dict = self.vae.encode(video, **batch)
prompt, prompt_mask = self.text_encoder.encode(prompt_ids, prompt_mask)
return file_names, latents, latents_dict, prompt, prompt_mask
def _save_sample_features(
self,
file_name: str,
latent: torch.Tensor,
prompt: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
prompt_mask: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
sample_idx: int,
latents_dict: Optional[Dict[str, Any]] = None
):
"""Save extracted features for a single sample to disk"""
pt_name = self._generate_safe_filename(file_name)
save_path = os.path.join(self.features_dir, pt_name)
data_to_save = {
VIDEO: latent.cpu(),
PROMPT_IDS: self._extract_prompt_component(prompt, sample_idx),
PROMPT_MASK: self._extract_prompt_component(prompt_mask, sample_idx)
}
if latents_dict:
for key, value in latents_dict.items():
item = value[sample_idx]
data_to_save[key] = item.cpu() if isinstance(item, torch.Tensor) else item
torch.save(data_to_save, save_path)
def _extract_prompt_component(
self,
prompt: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
idx: int
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Extract prompt component for a specific sample index"""
if isinstance(prompt, (list, tuple)):
return [p[idx].cpu() for p in prompt]
return prompt[idx].cpu()
def _prepare_data(self) -> Tuple[Any, Any]:
"""Prepare dataset and data loader"""
dataset = build_mm_dataset(self.args.mm.data.dataset_param)
dataloader = build_mm_dataloader(
dataset,
self.args.mm.data.dataloader_param,
process_group=mpu.get_data_parallel_group(),
dataset_param=self.args.mm.data.dataset_param,
)
print_rank_0(f"Prepared dataset with {len(dataset)} samples")
return dataset, dataloader
def _prepare_models(self) -> Tuple[torch.nn.Module, torch.nn.Module]:
"""Initialize and configure models for inference"""
vae = AEModel(self.args.mm.model.ae)
vae = vae.to(self.device, dtype=self.ae_dtype).eval()
text_encoder = TextEncoder(self.args.mm.model.text_encoder)
text_encoder = text_encoder.to(self.device).eval()
print_rank_0("Models initialized and moved to evaluation mode")
return vae, text_encoder
@staticmethod
def _generate_safe_filename(file_path: str) -> str:
"""
Generate a safe filename without special characters
Example:
Input: "/path/to/video.mp4"
Output: "video_mp4.pt"
"""
base_name = os.path.basename(file_path)
safe_name = base_name.replace(".", "_") + ".pt"
return safe_name
if __name__ == "__main__":
print_rank_0("Starting feature extraction process")
extractor = FeatureExtractor()
extractor.extract_all()
print_rank_0("Feature extraction completed successfully")