import json
from PIL import Image
import cv2
import numpy as np
from openai import OpenAI
from mm.core.frame_selector.frame_selector import KRangFrameSelector
from common import imgs_to_base64_list, create_messages, validate_qa_inputs, extract_frames_from_video, send_messages
REWRITE_SYSTEM_PROMPT = """<no think>. 你是一个视频处理助手。你的任务是分析用户查询并选择适当的视频采样策略,以严格的JSON格式输出结果。
## 采样策略决策:
- 当查询中提到特定对象、动作或内容时:使用基于关键词的采样,函数名为"query_keyframe"
- 当查询指定时间范围或均匀覆盖时:使用固定步长采样,函数名为"uniform_sample"
- 如果查询中同时包含对象和时间范围,优先使用带有时间参数的"query_keyframe"
## 查询解析指南:
- 对于"function_name",根据上述决策进行设置
- 在"params"中:
- "query":对于"query_keyframe",应包含一个重新分析的查询字符串。对于"uniform_sample",应设置为空字符串
- "start_time":根据你的查询,你认为答案可能出现在的时间区间的开始时间,解析为整数(秒);否则为null
- "end_time":根据你的查询,你认为答案可能出现在的时间区间的结束时间,解析为整数(秒);否则为null
- "sample_nums":设置采样帧数。对于"query_keyframe"使用16;对于"uniform_sample"使用24
## 输出格式:
严格返回以下JSON对象,不包含任何额外内容:
{
"function": {
"function_name": "字符串,'均匀采样'或'查询关键帧'",
"params": {
"query": "字符串,重新分析的查询字符串,用于query_keyframe,均匀采样时为空,使用中文输出。",
"start_time": "整数或null",
"end_time": "整数或null",
"sample_nums": "整数"
}
}
}
## 示例:
用户:"红色小轿车是什么时候出现和消失的?"
→ {
"function": {
"function_name": "query_keyframe",
"params": {
"query": "红色小轿车",
"start_time": null,
"end_time": null,
"sample_nums": 16
}
}
}
用户:"在30秒到120秒之间均匀采样"
→ {
"function": {
"function_name": "uniform_sample",
"params": {
"query": "",
"start_time": 30,
"end_time": 120,
"sample_nums": 24
}
}
}
"""
QA_SYSTEM_PROMPT = """你是一个严格遵守输出格式的视频分析工具。请根据从视频里选择的和问题相关的帧,严格按照以下要求来回答用户的问题。
## 核心原则
1. 时间来源唯一性:所有时间信息必须且只能基于"帧索引-时间映射表"。禁止使用视频内显示的任何时间戳、时钟或字幕时间。
2. 基于实际内容:仅分析提供的视频帧描述内容,不推测未显示的内容。
3. 严格匹配用户问题:只分析与用户问题直接相关的事件或物体。如果视频中存在相似但不是问题所指的物体,不应进行分析。
4. 全面分析:仔细分析所有帧,找到所有和问题相关的帧。
## 帧索引-时间映射表(唯一时间源)
映射表格式:第X帧: mm:ss
映射表内容:
{frame_indices_times}
## 时间处理规范
- 格式要求:所有时间必须格式化为"mm:ss"(分钟和秒不足两位时补零)
- 起止确定:
- 开始时间:物体/事件首次可辨识(即使模糊)的帧对应时间
- 结束时间:物体/事件最后一次明确可见的帧对应时间
## 输出格式
- 必须输出有效的JSON数组,无任何额外文本
- 每个对象包含且仅包含:
{{
"start_time": "mm:ss",
"end_time": "mm:ss"
}}
- 找到匹配事件:按时间顺序返回所有时间片段
- 未找到事件:返回空数组 []
"""
class RangeDetectionQaDemo:
def __init__(
self,
model_path: str,
device_list: list,
similar_threshold: float = 0.03,
similar_threshold_image: float = 0.015,
vlm_url: str = None,
api_key: str = "NONE",
vlm_model_name: str = None,
model_type: str = 'cn_clip',
):
self.frame_selector = KRangFrameSelector(
model_path, device_list[0], model_type, similar_threshold, similar_threshold_image
)
self.vlm_client = OpenAI(base_url=vlm_url, api_key=api_key)
self.vlm_model_name = vlm_model_name
self._frame_cache = {}
self._supported_functions = {
'uniform_sample': self._execute_uniform_sample,
'query_keyframe': self._execute_query_keyframe,
}
@staticmethod
def _get_frame_times(frame_indices):
return [f"第{i}帧:{idx // 60:02d}:{idx % 60:02d}" for i, idx in enumerate(frame_indices, 1)]
@staticmethod
def _make_time_valid(start_time, end_time, frame_len):
start_time = start_time if start_time is not None else 0
end_time = min(end_time if end_time is not None else frame_len - 1, frame_len - 1)
if start_time == end_time:
start_time, end_time = 0, frame_len - 1
return int(start_time), int(end_time)
@staticmethod
def _uniform_sample(start_time, end_time, frames, sample_num):
frame_indices = np.linspace(start_time, end_time, sample_num, dtype=int)
return frame_indices, [frames[i] for i in frame_indices]
def qa(self, query: str, video_path: str, sample_num: int):
validate_qa_inputs(video_path, query, sample_num)
frames, fps, total_frames = extract_frames_from_video(video_path, cache=self._frame_cache)
sorted_frame_indices, selected_frames = self._select_frames(query, frames, fps, total_frames, sample_num)
frame_indices_times = self._get_frame_times(sorted_frame_indices)
if len(selected_frames) >= 10:
selected_frames = [cv2.resize(frame, (384, 384)) for frame in selected_frames]
selected_frames = [Image.fromarray(frame) for frame in selected_frames]
system_prompt = {
"role": "system",
"content": QA_SYSTEM_PROMPT.format(
frame_indices_times="\n".join(f"- {item}" for item in frame_indices_times)
),
}
q = (
f"仔细分析视频内容,然后找到与问题相关的视频帧。在用户没有明确要求分析视频中出现的时间戳的情况下,"
f"禁止使用视频中出现的时间戳来作为问题的答案, 必须根据时间戳映射表来进行回答。你需要回答的问题如下:{query}"
)
base64_frames = imgs_to_base64_list(selected_frames)
user_messages = create_messages(q, base64_frames)
return send_messages(self.vlm_client, self.vlm_model_name, [system_prompt, user_messages])
def _select_frames(self, query, frames, fps, total_frames, sample_num):
rewrite_messages = [
{"role": "system", "content": REWRITE_SYSTEM_PROMPT},
{
"role": "user",
"content": f"根据视频的文字记录,提取与查询项相关的关键词和短语, 你需要提取的查询如下:{query}。"
f"考虑视频的上下文,包括其时长为{total_frames}和帧率为{fps},"
f"以识别出最显著的术语、实体和概念。优先考虑与查询项的相关性,并提供一份简洁的列表。",
},
]
new_query = send_messages(self.vlm_client, self.vlm_model_name, rewrite_messages, max_tokens=512)
return self._execute(new_query, frames, sample_num)
def _execute(self, model_response, frames, sample_num):
try:
rewrite_query = json.loads(model_response)
function_name = rewrite_query["function"]["function_name"]
params = rewrite_query["function"]["params"]
except Exception:
function_name, params = "uniform_sample", {}
handler = self._supported_functions.get(function_name, self._execute_uniform_sample)
return handler(frames, params, sample_num)
@staticmethod
def _execute_uniform_sample(frames, params, sample_num):
start_time = params.get("start_time", 0)
end_time = params.get("end_time", len(frames) - 1)
start_time, end_time = RangeDetectionQaDemo._make_time_valid(start_time, end_time, len(frames))
return RangeDetectionQaDemo._uniform_sample(start_time, end_time, frames, sample_num)
def _execute_query_keyframe(self, frames, params, sample_num):
query = params["query"]
start_time = params.get("start_time", 0)
end_time = params.get("end_time", len(frames) - 1)
start_time, end_time = self._make_time_valid(start_time, end_time, len(frames))
return self.frame_selector.select_keyframes(query, frames[int(start_time) : int(end_time)], sample_num)
def k_range_frame_selector_example():
demo = RangeDetectionQaDemo(
model_path="/path/to/chinese-clip-vit-large-patch14-336px",
device_list=[0],
vlm_url="http://127.0.0.1:8111/v1",
api_key="None",
vlm_model_name="Qwen2.5-VL-32B-Instruct",
model_type='cn_clip',
)
video_path = "/path/to/test_video.mp4"
query = "红色小轿车是什么时候出现和消失的"
sample_num = 16
answer = demo.qa(query, video_path, sample_num)
print(f"查询: {query}")
print(f"回答: {answer}")
if __name__ == "__main__":
k_range_frame_selector_example()