import sys
import cv2
import numpy as np
from PyQt5.QtWidgets import (QApplication, QMainWindow, QLabel, QPushButton,
QVBoxLayout, QHBoxLayout, QWidget, QFileDialog,
QGridLayout, QGroupBox, QSlider, QMessageBox, QCheckBox)
from PyQt5.QtCore import Qt, QPoint, pyqtSignal
from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor
from dt_backend import Detector, InferenceThread
from tracker import ByteTrackHandler
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
class ROIDisplayLabel(QLabel):
"""带ROI绘制功能的显示标签"""
roi_selected = pyqtSignal(tuple)
def __init__(self, parent=None):
super().__init__(parent)
self.is_drawing = False
self.start_point = QPoint()
self.end_point = QPoint()
self.roi_rect = None
self.temp_pixmap = None
self.base_pixmap = None
self.draw_mode = False
self.original_size = (0, 0)
self.displayed_rect = None
def mousePressEvent(self, event):
if self.draw_mode and event.button() == Qt.LeftButton:
self.is_drawing = True
self.start_point = event.pos()
self.end_point = self.start_point
self.update()
def mouseMoveEvent(self, event):
if self.is_drawing and self.draw_mode:
self.end_point = event.pos()
self.update()
def mouseReleaseEvent(self, event):
if self.is_drawing and self.draw_mode and event.button() == Qt.LeftButton:
self.is_drawing = False
self.end_point = event.pos()
x1 = min(self.start_point.x(), self.end_point.x())
y1 = min(self.start_point.y(), self.end_point.y())
x2 = max(self.start_point.x(), self.end_point.x())
y2 = max(self.start_point.y(), self.end_point.y())
self.roi_rect = (x1, y1, x2, y2)
self.roi_selected.emit(self.roi_rect)
self.update()
def paintEvent(self, event):
super().paintEvent(event)
if self.draw_mode and self.is_drawing and self.base_pixmap:
painter = QPainter(self)
painter.setPen(QPen(QColor(0, 255, 255), 2, Qt.DashLine))
painter.drawRect(
min(self.start_point.x(), self.end_point.x()),
min(self.start_point.y(), self.end_point.y()),
abs(self.end_point.x() - self.start_point.x()),
abs(self.end_point.y() - self.start_point.y())
)
class MainWindow(QMainWindow):
"""主窗口:修复缺少on_conf_changed方法的问题"""
def __init__(self):
super().__init__()
self.resource_path = None
self.is_image = False
self.model_path = "yolo11n.pt"
self.inference_thread = None
self.detector = None
self.confidence = 0.5
self.iou_threshold = 0.7
self.use_tracking = False
self.tracker = None
self.track_thresh = 0.5
self.track_buffer = 30
self.use_roi = False
self.current_roi = None
self.last_video_position = 0
self.init_ui()
self.load_detector()
self.init_tracker()
def init_ui(self):
"""初始化UI"""
self.setWindowTitle("目标检测与多目标跟踪系统")
self.setGeometry(100, 100, 1200, 800)
main_widget = QWidget()
main_layout = QVBoxLayout(main_widget)
self.display_label = ROIDisplayLabel("请加载视频或图片")
self.display_label.setAlignment(Qt.AlignCenter)
self.display_label.setStyleSheet("background-color: #222; color: #aaa; font-size: 14px;")
self.display_label.roi_selected.connect(self.on_roi_selected)
main_layout.addWidget(self.display_label, 7)
params_group = QGroupBox("检测与跟踪参数调节")
params_layout = QGridLayout()
self.conf_label = QLabel(f"置信度阈值: {self.confidence:.2f}")
self.conf_slider = QSlider(Qt.Horizontal)
self.conf_slider.setRange(1, 99)
self.conf_slider.setValue(int(self.confidence * 100))
self.conf_slider.valueChanged.connect(self.on_conf_changed)
params_layout.addWidget(QLabel("置信度阈值:"), 0, 0)
params_layout.addWidget(self.conf_slider, 0, 1)
params_layout.addWidget(self.conf_label, 0, 2)
self.track_thresh_label = QLabel(f"跟踪置信度阈值: {self.track_thresh:.2f}")
self.track_thresh_slider = QSlider(Qt.Horizontal)
self.track_thresh_slider.setRange(1, 99)
self.track_thresh_slider.setValue(int(self.track_thresh * 100))
self.track_thresh_slider.valueChanged.connect(self.on_track_thresh_changed)
params_layout.addWidget(QLabel("跟踪置信度阈值:"), 1, 0)
params_layout.addWidget(self.track_thresh_slider, 1, 1)
params_layout.addWidget(self.track_thresh_label, 1, 2)
params_group.setLayout(params_layout)
main_layout.addWidget(params_group)
control_layout = QHBoxLayout()
self.enable_roi_btn = QPushButton("启用ROI绘制")
self.enable_roi_btn.clicked.connect(self.enable_roi_drawing)
self.clear_roi_btn = QPushButton("清除ROI")
self.clear_roi_btn.clicked.connect(self.clear_roi)
self.clear_roi_btn.setEnabled(False)
self.use_roi_checkbox = QPushButton("使用ROI检测")
self.use_roi_checkbox.setCheckable(True)
self.use_roi_checkbox.clicked.connect(self.toggle_roi_usage)
self.use_roi_checkbox.setEnabled(False)
self.tracking_checkbox = QCheckBox("启用多目标跟踪")
self.tracking_checkbox.stateChanged.connect(self.toggle_tracking)
self.tracking_status_label = QLabel("跟踪状态: 未启用")
self.tracking_status_label.setStyleSheet("color: #64748b;")
control_layout.addWidget(self.enable_roi_btn)
control_layout.addWidget(self.clear_roi_btn)
control_layout.addWidget(self.use_roi_checkbox)
control_layout.addWidget(self.tracking_checkbox)
control_layout.addWidget(self.tracking_status_label)
main_layout.addLayout(control_layout)
model_and_resource_layout = QVBoxLayout()
model_layout = QHBoxLayout()
self.select_model_btn = QPushButton("选择模型文件")
self.select_model_btn.clicked.connect(self.select_model)
self.current_model_label = QLabel(f"当前模型: {os.path.basename(self.model_path)}")
self.current_model_label.setStyleSheet("color: #333; font-style: italic;")
model_layout.addWidget(self.select_model_btn)
model_layout.addWidget(self.current_model_label, 1)
main_layout.addLayout(model_layout)
load_layout = QHBoxLayout()
self.load_image_btn = QPushButton("加载图片")
self.load_image_btn.clicked.connect(lambda: self.load_resource(True))
self.load_video_btn = QPushButton("加载视频")
self.load_video_btn.clicked.connect(lambda: self.load_resource(False))
self.start_btn = QPushButton("开始处理")
self.start_btn.clicked.connect(self.start_processing)
self.pause_btn = QPushButton("暂停")
self.pause_btn.clicked.connect(self.pause_processing)
self.pause_btn.setEnabled(False)
self.stop_btn = QPushButton("停止")
self.stop_btn.clicked.connect(self.stop_processing)
self.stop_btn.setEnabled(False)
load_layout.addWidget(self.load_image_btn)
load_layout.addWidget(self.load_video_btn)
load_layout.addWidget(self.start_btn)
load_layout.addWidget(self.pause_btn)
load_layout.addWidget(self.stop_btn)
main_layout.addLayout(load_layout)
self.status_label = QLabel("就绪")
main_layout.addWidget(self.status_label)
self.setCentralWidget(main_widget)
def on_conf_changed(self):
"""处理置信度滑块值变化,更新检测参数"""
self.confidence = self.conf_slider.value() / 100.0
self.conf_label.setText(f"置信度阈值: {self.confidence:.2f}")
if self.inference_thread and self.inference_thread.isRunning():
self.inference_thread.set_parameters(self.confidence, self.iou_threshold)
elif self.detector:
self.detector.set_parameters(self.confidence, self.iou_threshold)
def on_track_thresh_changed(self):
"""处理跟踪置信度阈值变化"""
self.track_thresh = self.track_thresh_slider.value() / 100.0
self.track_thresh_label.setText(f"跟踪置信度阈值: {self.track_thresh:.2f}")
if self.tracker:
self.tracker.track_thresh = self.track_thresh
def select_model(self):
"""选择模型文件并更新加载"""
self.stop_processing()
file_path, _ = QFileDialog.getOpenFileName(
self, "选择模型文件", "", "PyTorch模型 (*.pt);;所有文件 (*)"
)
if file_path:
self.model_path = file_path
self.current_model_label.setText(f"当前模型: {os.path.basename(file_path)}")
self.load_detector()
def load_detector(self):
"""加载检测器模型(更新状态提示)"""
try:
self.status_label.setText(f"正在加载模型: {os.path.basename(self.model_path)}...")
self.detector = Detector(self.model_path)
self.status_label.setText(f"模型加载成功: {os.path.basename(self.model_path)}")
except Exception as e:
self.detector = None
error_msg = f"模型加载失败: {str(e)}"
self.status_label.setText(error_msg)
QMessageBox.critical(self, "模型错误", error_msg)
def on_roi_selected(self, roi_rect):
"""处理ROI选择"""
self.current_roi = roi_rect
self.clear_roi_btn.setEnabled(True)
self.use_roi_checkbox.setEnabled(True)
if self.tracker:
self.tracker.set_roi(self.use_roi, roi_rect)
def toggle_roi_usage(self, checked):
"""切换ROI使用状态"""
self.use_roi = checked
if self.inference_thread and self.inference_thread.isRunning():
self.inference_thread.set_roi(checked, self.current_roi)
if self.tracker:
self.tracker.set_roi(checked, self.current_roi)
def enable_roi_drawing(self):
"""启用ROI绘制模式"""
self.display_label.draw_mode = not self.display_label.draw_mode
self.enable_roi_btn.setText("禁用ROI绘制" if self.display_label.draw_mode else "启用ROI绘制")
def clear_roi(self):
"""清除ROI区域"""
self.current_roi = None
self.display_label.roi_rect = None
self.display_label.update()
self.clear_roi_btn.setEnabled(False)
self.use_roi_checkbox.setChecked(False)
self.use_roi = False
if self.inference_thread and self.inference_thread.isRunning():
self.inference_thread.set_roi(False, None)
if self.tracker:
self.tracker.reset()
self.tracker.set_roi(False, None)
def load_resource(self, is_image):
"""加载图片或视频资源"""
self.stop_processing()
self.is_image = is_image
file_filter = "图片文件 (*.jpg *.jpeg *.png *.bmp)" if is_image else "视频文件 (*.mp4 *.avi *.mov *.mkv)"
file_path, _ = QFileDialog.getOpenFileName(self, "选择文件", "", file_filter)
if file_path:
self.resource_path = file_path
self.status_label.setText(f"已加载: {os.path.basename(file_path)}")
if is_image:
frame = cv2.imread(file_path)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h, w, c = frame.shape
qimg = QImage(frame.data, w, h, w * c, QImage.Format_RGB888)
self.display_label.setPixmap(QPixmap.fromImage(qimg).scaled(
self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
else:
self.display_label.setText(f"已加载视频: {os.path.basename(file_path)}")
def load_detector(self):
"""加载检测器模型"""
try:
self.detector = Detector(self.model_path)
self.status_label.setText("模型加载成功")
except Exception as e:
self.status_label.setText(f"模型加载失败: {str(e)}")
QMessageBox.critical(self, "模型错误", f"无法加载模型: {str(e)}")
def init_tracker(self):
"""初始化跟踪器"""
self.tracker = ByteTrackHandler(
track_thresh=self.track_thresh,
track_buffer=self.track_buffer,
use_roi=self.use_roi,
roi_rect=self.current_roi
)
def toggle_tracking(self, state):
"""修复:正确切换跟踪状态并传递跟踪器"""
self.use_tracking = (state == Qt.Checked)
self.tracking_status_label.setText(f"跟踪状态: {'已启用' if self.use_tracking else '未启用'}")
if self.inference_thread and self.inference_thread.isRunning():
self.inference_thread.set_tracking(
self.use_tracking,
self.tracker if self.use_tracking else None
)
status = "已启用" if self.use_tracking else "已禁用"
self.status_label.setText(f"多目标跟踪{status}")
else:
self.status_label.setText(f"多目标跟踪{('已启用', '已禁用')[not self.use_tracking]}(需重新开始处理生效)")
def start_processing(self):
"""修复:启动处理时正确传递跟踪器"""
if not self.resource_path or not self.detector:
QMessageBox.warning(self, "错误", "请先加载模型和资源")
return
self.stop_processing()
self.inference_thread = InferenceThread(
self.resource_path, self.detector, self.is_image
)
self.inference_thread.set_tracking(
self.use_tracking,
self.tracker if self.use_tracking else None
)
self.inference_thread.update_frame_signal.connect(self.update_display)
self.inference_thread.update_stats_signal.connect(self.update_stats)
self.inference_thread.process_finished_signal.connect(self.process_finished)
self.inference_thread.error_occurred_signal.connect(self.show_error)
self.inference_thread.frame_position_updated.connect(self.update_frame_position)
self.inference_thread.start()
self.start_btn.setEnabled(False)
self.pause_btn.setEnabled(not self.is_image)
self.stop_btn.setEnabled(True)
mode = "跟踪模式" if self.use_tracking else "检测模式"
self.status_label.setText(f"开始处理 - {mode}")
def update_display(self, frame):
"""更新显示帧"""
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h, w, c = frame_rgb.shape
qimg = QImage(frame_rgb.data, w, h, w * c, QImage.Format_RGB888)
self.display_label.setPixmap(QPixmap.fromImage(qimg).scaled(
self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
def update_stats(self, person_count, vehicle_count):
"""更新统计信息"""
self.status_label.setText(f"行人: {person_count}, 车辆: {vehicle_count}")
def pause_processing(self):
"""暂停/恢复处理"""
if self.inference_thread and self.inference_thread.isRunning():
self.inference_thread.pause()
self.pause_btn.setText("恢复" if self.inference_thread.paused else "暂停")
def stop_processing(self):
"""停止处理"""
if self.inference_thread and (self.inference_thread.isRunning() or self.inference_thread.paused):
self.inference_thread.stop()
self.start_btn.setEnabled(True)
self.pause_btn.setEnabled(False)
self.pause_btn.setText("暂停")
self.stop_btn.setEnabled(False)
def process_finished(self):
"""处理完成回调"""
self.start_btn.setEnabled(True)
self.pause_btn.setEnabled(False)
self.stop_btn.setEnabled(False)
self.status_label.setText("处理完成")
def show_error(self, message):
"""显示错误信息"""
self.status_label.setText(f"错误: {message}")
QMessageBox.warning(self, "处理错误", message)
def update_frame_position(self, pos):
"""更新帧位置"""
self.current_frame_pos = pos
def resizeEvent(self, event):
"""窗口大小改变时重绘"""
if self.resource_path and self.is_image and self.display_label.base_pixmap:
self.display_label.setPixmap(self.display_label.base_pixmap.scaled(
self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
super().resizeEvent(event)
if __name__ == "__main__":
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())