import asyncio
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import List, Optional, Tuple, TypeAlias
import cv2
import numpy as np
from vrag.logger import logger
from vrag.shared import into_u8_frames, once
from vrag.tools.np_cacher import CacherBase, get_cacher
FrameFingerPrint: TypeAlias = int
@dataclass
class ImageHasher:
thread_pool: ThreadPoolExecutor
cacher: CacherBase
@staticmethod
@once
def instance() -> "ImageHasher":
return ImageHasher.new()
@classmethod
def new(cls, workers: Optional[int] = None, cap: int = 4096) -> "ImageHasher":
return cls(thread_pool=ThreadPoolExecutor(max_workers=workers), cacher=get_cacher(cap))
@classmethod
def with_cacher(cls, cacher: CacherBase, workers: Optional[int] = None):
return cls(thread_pool=ThreadPoolExecutor(max_workers=workers), cacher=cacher)
def get_unique_frame_indices(self, frames: np.ndarray, threshold: int = 5, block_size: int = 16) -> List[int]:
"""
Identify indices of unique frames from a video stream array.
Processes a 4D numpy array (N, H, W, C) to extract frames that differ significantly from their immediate
predecessors using perceptual hashing.
Args:
frames: video frames.
threshold: Hamming distance threshold for considering different.
block_size: Grid size for fingerprint calculation.
Returns:
List[int]: Indices of frames selected to represent unique visual content.
"""
if frames.ndim != 4:
raise ValueError(f"Frames expected 4D array (N, H, C, W), but get {frames.ndim}D")
frames_num = frames.shape[0]
if frames_num < 2:
return []
frames_hashes = self._compute_hashes(frames, block_size)
keep_flags = [False] * frames_num
keep_flags[0] = True
for i in range(1, frames_num):
prev_hash = frames_hashes[i - 1]
curr_hash = frames_hashes[i]
if prev_hash is not None and curr_hash is not None:
distance = _hamming_distance(prev_hash, curr_hash)
if distance > threshold:
keep_flags[i] = True
keep_indices = [idx for idx, keep in enumerate(keep_flags) if keep]
msg = f"Discard duplicated frames: {len(frames) - len(keep_indices)}, remaining {len(keep_indices)}"
logger.debug(msg)
return keep_indices
async def get_unique_frame_indices_async(
self, frames: np.ndarray, threshold: int = 5, block_size: int = 16
) -> List[int]:
return await asyncio.to_thread(self.get_unique_frame_indices, frames, threshold, block_size)
def _compute_hashes(self, frames: np.ndarray, block_size: int) -> List[int]:
def _cache_suffix(block_size: int) -> str:
return f"{block_size}"
@self.cacher.cached_sync_with(_cache_suffix)
def _compute(spare_frames: np.ndarray, block_size: int) -> List[int]:
return self._compute_hashes_inner(block_size, spare_frames)
return _compute(frames, block_size)
def _compute_hashes_inner(self, block_size: int, frames: np.ndarray) -> List[Optional[int]]:
frames_num = frames.shape[0]
tasks = [(i, frames[i], block_size) for i in range(frames_num)]
computed_hashes: List[Optional[int]] = [None] * frames_num
futures = {self.thread_pool.submit(_compute_fingerprint_task, t): t[0] for t in tasks}
for future in futures:
frame_idx, h = future.result()
computed_hashes[frame_idx] = h
return computed_hashes
def _get_frame_fingerprint(frame: np.ndarray, block_size: int = 16) -> FrameFingerPrint:
"""
Generate a robust integer fingerprint using OpenCV DCT perceptual hash.
Uses doubled resolution grid based on block_size to enhance precision.
Convert input to grayscale float32 before processing.
Args:
frame: Input images, support 2D grayscale or 3D BGR arrays.
block_size: Grid dimension for low-frequency extraction.
Returns:
int: Binary hash converted to an unsigned integer.
"""
if frame.ndim == 3:
gray = cv2.cvtColor(into_u8_frames(frame), cv2.COLOR_BGR2GRAY)
elif frame.ndim == 2:
gray = frame
else:
raise ValueError(f"Input frame ndim: {frame.ndim} is not valid")
if gray.dtype != np.float32:
gray = gray.astype(np.float32)
target_size = block_size * 8
resized = cv2.resize(gray, (target_size, target_size))
dct = cv2.dct(resized)
low_freq = dct[:block_size, :block_size]
median = np.median(low_freq)
binary_hash = (low_freq > median).astype(np.uint8)
return int("".join(map(str, binary_hash.flatten())), 2)
def _compute_fingerprint_task(task_args: Tuple[int, np.ndarray, int]) -> Tuple[int, FrameFingerPrint]:
i, frame, bs = task_args
hash_int = _get_frame_fingerprint(frame, block_size=bs)
return i, hash_int
def _hamming_distance(hash0: FrameFingerPrint, hash1: FrameFingerPrint) -> int:
"""Calculate bitwise difference between two integer."""
return bin(hash0 ^ hash1).count("1")