2d62036f创建于 2025年8月30日历史提交
import cv2

import numpy as np

from PyQt5.QtCore import Qt, QThread, pyqtSignal, QMutex, QWaitCondition

from ultralytics import YOLO









class Detector:

    """目标检测算法类:封装检测逻辑,支持与跟踪器协同工作"""



    def __init__(self, model_path):

        self.model = self._load_model(model_path)

        self.conf = 0.5

        self.iou = 0.5



        # 类别设置

        self.person_class = 0

        self.vehicle_classes = {2, 3, 5, 7, 18}



        # ROI相关参数

        self.use_roi = False

        self.roi_rect = None  # (x1, y1, x2, y2)



    def _load_model(self, model_path):

        """加载YOLO模型"""

        try:

            return YOLO(model_path)

        except Exception as e:

            raise ValueError(f"模型加载失败: {str(e)}")



    def set_parameters(self, conf, iou):

        """更新检测参数"""

        self.conf = conf

        self.iou = iou



    def set_roi(self, use_roi, roi_rect=None):

        """设置ROI参数"""

        self.use_roi = use_roi

        self.roi_rect = roi_rect



    def detect_raw(self, frame):

        """

        执行原始检测,返回未处理的检测结果(用于跟踪)

        返回: 检测框数组[N,5] (x1,y1,x2,y2,score) 和类别ID数组[N]

        """

        original_frame = frame.copy()

        roi_offset = (0, 0)  # 用于坐标转换的偏移量



        # 应用ROI裁剪

        if self.use_roi and self.roi_rect:

            x1, y1, x2, y2 = self.roi_rect

            x1 = max(0, min(x1, frame.shape[1]))

            y1 = max(0, min(y1, frame.shape[0]))

            x2 = max(x1, min(x2, frame.shape[1]))

            y2 = max(y1, min(y2, frame.shape[0]))



            frame = frame[y1:y2, x1:x2]

            roi_offset = (x1, y1)



        # 执行检测

        results = self.model(

            frame,

            classes=[self.person_class] + list(self.vehicle_classes),

            conf=self.conf,

            iou=self.iou,

            stream=False

        )



        detections = []

        class_ids = []



        for result in results:

            boxes = result.boxes.xyxy.cpu().numpy()  # [x1, y1, x2, y2]

            scores = result.boxes.conf.cpu().numpy()

            classes = result.boxes.cls.cpu().numpy().astype(int)



            for box, score, cls_id in zip(boxes, scores, classes):

                if cls_id in self.vehicle_classes:

                    # 转换坐标到原始图像

                    x1, y1, x2, y2 = box

                    x1 += roi_offset[0]

                    y1 += roi_offset[1]

                    x2 += roi_offset[0]

                    y2 += roi_offset[1]



                    detections.append([x1, y1, x2, y2, score])

                    class_ids.append(cls_id)



        return np.array(detections), np.array(class_ids)



    def draw_detections(self, frame, detections, class_ids):

        """绘制检测结果"""

        processed_frame = frame.copy()

        person_count = 0

        vehicle_count = 0



        for det, cls_id in zip(detections, class_ids):

            x1, y1, x2, y2, score = det

            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])



            if cls_id == self.person_class:

                person_count += 1

                color = (0, 255, 0)  # 绿色:行人

                label = f"person: {score:.2f}"

            elif cls_id in self.vehicle_classes:

                vehicle_count += 1

                color = (255, 0, 0)  # 蓝色:车辆

                label = f"vehicle: {score:.2f}"

            else:

                continue



            # 绘制边界框

            cv2.rectangle(processed_frame, (x1, y1), (x2, y2), color, 2)

            # 绘制标签

            cv2.putText(processed_frame, label, (x1, y1 - 10),

                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)



        # 绘制ROI(如果启用)

        if self.use_roi and self.roi_rect:

            rx1, ry1, rx2, ry2 = self.roi_rect

            cv2.rectangle(processed_frame, (rx1, ry1), (rx2, ry2), (0, 255, 255), 2)

            cv2.putText(processed_frame, "ROI", (rx1, ry1 - 10),

                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)



        return processed_frame, person_count, vehicle_count



    def draw_tracked_results(self, frame, tracked_results):

        """绘制跟踪结果(含跟踪ID)"""

        processed_frame = frame.copy()

        person_count = 0

        vehicle_count = 0



        for track in tracked_results:

            x1, y1, x2, y2, track_id, cls_id, score = track

            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])



            if cls_id == self.person_class:

                person_count += 1

                color = (0, 255, 0)  # 绿色:行人

                label = f"person {track_id}: {score:.2f}"

            elif cls_id in self.vehicle_classes:

                vehicle_count += 1

                color = (255, 0, 0)  # 蓝色:车辆

                label = f"vehicle {track_id}: {score:.2f}"

            else:

                continue



            # 绘制边界框

            cv2.rectangle(processed_frame, (x1, y1), (x2, y2), color, 2)

            # 绘制带跟踪ID的标签

            cv2.putText(processed_frame, f"ID: {int(track_id)}", (x1, y1 - 10),

                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

            # # 绘制跟踪ID标识

            # cv2.putText(processed_frame, f"ID: {track_id}", (x1 + 5, y2 - 5),

            #             cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)



        # 绘制ROI(如果启用)

        if self.use_roi and self.roi_rect:

            rx1, ry1, rx2, ry2 = self.roi_rect

            cv2.rectangle(processed_frame, (rx1, ry1), (rx2, ry2), (0, 255, 255), 2)

            cv2.putText(processed_frame, "ROI", (rx1, ry1 - 10),

                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)



        return processed_frame, person_count, vehicle_count





class InferenceThread(QThread):

    """支持多目标跟踪的推理线程"""

    update_frame_signal = pyqtSignal(object)  # 发送处理后的帧

    update_stats_signal = pyqtSignal(int, int)  # 发送统计结果(行人, 车辆)

    process_finished_signal = pyqtSignal()

    error_occurred_signal = pyqtSignal(str)

    frame_position_updated = pyqtSignal(int)  # 发送当前帧位置



    def __init__(self, path, detector, is_image=False, use_tracking=False, tracker=None):

        super().__init__()

        self.path = path

        self.detector = detector

        self.is_image = is_image



        # 跟踪相关参数

        self.use_tracking = use_tracking

        self.tracker = tracker



        # 线程控制(确保paused是公共属性)

        self.running = False

        self.paused = False  # 这是公共属性,可被外部访问

        self.mutex = QMutex()

        self.cond = QWaitCondition()

        self.cap = None



        # 视频位置管理

        self.current_frame_pos = 0

        self.total_frames = 0



    def set_parameters(self, conf, iou):

        """更新检测参数(线程安全)"""

        self.mutex.lock()

        if self.detector:

            self.detector.set_parameters(conf, iou)

        self.mutex.unlock()



    def set_roi(self, use_roi, roi_rect=None):

        """设置ROI参数(线程安全)"""

        self.mutex.lock()

        if self.detector:

            self.detector.set_roi(use_roi, roi_rect)

        # 同时更新跟踪器的ROI设置

        if self.tracker:

            self.tracker.use_roi = use_roi

            self.tracker.roi_rect = roi_rect

        self.mutex.unlock()



    def set_tracking(self, use_tracking, tracker=None):

        """修复:线程安全地设置跟踪状态和跟踪器"""

        self.mutex.lock()

        self._use_tracking = use_tracking

        self._tracker = tracker

        self.mutex.unlock()





    def get_current_position(self):

        """获取当前视频位置"""

        self.mutex.lock()

        pos = self.current_frame_pos

        self.mutex.unlock()

        return pos



    def set_start_position(self, pos):

        """设置视频起始位置"""

        self.mutex.lock()

        self.current_frame_pos = pos

        self.mutex.unlock()



    def run(self):

        """修复:确保跟踪状态正确应用"""

        try:

            self.mutex.lock()

            self.running = True

            self.paused = False

            self.mutex.unlock()



            if self.is_image:

                self._process_image()

            else:

                self._process_video()



            self.process_finished_signal.emit()



        except Exception as e:

            self.error_occurred_signal.emit(f"处理错误: {str(e)}")

        finally:

            self._cleanup()



    def _process_image(self):

        """处理单张图片"""

        frame = cv2.imread(self.path)

        if frame is None:

            self.error_occurred_signal.emit("无法加载图片文件")

            return



        frame = cv2.resize(frame, (1280, 720))

        self.mutex.lock()

        use_tracking = self.use_tracking

        tracker = self.tracker

        self.mutex.unlock()



        # 执行检测

        detections, class_ids = self.detector.detect_raw(frame)



        # 处理跟踪

        if use_tracking and tracker and len(detections) > 0:



            tracked_results = tracker.update(detections, class_ids)

            processed_frame, person_count, vehicle_count = self.detector.draw_tracked_results(

                frame, tracked_results)

        else:

            processed_frame, person_count, vehicle_count = self.detector.draw_detections(

                frame, detections, class_ids)



        # 发送结果

        self.update_frame_signal.emit(processed_frame)

        self.update_stats_signal.emit(person_count, vehicle_count)



    def _process_video(self):

        """修复:确保每帧都检查最新的跟踪状态"""

        self.cap = cv2.VideoCapture(self.path)

        if not self.cap.isOpened():

            self.error_occurred_signal.emit("无法打开视频文件")

            return



        while True:

            # 检查线程状态和跟踪状态(修复:正确获取状态)

            self.mutex.lock()

            if not self.running:

                self.mutex.unlock()

                break

            if self.paused:

                self.cond.wait(self.mutex)



            # 获取当前跟踪状态(关键修复)

            current_use_tracking = self._use_tracking

            current_tracker = self._tracker if current_use_tracking else None

            self.mutex.unlock()



            # 读取帧

            ret, frame = self.cap.read()

            if not ret:

                break



            # 处理帧

            frame = cv2.resize(frame, (1280, 720))

            detections, class_ids = self.detector.detect_raw(frame)



            # 应用跟踪(修复:确保正确调用跟踪器)

            if current_use_tracking and current_tracker is not None:

                # 打印调试信息(帮助排查问题)

                # if len(detections) > 0:

                #     self.error_occurred_signal.emit(

                #         f"跟踪中 - 检测到{len(detections)}个目标")

                # else:

                #     self.error_occurred_signal.emit("跟踪中 - 未检测到目标")



                tracked_results = current_tracker.update(detections, class_ids)

                processed_frame, p_count, v_count = self.detector.draw_tracked_results(

                    frame, tracked_results)

            else:

                processed_frame, p_count, v_count = self.detector.draw_detections(

                    frame, detections, class_ids)



            # 发送结果

            self.update_frame_signal.emit(processed_frame)

            self.update_stats_signal.emit(p_count, v_count)

            self.msleep(33)



        self.cap.release()



    def _cleanup(self):

        """资源清理"""

        self.mutex.lock()

        # 保存当前位置

        if self.cap and not self.is_image and self.cap.isOpened():

            self.current_frame_pos = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))

            self.cap.release()

        self.running = False

        self.paused = False

        self.mutex.unlock()



    def pause(self):

        """暂停/恢复线程"""

        self.mutex.lock()

        self.paused = not self.paused

        if not self.paused:

            self.cond.wakeAll()

        self.mutex.unlock()



    def stop(self):

        """停止线程"""

        self.mutex.lock()

        # 保存当前位置

        if self.cap and not self.is_image and self.cap.isOpened():

            self.current_frame_pos = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))

        self.running = False

        self.paused = False

        self.cond.wakeAll()

        self.mutex.unlock()



        # 等待线程结束

        if not self.wait(1000):

            self.terminate()

        self.wait()