import glob
import importlib
import logging
import shutil
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import Any, ClassVar
import av
import fsspec
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning("'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder")
return "pyav"
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str | None = None,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
Args:
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
Returns:
torch.Tensor: Decoded frames.
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
video_path = str(video_path)
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True
reader = torchvision.io.VideoReader(video_path, "video")
first_ts = min(timestamps)
last_ts = max(timestamps)
reader.seek(first_ts, keyframes_only=keyframes_only)
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
if backend == "pyav":
reader.container.close()
reader = None
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
return closest_frames
def decode_video_frames_npu(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using NPU backend.
Args:
video_path: Path to the video file.
timestamps: List of timestamps to extract frames.
tolerance_s: Allowed deviation in seconds for frame retrieval.
log_loaded_timestamps: Whether to log loaded timestamps.
Note: This function requires torch_npu and torchvision_npu to be installed.
Make sure to source the CANN environment before using
"""
try:
import torch_npu
import torchvision_npu
torch_npu.npu.current_stream().set_data_preprocess_stream(True)
torchvision_npu._set_video_backend('npu')
except ImportError as e:
raise ImportError(
"NPU backend requires torch_npu and torchvision_npu. Please install them and source CANN environment:\n"
) from e
video_path = str(video_path)
torchvision.set_video_backend('npu')
video_full, _, info = torchvision.io.read_video(video_path, pts_unit='sec')
video_fps = info.get('video_fps', 30.0)
frame_indices = [int(ts * video_fps) for ts in timestamps]
loaded_frames = []
loaded_ts = []
total_frames = video_full.shape[0] if len(video_full.shape) == 4 else 0
for idx in frame_indices:
if idx < total_frames:
frame = video_full[idx]
loaded_frames.append(frame)
ts = idx / video_fps
loaded_ts.append(ts)
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={ts:.4f}")
if len(loaded_frames) == 0:
raise ValueError(f"No frames found for timestamps: {timestamps}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: npu"
)
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
closest_frames = closest_frames.float() / 255.0
closest_frames = closest_frames.cpu()
assert len(timestamps) == len(closest_frames)
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
def clear(self):
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
self._cache.clear()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
return len(self._cache)
class FrameTimestampError(ValueError):
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
pass
_default_decoder_cache = VideoDecoderCache()
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Args:
video_path: Path to the video file.
timestamps: List of timestamps to extract frames.
tolerance_s: Allowed deviation in seconds for frame retrieval.
log_loaded_timestamps: Whether to log loaded timestamps.
decoder_cache: Optional decoder cache instance. Uses default if None.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if decoder_cache is None:
decoder_cache = _default_decoder_cache
decoder = decoder_cache.get_decoder(str(video_path))
loaded_ts = []
loaded_frames = []
metadata = decoder.metadata
average_fps = metadata.average_fps
frame_indices = [round(ts * average_fps) for ts in timestamps]
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
closest_frames = (closest_frames / 255.0).type(torch.float32)
if not len(timestamps) == len(closest_frames):
raise FrameTimestampError(f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}")
return closest_frames
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
if vcodec not in ["h264", "hevc", "libsvtav1"]:
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
if video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
return
video_path.parent.mkdir(parents=True, exist_ok=True)
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logging.warning(f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'")
pix_fmt = "yuv420p"
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0]))
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
video_options = {}
if g is not None:
video_options["g"] = str(g)
if crf is not None:
video_options["crf"] = str(crf)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if log_level is not None:
logging.getLogger("libav").setLevel(log_level)
with av.open(str(video_path), "w") as output:
output_stream = output.add_stream(vcodec, fps, options=video_options)
output_stream.pix_fmt = pix_fmt
output_stream.width = width
output_stream.height = height
for input_data in input_list:
with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
packet = output_stream.encode()
if packet:
output.mux(packet)
if log_level is not None:
av.logging.restore_default_callback()
if not video_path.exists():
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def concatenate_video_files(input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True):
"""
Concatenate multiple video files into a single video file using pyav.
This function takes a list of video input file paths and concatenates them into a single
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
concatenation without re-encoding.
Args:
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
codec, resolution, and frame rate for proper concatenation.
"""
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
input_container = av.open(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
)
stream_map = {}
for input_stream in input_container.streams:
if input_stream.type in ("video", "audio", "subtitle"):
stream_map[input_stream.index] = output_container.add_stream_from_template(
template=input_stream, opaque=True
)
stream_map[input_stream.index].time_base = input_stream.time_base
for packet in input_container.demux():
if packet.stream.index not in stream_map:
continue
if packet.dts is None:
continue
output_stream = stream_map[packet.stream.index]
packet.stream = output_stream
output_container.mux(packet)
input_container.close()
output_container.close()
shutil.move(tmp_output_video_path, output_video_path)
Path(tmp_concatenate_path).unlink()
@dataclass
class VideoFrame:
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
logging.getLogger("libav").setLevel(av.logging.ERROR)
audio_info = {}
with av.open(str(video_path), "r") as audio_file:
try:
audio_stream = audio_file.streams.audio[0]
except IndexError:
av.logging.restore_default_callback()
return {"has_audio": False}
audio_info["audio.channels"] = audio_stream.channels
audio_info["audio.codec"] = audio_stream.codec.canonical_name
audio_info["audio.bit_rate"] = audio_stream.bit_rate
audio_info["audio.sample_rate"] = audio_stream.sample_rate
audio_info["audio.bit_depth"] = audio_stream.format.bits
audio_info["audio.channel_layout"] = audio_stream.layout.name
audio_info["has_audio"] = True
av.logging.restore_default_callback()
return audio_info
def get_video_info(video_path: Path | str) -> dict:
logging.getLogger("libav").setLevel(av.logging.ERROR)
video_info = {}
with av.open(str(video_path), "r") as video_file:
try:
video_stream = video_file.streams.video[0]
except IndexError:
av.logging.restore_default_callback()
return {}
video_info["video.height"] = video_stream.height
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
video_info["video.fps"] = int(video_stream.base_rate)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
av.logging.restore_default_callback()
video_info.update(**get_audio_info(video_path))
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
Args:
video_path: Path to the video file.
Returns:
Duration of the video in seconds.
"""
with av.open(str(video_path)) as container:
video_stream = container.streams.video[0]
if video_stream.duration is not None:
duration = float(video_stream.duration * video_stream.time_base)
else:
duration = float(container.duration / av.time_base)
return duration
class VideoEncodingManager:
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
"""
def __init__(self, dataset):
self.dataset = dataset
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
self.dataset.finalize()
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
img_dir = self.dataset.root / "images"
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False