from typing import List, Optional
import bentoml
import numpy as np
import torch
from PIL import Image
from pydantic import Field
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from vrag.logger import logger
from vrag.shared import ArgsBase, ConfigBase, first_available, into_u8_frames, vrag_service
from vrag.tools.np_cacher import get_cacher
from vrag.types import MMDINODetectionBatchResult, MMDINODetectionItem, MMDINODetectionResult
from vrag.tools.path_validator import validate_dir_exists
class MMDinoArgs(ArgsBase):
"""MMDINO object detection configuration"""
mmdino_model_path: str = ""
"""Local path to the MMDINO model directory."""
mmdino_device: str = "npu:2"
"""Device for MMDINO model inference, e.g. 'npu:2' or 'cpu'."""
mmdino_batch_size: int = Field(8, ge=1)
"""Batch size for MMDINO object detection inference."""
mmdino_cache_size: int = Field(4096, ge=0)
"""LRU cache capacity for MMDINO detection results."""
default_mmdino_threshold: float = Field(0.43, ge=0.0, le=1.0)
"""Default confidence threshold for object detection."""
class MMDinoConfig(ConfigBase):
"""MMDINO object detection configuration for request-level"""
mmdino_threshold: Optional[float] = None
@staticmethod
def merge_config(config: Optional["MMDinoConfig"]) -> "MMDinoConfig":
if config is None:
return MMDinoConfig(mmdino_threshold=args.default_mmdino_threshold)
return MMDinoConfig(mmdino_threshold=first_available(config.mmdino_threshold, args.default_mmdino_threshold))
args = bentoml.use_arguments(MMDinoArgs).override()
def _get_cache_key(prompt: List[str], threshold: float, *_, **__) -> str:
return "".join(sorted(prompt)) + str(threshold)
def _format_detection_prompt(prompts: List[str]) -> str:
return ". ".join(prompt.lower().strip() for prompt in prompts) + "."
@vrag_service(args)
class MMDINOService:
def __init__(self):
self.model_path = validate_dir_exists(args.mmdino_model_path, "MMDINO model")
self.processor = AutoProcessor.from_pretrained(self.model_path, local_files_only=True)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
self.model_path, device_map=args.mmdino_device, local_files_only=True
)
self.model.eval()
self._cacher = get_cacher(args.mmdino_cache_size)
msg = f"MMDINOService initialized from {self.model_path}."
logger.info(msg)
@bentoml.api
async def ov_detect(
self, frames: List[np.ndarray], prompts: List[str], config: Optional[MMDinoConfig] = None
) -> MMDINODetectionBatchResult:
merged_config = MMDinoConfig.merge_config(config)
key = _get_cache_key(prompts, merged_config.mmdino_threshold)
@self._cacher.cached_sync_with(lambda *_, **__: key)
def _detect(frames: List[np.ndarray]) -> List[MMDINODetectionResult]:
return self._detect_raw(frames, prompts, args.mmdino_batch_size, merged_config.mmdino_threshold)
if not frames:
return MMDINODetectionBatchResult()
return MMDINODetectionBatchResult(results=_detect(frames))
def _detect_raw(
self, frames: List[np.ndarray], prompts: List[str], batch_size: int, threshold: float = 0.35
) -> List[MMDINODetectionResult]:
pil_images = [Image.fromarray(into_u8_frames(frame)) for frame in frames]
all_res: List[MMDINODetectionResult] = []
frames_num = len(frames)
text_prompt = _format_detection_prompt(prompts)
for i in range(0, frames_num, batch_size):
batch_pil_images = pil_images[i : i + batch_size]
current_batch_size = len(batch_pil_images)
msg = f"MMDINO detect batch [{i}:{i + current_batch_size}]"
logger.debug(msg)
inputs = self.processor(
images=batch_pil_images, text=[text_prompt] * current_batch_size, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)
target_sizes = [(img.height, img.width) for img in batch_pil_images]
result_list = self.processor.post_process_grounded_object_detection(
outputs, inputs.input_ids, threshold=threshold, text_threshold=threshold, target_sizes=target_sizes
)
for result in result_list:
boxes = result["boxes"]
labels = result["labels"]
detection_result = MMDINODetectionResult(
items=[
MMDINODetectionItem(class_name=label, bbox=[int(c) for c in box.tolist()])
for box, label in zip(boxes, labels, strict=True)
]
)
all_res.append(detection_result)
return all_res