2d62036f创建于 2025年8月30日历史提交
import sys

import cv2

import numpy as np

from PyQt5.QtWidgets import (QApplication, QMainWindow, QLabel, QPushButton,

                             QVBoxLayout, QHBoxLayout, QWidget, QFileDialog,

                             QGridLayout, QGroupBox, QSlider, QMessageBox, QCheckBox)

from PyQt5.QtCore import Qt, QPoint, pyqtSignal

from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QColor

from dt_backend import Detector, InferenceThread

from tracker import ByteTrackHandler

import os

os.environ['KMP_DUPLICATE_LIB_OK']='True'



class ROIDisplayLabel(QLabel):

    """带ROI绘制功能的显示标签"""

    roi_selected = pyqtSignal(tuple)



    def __init__(self, parent=None):

        super().__init__(parent)

        self.is_drawing = False

        self.start_point = QPoint()

        self.end_point = QPoint()

        self.roi_rect = None

        self.temp_pixmap = None

        self.base_pixmap = None

        self.draw_mode = False

        self.original_size = (0, 0)

        self.displayed_rect = None



    def mousePressEvent(self, event):

        if self.draw_mode and event.button() == Qt.LeftButton:

            self.is_drawing = True

            self.start_point = event.pos()

            self.end_point = self.start_point

            self.update()



    def mouseMoveEvent(self, event):

        if self.is_drawing and self.draw_mode:

            self.end_point = event.pos()

            self.update()



    def mouseReleaseEvent(self, event):

        if self.is_drawing and self.draw_mode and event.button() == Qt.LeftButton:

            self.is_drawing = False

            self.end_point = event.pos()

            x1 = min(self.start_point.x(), self.end_point.x())

            y1 = min(self.start_point.y(), self.end_point.y())

            x2 = max(self.start_point.x(), self.end_point.x())

            y2 = max(self.start_point.y(), self.end_point.y())

            self.roi_rect = (x1, y1, x2, y2)

            self.roi_selected.emit(self.roi_rect)

            self.update()



    def paintEvent(self, event):

        super().paintEvent(event)

        if self.draw_mode and self.is_drawing and self.base_pixmap:

            painter = QPainter(self)

            painter.setPen(QPen(QColor(0, 255, 255), 2, Qt.DashLine))

            painter.drawRect(

                min(self.start_point.x(), self.end_point.x()),

                min(self.start_point.y(), self.end_point.y()),

                abs(self.end_point.x() - self.start_point.x()),

                abs(self.end_point.y() - self.start_point.y())

            )





class MainWindow(QMainWindow):

    """主窗口:修复缺少on_conf_changed方法的问题"""



    def __init__(self):

        super().__init__()

        self.resource_path = None

        self.is_image = False

        self.model_path = "yolo11n.pt"

        self.inference_thread = None

        self.detector = None

        self.confidence = 0.5  # 默认置信度

        self.iou_threshold = 0.7  # 默认IOU阈值



        # 跟踪相关变量

        self.use_tracking = False

        self.tracker = None

        self.track_thresh = 0.5

        self.track_buffer = 30



        # ROI相关变量

        self.use_roi = False

        self.current_roi = None

        self.last_video_position = 0



        self.init_ui()

        self.load_detector()

        self.init_tracker()



    def init_ui(self):

        """初始化UI"""

        self.setWindowTitle("目标检测与多目标跟踪系统")

        self.setGeometry(100, 100, 1200, 800)



        main_widget = QWidget()

        main_layout = QVBoxLayout(main_widget)



        # 显示区域

        self.display_label = ROIDisplayLabel("请加载视频或图片")

        self.display_label.setAlignment(Qt.AlignCenter)

        self.display_label.setStyleSheet("background-color: #222; color: #aaa; font-size: 14px;")

        self.display_label.roi_selected.connect(self.on_roi_selected)

        main_layout.addWidget(self.display_label, 7)



        # 参数调节区域

        params_group = QGroupBox("检测与跟踪参数调节")

        params_layout = QGridLayout()



        # 置信度调节

        self.conf_label = QLabel(f"置信度阈值: {self.confidence:.2f}")

        self.conf_slider = QSlider(Qt.Horizontal)

        self.conf_slider.setRange(1, 99)

        self.conf_slider.setValue(int(self.confidence * 100))

        self.conf_slider.valueChanged.connect(self.on_conf_changed)  # 信号连接

        params_layout.addWidget(QLabel("置信度阈值:"), 0, 0)

        params_layout.addWidget(self.conf_slider, 0, 1)

        params_layout.addWidget(self.conf_label, 0, 2)



        # 跟踪参数调节

        self.track_thresh_label = QLabel(f"跟踪置信度阈值: {self.track_thresh:.2f}")

        self.track_thresh_slider = QSlider(Qt.Horizontal)

        self.track_thresh_slider.setRange(1, 99)

        self.track_thresh_slider.setValue(int(self.track_thresh * 100))

        self.track_thresh_slider.valueChanged.connect(self.on_track_thresh_changed)

        params_layout.addWidget(QLabel("跟踪置信度阈值:"), 1, 0)

        params_layout.addWidget(self.track_thresh_slider, 1, 1)

        params_layout.addWidget(self.track_thresh_label, 1, 2)



        params_group.setLayout(params_layout)

        main_layout.addWidget(params_group)



        # 功能控制区域

        control_layout = QHBoxLayout()



        # ROI控制

        self.enable_roi_btn = QPushButton("启用ROI绘制")

        self.enable_roi_btn.clicked.connect(self.enable_roi_drawing)

        self.clear_roi_btn = QPushButton("清除ROI")

        self.clear_roi_btn.clicked.connect(self.clear_roi)

        self.clear_roi_btn.setEnabled(False)

        self.use_roi_checkbox = QPushButton("使用ROI检测")

        self.use_roi_checkbox.setCheckable(True)

        self.use_roi_checkbox.clicked.connect(self.toggle_roi_usage)

        self.use_roi_checkbox.setEnabled(False)



        # 跟踪开关

        self.tracking_checkbox = QCheckBox("启用多目标跟踪")

        self.tracking_checkbox.stateChanged.connect(self.toggle_tracking)

        self.tracking_status_label = QLabel("跟踪状态: 未启用")

        self.tracking_status_label.setStyleSheet("color: #64748b;")



        control_layout.addWidget(self.enable_roi_btn)

        control_layout.addWidget(self.clear_roi_btn)

        control_layout.addWidget(self.use_roi_checkbox)

        control_layout.addWidget(self.tracking_checkbox)

        control_layout.addWidget(self.tracking_status_label)

        main_layout.addLayout(control_layout)



        # 4. 模型选择和资源加载区域(新增模型选择)

        model_and_resource_layout = QVBoxLayout()



        # 模型选择子布局

        model_layout = QHBoxLayout()

        self.select_model_btn = QPushButton("选择模型文件")

        self.select_model_btn.clicked.connect(self.select_model)

        self.current_model_label = QLabel(f"当前模型: {os.path.basename(self.model_path)}")

        self.current_model_label.setStyleSheet("color: #333; font-style: italic;")



        model_layout.addWidget(self.select_model_btn)

        model_layout.addWidget(self.current_model_label, 1)  # 占更多空间显示路径

        main_layout.addLayout(model_layout)



        # 资源加载控制

        load_layout = QHBoxLayout()

        self.load_image_btn = QPushButton("加载图片")

        self.load_image_btn.clicked.connect(lambda: self.load_resource(True))

        self.load_video_btn = QPushButton("加载视频")

        self.load_video_btn.clicked.connect(lambda: self.load_resource(False))

        self.start_btn = QPushButton("开始处理")

        self.start_btn.clicked.connect(self.start_processing)

        self.pause_btn = QPushButton("暂停")

        self.pause_btn.clicked.connect(self.pause_processing)

        self.pause_btn.setEnabled(False)

        self.stop_btn = QPushButton("停止")

        self.stop_btn.clicked.connect(self.stop_processing)

        self.stop_btn.setEnabled(False)



        load_layout.addWidget(self.load_image_btn)

        load_layout.addWidget(self.load_video_btn)

        load_layout.addWidget(self.start_btn)

        load_layout.addWidget(self.pause_btn)

        load_layout.addWidget(self.stop_btn)

        main_layout.addLayout(load_layout)



        # 状态标签

        self.status_label = QLabel("就绪")

        main_layout.addWidget(self.status_label)



        self.setCentralWidget(main_widget)



    # --------------------------

    # 新增:修复的置信度更新方法

    # --------------------------

    def on_conf_changed(self):

        """处理置信度滑块值变化,更新检测参数"""

        self.confidence = self.conf_slider.value() / 100.0  # 转换为0-1范围

        self.conf_label.setText(f"置信度阈值: {self.confidence:.2f}")  # 更新显示



        # 线程安全地更新检测器参数

        if self.inference_thread and self.inference_thread.isRunning():

            self.inference_thread.set_parameters(self.confidence, self.iou_threshold)

        elif self.detector:

            self.detector.set_parameters(self.confidence, self.iou_threshold)



    def on_track_thresh_changed(self):

        """处理跟踪置信度阈值变化"""

        self.track_thresh = self.track_thresh_slider.value() / 100.0

        self.track_thresh_label.setText(f"跟踪置信度阈值: {self.track_thresh:.2f}")



        if self.tracker:

            self.tracker.track_thresh = self.track_thresh



    # def toggle_tracking(self, state):

    #     """切换跟踪状态"""

    #     self.use_tracking = (state == Qt.Checked)

    #     self.tracking_status_label.setText(f"跟踪状态: {'已启用' if self.use_tracking else '未启用'}")

    #

    #     if self.inference_thread and self.inference_thread.isRunning():

    #         self.inference_thread.set_tracking_state(self.use_tracking)

    def select_model(self):

        """选择模型文件并更新加载"""

        # 停止当前处理线程(如果正在运行)

        self.stop_processing()



        # 打开文件对话框选择模型(YOLO模型通常为.pt格式)

        file_path, _ = QFileDialog.getOpenFileName(

            self, "选择模型文件", "", "PyTorch模型 (*.pt);;所有文件 (*)"

        )



        if file_path:

            # 更新模型路径并显示

            self.model_path = file_path

            self.current_model_label.setText(f"当前模型: {os.path.basename(file_path)}")



            # 重新加载模型

            self.load_detector()



    def load_detector(self):

        """加载检测器模型(更新状态提示)"""

        try:

            # 显示加载中状态

            self.status_label.setText(f"正在加载模型: {os.path.basename(self.model_path)}...")

            # 实际加载模型

            self.detector = Detector(self.model_path)

            self.status_label.setText(f"模型加载成功: {os.path.basename(self.model_path)}")

        except Exception as e:

            self.detector = None

            error_msg = f"模型加载失败: {str(e)}"

            self.status_label.setText(error_msg)

            QMessageBox.critical(self, "模型错误", error_msg)





    def on_roi_selected(self, roi_rect):

        """处理ROI选择"""

        self.current_roi = roi_rect

        self.clear_roi_btn.setEnabled(True)

        self.use_roi_checkbox.setEnabled(True)

        if self.tracker:

            self.tracker.set_roi(self.use_roi, roi_rect)



    def toggle_roi_usage(self, checked):

        """切换ROI使用状态"""

        self.use_roi = checked

        if self.inference_thread and self.inference_thread.isRunning():

            self.inference_thread.set_roi(checked, self.current_roi)

        if self.tracker:

            self.tracker.set_roi(checked, self.current_roi)



    def enable_roi_drawing(self):

        """启用ROI绘制模式"""

        self.display_label.draw_mode = not self.display_label.draw_mode

        self.enable_roi_btn.setText("禁用ROI绘制" if self.display_label.draw_mode else "启用ROI绘制")



    def clear_roi(self):

        """清除ROI区域"""

        self.current_roi = None

        self.display_label.roi_rect = None

        self.display_label.update()

        self.clear_roi_btn.setEnabled(False)

        self.use_roi_checkbox.setChecked(False)

        self.use_roi = False



        if self.inference_thread and self.inference_thread.isRunning():

            self.inference_thread.set_roi(False, None)

        if self.tracker:

            self.tracker.reset()

            self.tracker.set_roi(False, None)



    def load_resource(self, is_image):

        """加载图片或视频资源"""

        self.stop_processing()

        self.is_image = is_image

        file_filter = "图片文件 (*.jpg *.jpeg *.png *.bmp)" if is_image else "视频文件 (*.mp4 *.avi *.mov *.mkv)"

        file_path, _ = QFileDialog.getOpenFileName(self, "选择文件", "", file_filter)



        if file_path:

            self.resource_path = file_path

            self.status_label.setText(f"已加载: {os.path.basename(file_path)}")

            # 显示原始资源

            if is_image:

                frame = cv2.imread(file_path)

                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                h, w, c = frame.shape

                qimg = QImage(frame.data, w, h, w * c, QImage.Format_RGB888)

                self.display_label.setPixmap(QPixmap.fromImage(qimg).scaled(

                    self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))

            else:

                self.display_label.setText(f"已加载视频: {os.path.basename(file_path)}")



    def load_detector(self):

        """加载检测器模型"""

        try:

            self.detector = Detector(self.model_path)

            self.status_label.setText("模型加载成功")

        except Exception as e:

            self.status_label.setText(f"模型加载失败: {str(e)}")

            QMessageBox.critical(self, "模型错误", f"无法加载模型: {str(e)}")



    def init_tracker(self):

        """初始化跟踪器"""

        self.tracker = ByteTrackHandler(

            track_thresh=self.track_thresh,

            track_buffer=self.track_buffer,

            use_roi=self.use_roi,

            roi_rect=self.current_roi

        )



    def toggle_tracking(self, state):

        """修复:正确切换跟踪状态并传递跟踪器"""

        self.use_tracking = (state == Qt.Checked)

        self.tracking_status_label.setText(f"跟踪状态: {'已启用' if self.use_tracking else '未启用'}")







        # 修复:如果线程正在运行,动态更新跟踪状态

        if self.inference_thread and self.inference_thread.isRunning():

            self.inference_thread.set_tracking(

                self.use_tracking,

                self.tracker if self.use_tracking else None

            )

            # 显示调试信息

            status = "已启用" if self.use_tracking else "已禁用"

            self.status_label.setText(f"多目标跟踪{status}")

        else:

            # 未运行时更新状态提示

            self.status_label.setText(f"多目标跟踪{('已启用', '已禁用')[not self.use_tracking]}(需重新开始处理生效)")



    def start_processing(self):

        """修复:启动处理时正确传递跟踪器"""

        if not self.resource_path or not self.detector:

            QMessageBox.warning(self, "错误", "请先加载模型和资源")

            return



        self.stop_processing()

        self.inference_thread = InferenceThread(

            self.resource_path, self.detector, self.is_image

        )



        # 关键修复:设置跟踪状态和跟踪器

        self.inference_thread.set_tracking(

            self.use_tracking,

            self.tracker if self.use_tracking else None

        )



        # 连接信号槽

        self.inference_thread.update_frame_signal.connect(self.update_display)

        self.inference_thread.update_stats_signal.connect(self.update_stats)

        self.inference_thread.process_finished_signal.connect(self.process_finished)

        self.inference_thread.error_occurred_signal.connect(self.show_error)

        self.inference_thread.frame_position_updated.connect(self.update_frame_position)



        self.inference_thread.start()

        self.start_btn.setEnabled(False)

        self.pause_btn.setEnabled(not self.is_image)

        self.stop_btn.setEnabled(True)



        # 显示当前模式

        mode = "跟踪模式" if self.use_tracking else "检测模式"

        self.status_label.setText(f"开始处理 - {mode}")



    def update_display(self, frame):

        """更新显示帧"""

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        h, w, c = frame_rgb.shape

        qimg = QImage(frame_rgb.data, w, h, w * c, QImage.Format_RGB888)

        self.display_label.setPixmap(QPixmap.fromImage(qimg).scaled(

            self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))



    def update_stats(self, person_count, vehicle_count):

        """更新统计信息"""

        self.status_label.setText(f"行人: {person_count}, 车辆: {vehicle_count}")



    def pause_processing(self):

        """暂停/恢复处理"""

        if self.inference_thread and self.inference_thread.isRunning():

            self.inference_thread.pause()

            self.pause_btn.setText("恢复" if self.inference_thread.paused else "暂停")



    def stop_processing(self):

        """停止处理"""

        if self.inference_thread and (self.inference_thread.isRunning() or self.inference_thread.paused):

            self.inference_thread.stop()

        self.start_btn.setEnabled(True)

        self.pause_btn.setEnabled(False)

        self.pause_btn.setText("暂停")

        self.stop_btn.setEnabled(False)



    def process_finished(self):

        """处理完成回调"""

        self.start_btn.setEnabled(True)

        self.pause_btn.setEnabled(False)

        self.stop_btn.setEnabled(False)

        self.status_label.setText("处理完成")



    def show_error(self, message):

        """显示错误信息"""

        self.status_label.setText(f"错误: {message}")

        QMessageBox.warning(self, "处理错误", message)



    def update_frame_position(self, pos):

        """更新帧位置"""

        self.current_frame_pos = pos



    def resizeEvent(self, event):

        """窗口大小改变时重绘"""

        if self.resource_path and self.is_image and self.display_label.base_pixmap:

            self.display_label.setPixmap(self.display_label.base_pixmap.scaled(

                self.display_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))

        super().resizeEvent(event)





if __name__ == "__main__":

    QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)

    app = QApplication(sys.argv)

    window = MainWindow()

    window.show()

    sys.exit(app.exec_())