import cv2
import numpy as np
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QMutex, QWaitCondition
from ultralytics import YOLO
class Detector:
"""目标检测算法类:封装检测逻辑,支持与跟踪器协同工作"""
def __init__(self, model_path):
self.model = self._load_model(model_path)
self.conf = 0.5
self.iou = 0.5
self.person_class = 0
self.vehicle_classes = {2, 3, 5, 7, 18}
self.use_roi = False
self.roi_rect = None
def _load_model(self, model_path):
"""加载YOLO模型"""
try:
return YOLO(model_path)
except Exception as e:
raise ValueError(f"模型加载失败: {str(e)}")
def set_parameters(self, conf, iou):
"""更新检测参数"""
self.conf = conf
self.iou = iou
def set_roi(self, use_roi, roi_rect=None):
"""设置ROI参数"""
self.use_roi = use_roi
self.roi_rect = roi_rect
def detect_raw(self, frame):
"""
执行原始检测,返回未处理的检测结果(用于跟踪)
返回: 检测框数组[N,5] (x1,y1,x2,y2,score) 和类别ID数组[N]
"""
original_frame = frame.copy()
roi_offset = (0, 0)
if self.use_roi and self.roi_rect:
x1, y1, x2, y2 = self.roi_rect
x1 = max(0, min(x1, frame.shape[1]))
y1 = max(0, min(y1, frame.shape[0]))
x2 = max(x1, min(x2, frame.shape[1]))
y2 = max(y1, min(y2, frame.shape[0]))
frame = frame[y1:y2, x1:x2]
roi_offset = (x1, y1)
results = self.model(
frame,
classes=[self.person_class] + list(self.vehicle_classes),
conf=self.conf,
iou=self.iou,
stream=False
)
detections = []
class_ids = []
for result in results:
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
classes = result.boxes.cls.cpu().numpy().astype(int)
for box, score, cls_id in zip(boxes, scores, classes):
if cls_id in self.vehicle_classes:
x1, y1, x2, y2 = box
x1 += roi_offset[0]
y1 += roi_offset[1]
x2 += roi_offset[0]
y2 += roi_offset[1]
detections.append([x1, y1, x2, y2, score])
class_ids.append(cls_id)
return np.array(detections), np.array(class_ids)
def draw_detections(self, frame, detections, class_ids):
"""绘制检测结果"""
processed_frame = frame.copy()
person_count = 0
vehicle_count = 0
for det, cls_id in zip(detections, class_ids):
x1, y1, x2, y2, score = det
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
if cls_id == self.person_class:
person_count += 1
color = (0, 255, 0)
label = f"person: {score:.2f}"
elif cls_id in self.vehicle_classes:
vehicle_count += 1
color = (255, 0, 0)
label = f"vehicle: {score:.2f}"
else:
continue
cv2.rectangle(processed_frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(processed_frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
if self.use_roi and self.roi_rect:
rx1, ry1, rx2, ry2 = self.roi_rect
cv2.rectangle(processed_frame, (rx1, ry1), (rx2, ry2), (0, 255, 255), 2)
cv2.putText(processed_frame, "ROI", (rx1, ry1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
return processed_frame, person_count, vehicle_count
def draw_tracked_results(self, frame, tracked_results):
"""绘制跟踪结果(含跟踪ID)"""
processed_frame = frame.copy()
person_count = 0
vehicle_count = 0
for track in tracked_results:
x1, y1, x2, y2, track_id, cls_id, score = track
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
if cls_id == self.person_class:
person_count += 1
color = (0, 255, 0)
label = f"person {track_id}: {score:.2f}"
elif cls_id in self.vehicle_classes:
vehicle_count += 1
color = (255, 0, 0)
label = f"vehicle {track_id}: {score:.2f}"
else:
continue
cv2.rectangle(processed_frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(processed_frame, f"ID: {int(track_id)}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
if self.use_roi and self.roi_rect:
rx1, ry1, rx2, ry2 = self.roi_rect
cv2.rectangle(processed_frame, (rx1, ry1), (rx2, ry2), (0, 255, 255), 2)
cv2.putText(processed_frame, "ROI", (rx1, ry1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
return processed_frame, person_count, vehicle_count
class InferenceThread(QThread):
"""支持多目标跟踪的推理线程"""
update_frame_signal = pyqtSignal(object)
update_stats_signal = pyqtSignal(int, int)
process_finished_signal = pyqtSignal()
error_occurred_signal = pyqtSignal(str)
frame_position_updated = pyqtSignal(int)
def __init__(self, path, detector, is_image=False, use_tracking=False, tracker=None):
super().__init__()
self.path = path
self.detector = detector
self.is_image = is_image
self.use_tracking = use_tracking
self.tracker = tracker
self.running = False
self.paused = False
self.mutex = QMutex()
self.cond = QWaitCondition()
self.cap = None
self.current_frame_pos = 0
self.total_frames = 0
def set_parameters(self, conf, iou):
"""更新检测参数(线程安全)"""
self.mutex.lock()
if self.detector:
self.detector.set_parameters(conf, iou)
self.mutex.unlock()
def set_roi(self, use_roi, roi_rect=None):
"""设置ROI参数(线程安全)"""
self.mutex.lock()
if self.detector:
self.detector.set_roi(use_roi, roi_rect)
if self.tracker:
self.tracker.use_roi = use_roi
self.tracker.roi_rect = roi_rect
self.mutex.unlock()
def set_tracking(self, use_tracking, tracker=None):
"""修复:线程安全地设置跟踪状态和跟踪器"""
self.mutex.lock()
self._use_tracking = use_tracking
self._tracker = tracker
self.mutex.unlock()
def get_current_position(self):
"""获取当前视频位置"""
self.mutex.lock()
pos = self.current_frame_pos
self.mutex.unlock()
return pos
def set_start_position(self, pos):
"""设置视频起始位置"""
self.mutex.lock()
self.current_frame_pos = pos
self.mutex.unlock()
def run(self):
"""修复:确保跟踪状态正确应用"""
try:
self.mutex.lock()
self.running = True
self.paused = False
self.mutex.unlock()
if self.is_image:
self._process_image()
else:
self._process_video()
self.process_finished_signal.emit()
except Exception as e:
self.error_occurred_signal.emit(f"处理错误: {str(e)}")
finally:
self._cleanup()
def _process_image(self):
"""处理单张图片"""
frame = cv2.imread(self.path)
if frame is None:
self.error_occurred_signal.emit("无法加载图片文件")
return
frame = cv2.resize(frame, (1280, 720))
self.mutex.lock()
use_tracking = self.use_tracking
tracker = self.tracker
self.mutex.unlock()
detections, class_ids = self.detector.detect_raw(frame)
if use_tracking and tracker and len(detections) > 0:
tracked_results = tracker.update(detections, class_ids)
processed_frame, person_count, vehicle_count = self.detector.draw_tracked_results(
frame, tracked_results)
else:
processed_frame, person_count, vehicle_count = self.detector.draw_detections(
frame, detections, class_ids)
self.update_frame_signal.emit(processed_frame)
self.update_stats_signal.emit(person_count, vehicle_count)
def _process_video(self):
"""修复:确保每帧都检查最新的跟踪状态"""
self.cap = cv2.VideoCapture(self.path)
if not self.cap.isOpened():
self.error_occurred_signal.emit("无法打开视频文件")
return
while True:
self.mutex.lock()
if not self.running:
self.mutex.unlock()
break
if self.paused:
self.cond.wait(self.mutex)
current_use_tracking = self._use_tracking
current_tracker = self._tracker if current_use_tracking else None
self.mutex.unlock()
ret, frame = self.cap.read()
if not ret:
break
frame = cv2.resize(frame, (1280, 720))
detections, class_ids = self.detector.detect_raw(frame)
if current_use_tracking and current_tracker is not None:
tracked_results = current_tracker.update(detections, class_ids)
processed_frame, p_count, v_count = self.detector.draw_tracked_results(
frame, tracked_results)
else:
processed_frame, p_count, v_count = self.detector.draw_detections(
frame, detections, class_ids)
self.update_frame_signal.emit(processed_frame)
self.update_stats_signal.emit(p_count, v_count)
self.msleep(33)
self.cap.release()
def _cleanup(self):
"""资源清理"""
self.mutex.lock()
if self.cap and not self.is_image and self.cap.isOpened():
self.current_frame_pos = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
self.cap.release()
self.running = False
self.paused = False
self.mutex.unlock()
def pause(self):
"""暂停/恢复线程"""
self.mutex.lock()
self.paused = not self.paused
if not self.paused:
self.cond.wakeAll()
self.mutex.unlock()
def stop(self):
"""停止线程"""
self.mutex.lock()
if self.cap and not self.is_image and self.cap.isOpened():
self.current_frame_pos = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
self.running = False
self.paused = False
self.cond.wakeAll()
self.mutex.unlock()
if not self.wait(1000):
self.terminate()
self.wait()