import os
import re
import gc
import html
import math
import copy
import random
import urllib.parse as ul
from fractions import Fraction
from collections import Counter
from logging import getLogger
from typing import Any, Dict, Optional, Tuple, Union, Sequence, Literal, Type
from abc import ABC, abstractmethod
from pathlib import Path
import typing
import av
import torch
import torchvision
import numpy as np
from torchvision import get_video_backend
from torchvision.io.video import (
_align_audio_frames,
_check_av_available,
_read_from_stream,
_video_opt,
)
from mindspeed_mm.data.data_utils.utils import DecordInit
from mindspeed_mm.utils.utils import Registry
VideoLayoutType = Literal["THWC", "TCHW"]
VideoArrayType = Literal["numpy", "torch"]
class Video(ABC):
"""
Abstract base class defining the common video processing interface
"""
def __init__(self, video_path: str, layout: VideoLayoutType = "TCHW", array_type: VideoArrayType = "torch"):
"""
Initialize video source
Args:
video_path: String path to video file
layout (VideoLayoutType):
Desired tensor layout format. Options:
- "TCHW": Time, Channel, Height, Width (default)
- "THWC": Time, Height, Width, Channel
array_type (VideoArrayType):
Target array container type. Options:
- "torch": PyTorch tensors (default)
- "numpy": NumPy ndarrays
"""
self.video_path = Path(video_path)
self.layout = layout
self.array_type = array_type
self._validate_params()
self._load_data()
def _validate_params(self):
"""param validation"""
if self.layout not in typing.get_args(VideoLayoutType):
raise ValueError(f"Invalid video layout type: {self.layout}")
if self.array_type not in typing.get_args(VideoArrayType):
raise ValueError(f"Invalid video array type: {self.array_type}")
@abstractmethod
def _load_data(self):
"""
Abstract method for implementation-specific data loading
Raises:
VideoLoadError: If video file cannot be processed
"""
@abstractmethod
def get_batch(self, frame_indices: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
Retrieve video data batch in implementation-specific format
"""
@abstractmethod
def get_video_fps(self) -> float:
"""
Retrieve frames per second (FPS) information
Returns:
Frame rate as floating point value
"""
@abstractmethod
def get_len(self) -> int:
"""
Get the number of frames
"""
@Registry.register
class DecordVideo(Video):
"""
Decord-based video decoder implementation with shared decoder instance
Class Attributes:
_decoder (ClassVar[Optional[DecordInit]]): Shared decoder instance
"""
_decoder: Optional[object] = None
@classmethod
def _init_decoder(cls):
"""
Initialize shared decoder instance once
Initializes the class-level decoder instance on first call.
Subsequent calls reuse existing instance.
"""
if cls._decoder is None:
cls._decoder = DecordInit()
def _load_data(self):
self._init_decoder()
self.vframes = self.__class__._decoder(str(self.video_path))
def get_batch(self, frame_indices: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
video_data = self.vframes.get_batch(frame_indices).asnumpy()
if self.layout == "TCHW":
video_data = video_data.transpose(0, 3, 1, 2)
if self.array_type == "torch":
video_data = torch.from_numpy(video_data)
return video_data
def get_video_fps(self) -> float:
return self.vframes.get_avg_fps()
def get_len(self) -> int:
return len(self.vframes)
def get_frame_timestamp(self, frame_index: int) -> float:
return self.vframes.get_frame_timestamp(frame_index)
def next(self):
return self.vframes.next()
@Registry.register
class TorchvisionVideo(Video):
"""Torchvision-based video reader implementation"""
def _load_data(self):
self.vframes, _, self.metadata = torchvision.io.read_video(
str(self.video_path), pts_unit="sec", output_format=self.layout
)
def get_batch(self, frame_indices: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
video_data = self.vframes[frame_indices]
if self.layout == "numpy":
video_data = video_data.numpy()
return video_data
def get_video_fps(self) -> float:
return self.metadata.get("video_fps")
def get_len(self) -> int:
return len(self.vframes)
@Registry.register
class AvVideo(Video):
"""AV-based video reader implementation"""
def _load_data(self):
self.vframes, _, self.metadata = read_video_av(
str(self.video_path), pts_unit="sec", output_format=self.layout
)
def get_batch(self, frame_indices: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
video_data = self.vframes[frame_indices]
if self.layout == "numpy":
video_data = video_data.numpy()
return video_data
def get_video_fps(self) -> float:
return self.metadata.get("video_fps")
def get_len(self) -> int:
return len(self.vframes)
class VideoReader:
"""
Factory class for creating video reader instances
"""
def __init__(self, video_reader_type=None):
"""
Initialize with specified video type
Args:
video_type: Registered video backend type (e.g., 'decord', 'torchvision', 'av')
"""
self._reader_cls = Registry.get_class(video_reader_type)
def __call__(self, video_path, layout: VideoLayoutType = "TCHW", array_type: VideoArrayType = "torch"):
"""
Create and return a video reader instance with specified configurations.
Args:
video_path (str/Path): Path to the video file
layout (VideoLayoutType):
Expected tensor layout format. Default is "TCHW" (Time, Channels, Height, Width).
Other options might include "THWC" etc.
array_type (VideoArrayType):
Desired array type for frame data. Default is "torch" (PyTorch tensors).
Could support "numpy" for ndarrays or other types.
Returns:
BaseVideoReader: Instantiated video reader object of the specified type
"""
return self._reader_cls(video_path, layout=layout, array_type=array_type)
def read_video_av(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
start_pts (int if pts_unit = "pts", float / Fraction if pts_unit = "sec", optional):
The start presentation time of the video
end_pts (int if pts_unit = "pts", float / Fraction if pts_unit = "sec", optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either "pts" or "sec". Defaults to "pts".
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
if get_video_backend() != "pyav":
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
else:
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase
container = av.open(filename, metadata_errors="ignore")
try:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError as ex:
raise ex
finally:
container.close()
del container
gc.collect()
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info