from __future__ import annotations
import importlib
import logging
import os
import random
import sys
from dataclasses import dataclass, field, fields
from pathlib import Path
from typing import Any, Literal
import torch
from torch.utils.data import Dataset
if "ltx_core" not in sys.modules:
sys.modules["ltx_core"] = importlib.import_module("mindspeed_mm.fsdp.models.ltx2.ltx_core")
from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier, get_pixel_coords
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
from mindspeed_mm.fsdp.utils.register import data_register
logger = logging.getLogger(__name__)
@dataclass
class LTX2PrecomputedBasicArgs:
dataset_dir: str = "/data/ltx2"
latents_dir: str = "latents"
conditions_dir: str = "conditions"
max_samples: int | None = None
fps: float = 24.0
first_frame_conditioning_p: float = 0.0
with_audio: bool = False
audio_latents_dir: str = "audio_latents"
audio_channels: int = 8
audio_mel_bins: int = 16
audio_prompt_key: str = "audio_prompt_embeds"
timestep_sampling_mode: Literal["uniform", "shifted_logit_normal"] = "shifted_logit_normal"
timestep_sampling_params: dict[str, float] = field(default_factory=dict)
@data_register.register("ltx2_precomputed")
class LTX2PrecomputedDataset(Dataset):
"""Load LTX2 precomputed latents/conditions and emit training-ready velocity targets.
Expected layout:
{root}/.precomputed/latents/*.pt
{root}/.precomputed/conditions/*.pt
"""
def __init__(self, basic_param, preprocess_param=None, dataset_param=None, **kwargs):
_ = (preprocess_param, kwargs)
resolved_basic_param = self._resolve_basic_param(basic_param, dataset_param)
self.args = LTX2PrecomputedBasicArgs(**resolved_basic_param)
self._video_patchifier = VideoLatentPatchifier(patch_size=1)
self._audio_patchifier = AudioPatchifier(patch_size=1)
self._video_scale_factors = SpatioTemporalScaleFactors.default()
root = Path(self.args.dataset_dir).expanduser().resolve()
if (root / ".precomputed").is_dir():
root = root / ".precomputed"
self.latents_root = root / self.args.latents_dir
self.conditions_root = root / self.args.conditions_dir
self.audio_latents_root = root / self.args.audio_latents_dir if self.args.with_audio else None
if not self.latents_root.is_dir():
raise FileNotFoundError(f"Latents directory not found: {self.latents_root}")
if not self.conditions_root.is_dir():
raise FileNotFoundError(f"Conditions directory not found: {self.conditions_root}")
if self.args.with_audio and (self.audio_latents_root is None or not self.audio_latents_root.is_dir()):
raise FileNotFoundError(f"Audio latents directory not found: {self.audio_latents_root}")
latent_files = sorted(self.latents_root.rglob("*.pt"))
if not latent_files:
raise ValueError(f"No latent .pt files found under: {self.latents_root}")
pairs: list[tuple[Path, Path, Path | None]] = []
for latent_path in latent_files:
rel = latent_path.relative_to(self.latents_root)
cond_path = self.conditions_root / rel
if not cond_path.exists() and latent_path.name.startswith("latent_"):
cond_path = self.conditions_root / rel.with_name(f"condition_{latent_path.stem[7:]}.pt")
if not cond_path.exists():
continue
audio_latent_path: Path | None = None
if self.args.with_audio:
audio_latent_path = self.audio_latents_root / rel
if not audio_latent_path.exists():
continue
pairs.append((latent_path, cond_path, audio_latent_path))
if not pairs:
raise ValueError("No matched (latents, conditions) file pairs were found.")
if self.args.max_samples is not None:
pairs = pairs[: self.args.max_samples]
pairs = self._trim_pairs_for_data_parallel(pairs)
self.sample_pairs = pairs
@staticmethod
def _get_world_size() -> int:
if torch.distributed.is_available() and torch.distributed.is_initialized():
return int(torch.distributed.get_world_size())
env_world_size = os.getenv("WORLD_SIZE")
if env_world_size is not None:
try:
return int(env_world_size)
except ValueError:
return 1
return 1
@classmethod
def _trim_pairs_for_data_parallel(cls, pairs: list[tuple[Path, Path, Path | None]]) -> list[tuple[Path, Path, Path | None]]:
"""Drop tail samples so per-rank sample counts stay aligned with drop_last=True."""
world_size = max(cls._get_world_size(), 1)
if world_size <= 1:
return pairs
usable = (len(pairs) // world_size) * world_size
if usable == 0:
raise ValueError(
f"Dataset has {len(pairs)} samples, smaller than world size {world_size}. "
"Cannot shard evenly for distributed training."
)
if usable != len(pairs):
dropped = len(pairs) - usable
logger.warning(
f"Trimming {dropped} tail samples from LTX2 dataset for even DP sharding: {len(pairs)} -> {usable} (world_size={world_size})"
)
return pairs[:usable]
return pairs
@staticmethod
def _to_dict(obj: Any) -> dict[str, Any]:
if obj is None:
return {}
if isinstance(obj, dict):
return obj
if hasattr(obj, "to_dict"):
return obj.to_dict()
return {}
@classmethod
def _resolve_basic_param(cls, basic_param: Any, dataset_param: Any) -> dict[str, Any]:
merged: dict[str, Any] = {}
dataset_param_dict = cls._to_dict(dataset_param)
dataset_custom = dataset_param_dict.get("ltx2_dataset_custom", {})
if isinstance(dataset_custom, dict):
merged.update(dataset_custom)
if isinstance(basic_param, dict):
merged.update(basic_param)
valid_keys = {f.name for f in fields(LTX2PrecomputedBasicArgs)}
resolved = {k: v for k, v in merged.items() if k in valid_keys}
return resolved
def __len__(self) -> int:
return len(self.sample_pairs)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Return one raw sample in upstream ltx-trainer style.
The actual training input preparation (noise/timesteps/positions/targets)
is done in ``prepare_training_inputs`` during the train step.
"""
latent_path, cond_path, audio_latent_path = self.sample_pairs[idx]
latent_data = self._safe_torch_load(latent_path)
cond_data = self._safe_torch_load(cond_path)
latent_data = self._normalize_video_latents(latent_data)
sample: dict[str, Any] = {
"latents": latent_data,
"conditions": cond_data,
"idx": idx,
}
if self.args.with_audio:
if audio_latent_path is None:
raise ValueError("Audio training is enabled but audio latent path is missing.")
audio_data = self._safe_torch_load(audio_latent_path)
sample["audio_latents"] = audio_data
return sample
def _sample_sigmas(self, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
mode = self.args.timestep_sampling_mode
params = self.args.timestep_sampling_params or {}
if mode == "uniform":
min_value = float(params.get("min_value", 0.0))
max_value = float(params.get("max_value", 1.0))
if max_value <= min_value:
raise ValueError(
f"`timestep_sampling_params.max_value` ({max_value}) must be larger than "
f"`min_value` ({min_value}) for uniform mode."
)
sigmas = torch.rand((batch_size,), device=device, dtype=torch.float32) * (max_value - min_value) + min_value
return sigmas
if mode == "shifted_logit_normal":
std = float(params.get("std", 1.0))
min_tokens = float(params.get("min_tokens", 1024))
max_tokens = float(params.get("max_tokens", 4096))
min_shift = float(params.get("min_shift", 0.95))
max_shift = float(params.get("max_shift", 2.05))
if max_tokens <= min_tokens:
raise ValueError(
f"`timestep_sampling_params.max_tokens` ({max_tokens}) must be larger than "
f"`min_tokens` ({min_tokens}) for shifted_logit_normal mode."
)
m = (max_shift - min_shift) / (max_tokens - min_tokens)
b = min_shift - m * min_tokens
shift = m * float(seq_len) + b
sigmas = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=torch.float32) * std + shift)
return sigmas
raise ValueError(
f"Unsupported timestep_sampling_mode: {mode}. "
f"Expected one of ['uniform', 'shifted_logit_normal']."
)
def _prepare_base_inputs(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Prepare deterministic inputs; stochastic noise/timestep sampling happens in train step."""
def _first_scalar(x: Any) -> float:
if isinstance(x, torch.Tensor):
if x.numel() == 0:
raise ValueError("Expected non-empty tensor for scalar metadata.")
return float(x.reshape(-1)[0].item())
if isinstance(x, (list, tuple)):
if len(x) == 0:
raise ValueError("Expected non-empty list/tuple for scalar metadata.")
return _first_scalar(x[0])
if isinstance(x, (int, float)):
return float(x)
raise TypeError(f"Unsupported scalar metadata type: {type(x).__name__}")
latents = batch["latents"]
video_latents = latents["latents"]
video_latents = self._video_patchifier.patchify(video_latents)
num_frames = int(_first_scalar(latents["num_frames"]))
height = int(_first_scalar(latents["height"]))
width = int(_first_scalar(latents["width"]))
fps_tensor = latents.get("fps", None)
fps = float(_first_scalar(fps_tensor)) if fps_tensor is not None else float(self.args.fps)
batch_size = video_latents.shape[0]
dtype = video_latents.dtype
device = video_latents.device
video_positions = self._build_video_positions(
num_frames=num_frames,
height=height,
width=width,
fps=fps,
dtype=dtype,
batch_size=batch_size,
device=device,
)
prepared: dict[str, torch.Tensor] = {
"video_latent_clean": video_latents,
"video_positions": video_positions,
"video_height": torch.tensor(height, dtype=torch.int64, device=device),
"video_width": torch.tensor(width, dtype=torch.int64, device=device),
}
prepared.update(self._extract_condition_inputs(batch["conditions"]))
if "idx" in batch:
idx_data = batch["idx"]
if isinstance(idx_data, torch.Tensor):
prepared["sample_idx"] = idx_data.to(device=device, dtype=torch.int64)
elif isinstance(idx_data, list):
prepared["sample_idx"] = torch.tensor(idx_data, dtype=torch.int64, device=device)
if self.args.with_audio:
if "audio_latents" not in batch:
raise ValueError("Audio training is enabled but `audio_latents` is missing in batch.")
audio_latents = self._extract_audio_latents(batch["audio_latents"]).to(dtype)
audio_positions = self._build_audio_positions(
num_steps=audio_latents.shape[1],
dtype=dtype,
batch_size=batch_size,
device=device,
)
prepared.update(
{
"audio_latent_clean": audio_latents,
"audio_positions": audio_positions,
}
)
return prepared
def _prepare_training_inputs(self, batch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Prepare stochastic model inputs in upstream ltx-trainer style at train step time.
Expects batch to already have been processed by _prepare_base_inputs (i.e., contains
'video_latent_clean', 'video_positions', 'context', 'context_mask', etc.).
"""
video_latents = batch["video_latent_clean"]
batch_size = video_latents.shape[0]
seq_len = video_latents.shape[1]
device = video_latents.device
sigmas = self._sample_sigmas(batch_size, seq_len, device=device)
video_noise = torch.randn_like(video_latents)
sigmas_expanded = sigmas.view(-1, 1, 1)
noisy_video = (1.0 - sigmas_expanded) * video_latents + sigmas_expanded * video_noise
video_targets = video_noise - video_latents
height = int(batch["video_height"].reshape(-1)[0].item())
width = int(batch["video_width"].reshape(-1)[0].item())
conditioning_mask = self._create_first_frame_conditioning_mask(
batch_size=batch_size,
sequence_length=seq_len,
height=height,
width=width,
device=device,
first_frame_conditioning_p=self.args.first_frame_conditioning_p,
)
noisy_video = torch.where(conditioning_mask.unsqueeze(-1), video_latents, noisy_video)
video_timesteps = self._create_per_token_timesteps(conditioning_mask, sigmas)
video_loss_mask = ~conditioning_mask
prepared: dict[str, torch.Tensor] = {
"video_latent": noisy_video,
"video_target_velocity": video_targets,
"video_timesteps": video_timesteps,
"video_positions": batch["video_positions"],
"context": batch["context"],
"context_mask": batch["context_mask"],
"loss_mask": video_loss_mask,
}
if "sample_idx" in batch:
prepared["sample_idx"] = batch["sample_idx"]
if self.args.with_audio and "audio_latent_clean" in batch:
audio_latents = batch["audio_latent_clean"]
audio_seq_len = audio_latents.shape[1]
audio_noise = torch.randn_like(audio_latents)
noisy_audio = (1.0 - sigmas_expanded) * audio_latents + sigmas_expanded * audio_noise
audio_targets = audio_noise - audio_latents
audio_timesteps = sigmas.view(-1, 1).expand(-1, audio_seq_len).clone()
prepared.update(
{
"audio_latent": noisy_audio,
"audio_target_velocity": audio_targets,
"audio_timesteps": audio_timesteps,
"audio_positions": batch["audio_positions"],
}
)
return prepared
@staticmethod
def _create_per_token_timesteps(conditioning_mask: torch.Tensor, sampled_sigma: torch.Tensor) -> torch.Tensor:
expanded_sigma = sampled_sigma.view(-1, 1).expand_as(conditioning_mask).clone()
return torch.where(conditioning_mask, torch.zeros_like(expanded_sigma), expanded_sigma)
@staticmethod
def _create_first_frame_conditioning_mask(
batch_size: int,
sequence_length: int,
height: int,
width: int,
device: torch.device,
first_frame_conditioning_p: float = 0.0,
) -> torch.Tensor:
conditioning_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
first_frame_end_idx = height * width
if first_frame_conditioning_p > 0 and random.random() < first_frame_conditioning_p:
if first_frame_end_idx < sequence_length:
conditioning_mask[:, :first_frame_end_idx] = True
return conditioning_mask
def collate_fn(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
def _collate_values(values: list[Any], key_path: str) -> Any:
first = values[0]
if isinstance(first, torch.Tensor):
if any((not isinstance(v, torch.Tensor)) or (v.shape != first.shape) for v in values):
raise ValueError(f"Shape mismatch for `{key_path}`")
return torch.stack(values, dim=0)
if isinstance(first, dict):
out: dict[str, Any] = {}
for k in first.keys():
out[k] = _collate_values([v[k] for v in values], f"{key_path}.{k}" if key_path else k)
return out
if isinstance(first, bool):
return torch.tensor(values, dtype=torch.bool)
if isinstance(first, int):
return torch.tensor(values, dtype=torch.int64)
if isinstance(first, float):
return torch.tensor(values, dtype=torch.float32)
return values
raw_batch = _collate_values(features, "")
base_inputs = self._prepare_base_inputs(raw_batch)
return self._prepare_training_inputs(base_inputs)
@staticmethod
def _safe_torch_load(path: Path) -> dict:
try:
return torch.load(path, map_location="cpu", weights_only=True)
except Exception:
return torch.load(path, map_location="cpu")
@staticmethod
def _normalize_video_latents(latent_data: dict[str, Any]) -> dict[str, Any]:
data = dict(latent_data)
latents = data["latents"]
if latents.dim() == 2:
num_frames = int(data["num_frames"])
height = int(data["height"])
width = int(data["width"])
channels = latents.shape[-1]
latents = latents.reshape(num_frames, height, width, channels).permute(3, 0, 1, 2).contiguous()
data["latents"] = latents
elif latents.dim() != 4:
raise ValueError(f"Unsupported latent shape: {tuple(latents.shape)}")
return data
def _build_video_positions(
self,
num_frames: int,
height: int,
width: int,
fps: float,
dtype: torch.dtype,
batch_size: int = 1,
device: torch.device | None = None,
) -> torch.Tensor:
latent_coords = self._video_patchifier.get_patch_grid_bounds(
output_shape=VideoLatentShape(
frames=num_frames,
height=height,
width=width,
batch=batch_size,
channels=128,
),
device=device,
)
pixel_coords = get_pixel_coords(
latent_coords=latent_coords,
scale_factors=self._video_scale_factors,
causal_fix=True,
).to(dtype)
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / max(fps, 1e-6)
return pixel_coords
def _extract_audio_latents(self, audio_data: dict[str, Any]) -> torch.Tensor:
latents = audio_data["latents"]
if latents.dim() == 2:
if latents.shape[-1] != self.args.audio_channels * self.args.audio_mel_bins:
raise ValueError(
"2D audio latents expected shape [T, C*F], "
f"got last_dim={latents.shape[-1]}, expected {self.args.audio_channels * self.args.audio_mel_bins}."
)
return latents.unsqueeze(0)
if latents.dim() == 3:
if latents.shape[0] == self.args.audio_channels and latents.shape[-1] == self.args.audio_mel_bins:
return self._audio_patchifier.patchify(latents.unsqueeze(0))
if latents.shape[-1] == self.args.audio_channels * self.args.audio_mel_bins:
return latents
raise ValueError(
"3D audio latents expected either [C, T, F] or [B, T, C*F], "
f"got shape: {tuple(latents.shape)}."
)
if latents.dim() == 4:
return self._audio_patchifier.patchify(latents)
raise ValueError(f"Unsupported audio latent shape: {tuple(latents.shape)}")
def _build_audio_positions(
self,
num_steps: int,
dtype: torch.dtype,
batch_size: int = 1,
device: torch.device | None = None,
) -> torch.Tensor:
latent_coords = self._audio_patchifier.get_patch_grid_bounds(
output_shape=AudioLatentShape(
frames=num_steps,
mel_bins=self.args.audio_mel_bins,
batch=batch_size,
channels=self.args.audio_channels,
),
device=device,
)
return latent_coords.to(dtype)
@staticmethod
def _normalize_context(context: torch.Tensor) -> torch.Tensor:
if context.dim() == 2:
context = context.unsqueeze(0)
elif context.dim() != 3:
raise ValueError(f"Unsupported context shape: {tuple(context.shape)}")
return context.to(torch.float32)
@staticmethod
def _normalize_context_mask(context: torch.Tensor, context_mask: torch.Tensor | None) -> torch.Tensor:
if context_mask is None:
context_mask = torch.ones(context.shape[:2], dtype=torch.bool, device=context.device)
if context_mask.dim() == 1:
context_mask = context_mask.unsqueeze(0)
elif context_mask.dim() != 2:
raise ValueError(f"Unsupported context mask shape: {tuple(context_mask.shape)}")
return context_mask
def _extract_condition_inputs(self, cond_data: dict[str, Any]) -> dict[str, torch.Tensor]:
if "prompt_embeds" not in cond_data:
raise KeyError("Condition file must contain `prompt_embeds` for LTX2 training.")
context = self._normalize_context(cond_data["prompt_embeds"])
context_mask = self._normalize_context_mask(context, cond_data.get("prompt_attention_mask"))
return {
"context": context,
"context_mask": context_mask,
}