from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from pathlib import Path
from typing import List
import cv2
import numpy as np
from decord import VideoReader, cpu
from vrag.logger import logger
from vrag.shared import into_u8_frames
from vrag.types import FrameExtraction
def smart_resize_batch(frames: np.ndarray, target_h: int = 720) -> np.ndarray:
"""
Batch smart resize video frames.
Args:
frames: Numpy array of shape (N, H, W, C).
target_h: Target height of frames.
Returns:
Numpy array of shape (N, NEW_H, NEW_W, C).
"""
if frames is None or len(frames) == 0:
return np.empty((0,))
if frames.ndim != 4:
raise ValueError("Frames must in shape of (N, H, W, C)")
if target_h <= 0:
raise ValueError(f"Target frame height must be positive, but get target_h: {target_h}")
first_h, first_w = frames.shape[1:3]
if first_h <= 0 or first_w <= 0:
raise ValueError(f"Frames height and width must be positive, but get height: {first_h} and width: {first_w}")
baseline_width = target_h * 16 // 9
max_pixels = target_h * baseline_width
original_pixels = first_h * first_w
if original_pixels <= max_pixels:
msg = f"Original pixels {original_pixels} <= limit {max_pixels}, keep original sizes."
logger.debug(msg)
return frames
msg = f"Original pixels {original_pixels} exceed limit {max_pixels}, need resizing."
logger.debug(msg)
k = (float(max_pixels) / original_pixels) ** 0.5
new_h = int(first_h * k)
new_w = int(first_w * k)
def process_single_frame(img: np.ndarray) -> np.ndarray:
return cv2.resize(into_u8_frames(img), (new_w, new_h), interpolation=cv2.INTER_AREA)
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_single_frame, frame) for frame in frames]
results = [f.result() for f in futures]
return np.stack(results, axis=0)
def process_video(
video_path: Path,
max_frames_num: int,
fps: float = 1.0,
force_sample: bool = False,
decord_workers: int = 0,
resolution: int = 720,
) -> FrameExtraction:
"""
Extraction frames from a video.
Args:
video_path: Path to video file.
max_frames_num: Maximum number of frames sampled from video.
fps: Target frames per second for extraction.
force_sample: If force uniform sampling to max_frames_num.
decord_workers: Num of decord running threads.
resolution: Target height for resizing frames.
Returns:
FrameExtraction containing sampled frames.
"""
msg = f"Extraction frames from {video_path} with FPS={fps}."
logger.debug(msg)
vr = VideoReader(video_path.resolve().as_posix(), ctx=cpu(), num_threads=decord_workers)
total_frame_num: int = len(vr)
avg_fps: float = vr.get_avg_fps()
video_time: float = total_frame_num / avg_fps
frame_idx = compute_frame_idx(force_sample, avg_fps, total_frame_num, fps, max_frames_num)
spare_frames = vr.get_batch(frame_idx).asnumpy()
spare_frames = spare_frames.astype(np.uint8) if spare_frames.dtype != np.uint8 else spare_frames
return FrameExtraction(
frames=smart_resize_batch(spare_frames, resolution),
frame_timestamps=[i / avg_fps for i in frame_idx],
avg_fps=avg_fps,
video_duration=video_time,
total_frame_num=total_frame_num,
)
@lru_cache
def compute_frame_idx(
force_sample: bool,
avg_fps: float,
total_frame_num: int,
sample_fps: float,
max_frame_num: int,
) -> List[int]:
"""
Compute frame indices with fps and frame num.
1. If force_sample, just sample max_frame_num indices.
2. If not force_sample, use sample_fps as interval to sample indices with no more than max_frame_num frames.
Args:
force_sample: Force sample frames.
avg_fps: Average fps from raw video.
total_frame_num: Total frame num from raw video.
sample_fps: Target fps to sample frames.
max_frame_num: Max frame num to sample frames.
"""
if force_sample:
uniform_sampled_frames: np.ndarray = np.linspace(0, total_frame_num - 1, max_frame_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
step: int = max(1, round(avg_fps / sample_fps))
limix_index = min(total_frame_num, max_frame_num * step)
frame_idx = list(range(0, limix_index, step))
frame_idx = frame_idx[:max_frame_num]
return frame_idx