from typing import List, Optional
import bentoml
import numpy as np
from pydantic import Field
from vrag.bentos.aks import AksBlipConfig, AksBlipArgs, AksBlipService
from vrag.bentos.mmdino import MMDinoArgs, MMDinoConfig, MMDINOService
from vrag.logger import logger
from vrag.shared import first_available, ConfigBase, retry_async_request, vrag_service
from vrag.tools.imagehash import ImageHasher
from vrag.tools.np_cacher import get_cacher
from vrag.tools.query import Query
from vrag.tools.scene import SceneDescriber
from vrag.tools.selecters import indexed
from vrag.types import DETDoc, DetectionResult, MMDINODetectionBatchResult, FrameExtraction
class DetectionArgs(AksBlipArgs, MMDinoArgs):
detection_cache_size: int = Field(4096, ge=0)
"""LRU cache capacity for detection service results."""
default_use_det: bool = True
"""Whether to use object detection by default."""
default_det_dedup_frames: bool = Field(True)
"""Whether to deduplicate frames before detection by default."""
default_det_dedup_threshold: int = Field(2, ge=0)
"""Default Hamming distance threshold for frame deduplication before detection."""
default_det_dedup_block_size: int = Field(12, ge=8)
"""Default block size for perceptual hashing in frame deduplication."""
default_det_location: bool = Field(True)
"""Whether to include location descriptions in detection results by default."""
default_det_relation: bool = Field(True)
"""Whether to include spatial relation descriptions in detection results by default."""
default_det_number: bool = Field(True)
"""Whether to include object count descriptions in detection results by default."""
default_retrieve_frame_only: bool = Field(True)
"""Whether to only retrieve keyframes without generating scene descriptions by default."""
class DetectionServiceConfig(ConfigBase):
aks: Optional[AksBlipConfig] = None
mmdino: Optional[MMDinoConfig] = None
use_det: Optional[bool] = None
det_dedup_frames: Optional[bool] = None
det_dedup_threshold: Optional[int] = None
det_dedup_block_size: Optional[int] = None
det_location: Optional[bool] = None
det_relation: Optional[bool] = None
det_number: Optional[bool] = None
retrieve_frame_only: Optional[bool] = None
@staticmethod
def merge_config(config: Optional["DetectionServiceConfig"]) -> "DetectionServiceConfig":
if config is None:
return DetectionServiceConfig(
aks=AksBlipConfig.merge_config(None),
mmdino=MMDinoConfig.merge_config(None),
use_det=args.default_use_det,
det_dedup_frames=args.default_det_dedup_frames,
det_dedup_threshold=args.default_det_dedup_threshold,
det_dedup_block_size=args.default_det_dedup_block_size,
det_location=args.default_det_location,
det_relation=args.default_det_relation,
det_number=args.default_det_number,
retrieve_frame_only=args.default_retrieve_frame_only,
)
return DetectionServiceConfig(
aks=AksBlipConfig.merge_config(config.aks),
mmdino=MMDinoConfig.merge_config(config.mmdino),
use_det=first_available(config.use_det, args.default_use_det),
det_dedup_frames=first_available(config.det_dedup_frames, args.default_det_dedup_frames),
det_dedup_threshold=first_available(config.det_dedup_threshold, args.default_det_dedup_threshold),
det_dedup_block_size=first_available(config.det_dedup_block_size, args.default_det_dedup_block_size),
det_location=first_available(config.det_location, args.default_det_location),
det_relation=first_available(config.det_relation, args.default_det_relation),
det_number=first_available(config.det_number, args.default_det_number),
retrieve_frame_only=first_available(config.retrieve_frame_only, args.default_retrieve_frame_only),
)
args = bentoml.use_arguments(DetectionArgs).override()
@vrag_service(args)
class DetectionService:
aks = bentoml.depends(AksBlipService)
mmdino = bentoml.depends(MMDINOService)
def __init__(self) -> None:
self._cacher = get_cacher(args.detection_cache_size)
self._image_hasher = ImageHasher.with_cacher(self._cacher)
logger.info("DetectionService initialized.")
@bentoml.api
async def detect(
self, query: Query, frame_extraction: FrameExtraction, config: Optional[DetectionServiceConfig] = None
) -> DetectionResult:
merged_config = DetectionServiceConfig.merge_config(config)
scene_desc = query.access_scene_desc or query.access_filtered_targets
if not merged_config.use_det and (scene_desc is None or len(scene_desc) == 0):
return DetectionResult()
if query.det and query.det.scene_occurrence_count:
msg = f"Only select {query.det.scene_occurrence_count} key frames as the query specified."
logger.info(msg)
merged_config.aks.target_frame_count = query.det.scene_occurrence_count
merged_config.det_number = query.det.num
merged_config.det_relation = query.det.rel
merged_config.det_location = query.det.loc
dedup_indices = list(range(len(frame_extraction.frame_timestamps)))
if merged_config.det_dedup_frames:
msg = (
f"Applying deduplication: threshold={merged_config.det_dedup_threshold}, "
f"block_size={merged_config.det_dedup_block_size}."
)
logger.debug(msg)
dedup_indices = await self._image_hasher.get_unique_frame_indices_async(
frame_extraction.frames,
merged_config.det_dedup_threshold,
merged_config.det_dedup_block_size,
)
frame_extraction = frame_extraction.slice(dedup_indices)
det_top_idx = (
await retry_async_request(
lambda: self.aks.select_keyframes(
frames=frame_extraction.frames, queries=scene_desc, config=merged_config.aks
),
"detection_aks_sample",
)
if scene_desc
else []
)
msg = f"Selected {len(det_top_idx)} keyframes: {det_top_idx}"
logger.debug(msg)
if not det_top_idx:
logger.warning("Not select any keyframes, returning empty result.")
return DetectionResult()
selected_frames = [frame_extraction.frames[i] for i in det_top_idx]
timestamps = [frame_extraction.frame_timestamps[i] for i in det_top_idx]
det_docs: List[DETDoc] = []
if not merged_config.retrieve_frame_only and query.access_filtered_targets:
results: List[str] = await self._describe_scenes_inner(
frames=selected_frames, prompt=query.access_filtered_targets, config=merged_config
)
det_docs = list(zip(results, timestamps, strict=True))
return DetectionResult(det_docs=det_docs, det_top_idx=indexed(dedup_indices, det_top_idx))
async def _describe_scenes_inner(
self, frames: List[np.ndarray], prompt: List[str], config: Optional[DetectionServiceConfig] = None
) -> List[str]:
if len(frames) == 0:
return []
key = _get_cache_key(
prompt, config.det_location, config.det_relation, config.det_number, config.mmdino.mmdino_threshold
)
@self._cacher.cached_with(lambda *_, **__: key)
async def _describe(frames: List[np.ndarray], location: bool, relation: bool, number: bool) -> List[str]:
detection_results: MMDINODetectionBatchResult = await retry_async_request(
lambda: self.mmdino.ov_detect(frames, prompt, config.mmdino), "detection_mmdino_detect"
)
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
return [
SceneDescriber.from_detection_result(
det_result, frame_width, frame_height
).generate_scene_graph_description(location_desc=location, relation_desc=relation, number_desc=number)
for det_result in detection_results.results
]
return await _describe(frames, config.det_location, config.det_relation, config.det_number)
def _get_cache_key(prompts: List[str], location: bool, relation: bool, number: bool, threshold: float, *_, **__):
return f"{''.join(sorted(prompts))}{location}{relation}{number}{str(threshold)}"