YYour Nameadd openvino
1def4e0e创建于 2025年12月15日历史提交
import numpy as np
import cv2
from rknn.api import RKNN
import argparse
import itertools
from transformers import AutoTokenizer
import os

# ===== 配置参数(保持不变)=====
OBJ_THRESH = 0.25
NMS_THRESH = 0.45

IMG_SIZE = [640, 640]
SEQUENCE_LEN = 20
PAD_VALUE = 49407

CLASSES = ("person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck", "boat", "traffic light",
           "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
           "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
           "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife",
           "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "sofa",
           "pottedplant", "bed", "diningtable", "toilet", "tvmonitor", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
           "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush")

# ===== 工具函数(保持不变)=====
def text_tokenizer(text, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    text = list(itertools.chain(*text))
    text = tokenizer(text=text, return_tensors='pt', padding=True)
    return np.array(text['input_ids'])

def letter_box(img, new_shape, pad_color=(0, 0, 0)):
    shape = img.shape[:2]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
    dw /= 2
    dh /= 2
    if shape[::-1] != new_unpad:
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=pad_color)
    return img, r, (dw, dh)

def filter_boxes(boxes, box_confidences, box_class_probs):
    box_confidences = box_confidences.reshape(-1)
    class_max_score = np.max(box_class_probs, axis=-1)
    classes = np.argmax(box_class_probs, axis=-1)
    _class_pos = np.where(class_max_score * box_confidences >= OBJ_THRESH)
    scores = (class_max_score * box_confidences)[_class_pos]
    boxes = boxes[_class_pos]
    classes = classes[_class_pos]
    return boxes, classes, scores

def nms_boxes(boxes, scores):
    x = boxes[:, 0]
    y = boxes[:, 1]
    w = boxes[:, 2] - boxes[:, 0]
    h = boxes[:, 3] - boxes[:, 1]
    areas = w * h
    order = scores.argsort()[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x[i], x[order[1:]])
        yy1 = np.maximum(y[i], y[order[1:]])
        xx2 = np.minimum(x[i] + w[i], x[order[1:]] + w[order[1:]])
        yy2 = np.minimum(y[i] + h[i], y[order[1:]] + h[order[1:]])
        w1 = np.maximum(0.0, xx2 - xx1 + 1e-5)
        h1 = np.maximum(0.0, yy2 - yy1 + 1e-5)
        inter = w1 * h1
        ovr = inter / (areas[i] + areas[order[1:]] - inter)
        inds = np.where(ovr <= NMS_THRESH)[0]
        order = order[inds + 1]
    return np.array(keep)

def box_process(position):
    grid_h, grid_w = position.shape[2:4]
    col, row = np.meshgrid(np.arange(0, grid_w), np.arange(0, grid_h))
    col = col.reshape(1, 1, grid_h, grid_w)
    row = row.reshape(1, 1, grid_h, grid_w)
    grid = np.concatenate((col, row), axis=1)
    stride = np.array([IMG_SIZE[1] // grid_h, IMG_SIZE[0] // grid_w]).reshape(1, 2, 1, 1)
    box_xy = grid + 0.5 - position[:, 0:2, :, :]
    box_xy2 = grid + 0.5 + position[:, 2:4, :, :]
    xyxy = np.concatenate((box_xy * stride, box_xy2 * stride), axis=1)
    return xyxy

def postprocess(input_data):
    boxes, scores, classes_conf = [], [], []
    default_branch = 3
    pair_per_branch = len(input_data) // default_branch
    for i in range(default_branch):
        boxes.append(box_process(input_data[pair_per_branch * i + 1]))
        classes_conf.append(input_data[pair_per_branch * i])
        scores.append(np.ones_like(input_data[pair_per_branch * i][:, :1, :, :], dtype=np.float32))

    def sp_flatten(_in):
        ch = _in.shape[1]
        _in = _in.transpose(0, 2, 3, 1)
        return _in.reshape(-1, ch)

    boxes = [sp_flatten(_v) for _v in boxes]
    classes_conf = [sp_flatten(_v) for _v in classes_conf]
    scores = [sp_flatten(_v) for _v in scores]

    boxes = np.concatenate(boxes)
    classes_conf = np.concatenate(classes_conf)
    scores = np.concatenate(scores)

    boxes, classes, scores = filter_boxes(boxes, scores, classes_conf)

    nboxes, nclasses, nscores = [], [], []
    for c in set(classes):
        inds = np.where(classes == c)
        b, c_, s = boxes[inds], classes[inds], scores[inds]
        keep = nms_boxes(b, s)
        if len(keep) > 0:
            nboxes.append(b[keep])
            nclasses.append(c_[keep])
            nscores.append(s[keep])

    if not nboxes:
        return None, None, None

    boxes = np.concatenate(nboxes)
    classes = np.concatenate(nclasses)
    scores = np.concatenate(nscores)
    return boxes, classes, scores

def draw(image, boxes, scores, classes):
    for box, score, cl in zip(boxes, scores, classes):
        x0, y0, x1, y1 = map(int, box)
        cv2.rectangle(image, (x0, y0), (x1, y1), (255, 0, 0), 2)
        label = f'{CLASSES[cl]} {score:.2f}'
        cv2.putText(image, label, (x0, y0 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

# ===== 视频推理主类 =====
class YOLOWORLDVideo:
    def __init__(self, args):
        self.text_model = args.text_model
        self.yolo_world = args.yolo_world
        self.target = args.target
        self.input_source = args.input  # 视频文件路径 or 摄像头ID
        self.text = args.text
        self.save_output = args.save_output
        self.output_path = args.output

        # 预计算文本特征(只做一次)
        self.text_features = self._precompute_text_features()

        # 初始化 YOLO-World RKNN(只初始化一次)
        self.yolo_rknn = RKNN(verbose=False)
        self.yolo_rknn.load_rknn(self.yolo_world)
        self.yolo_rknn.init_runtime(target=self.target)

    def _precompute_text_features(self):
        input_ids = text_tokenizer(self.text, "openai/clip-vit-base-patch32")
        text_num, seq_len = input_ids.shape
        if seq_len >= SEQUENCE_LEN:
            input_data = input_ids[:, :SEQUENCE_LEN]
        else:
            input_data = np.full((text_num, SEQUENCE_LEN), PAD_VALUE, dtype=np.float32)
            input_data[:, :seq_len] = input_ids

        rknn = RKNN(verbose=False)
        rknn.load_rknn(self.text_model)
        rknn.init_runtime(target=self.target)
        outputs = []
        for i in range(text_num):
            out = rknn.inference(inputs=[input_data[i:i+1, :]])[0]
            outputs.append(out)
        rknn.release()

        text_feat = np.concatenate(outputs, axis=0)
        return np.expand_dims(text_feat, axis=0)  # [1, num_classes, 512]

    def preprocess_frame(self, frame):
        img, _, _ = letter_box(frame, new_shape=[IMG_SIZE[1], IMG_SIZE[0]], pad_color=(0, 0, 0))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.expand_dims(img.astype(np.float32), axis=0)
        return img

    def run(self):
        # 打开视频流
        try:
            source = int(self.input_source)  # 尝试作为摄像头ID
        except ValueError:
            source = self.input_source  # 作为文件路径

        cap = cv2.VideoCapture(source)
        if not cap.isOpened():
            print(f"Error: Cannot open video source: {source}")
            return

        # 获取视频属性(用于保存)
        fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        out = None
        if self.save_output:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(self.output_path, fourcc, fps, (width, height))

        print("Press 'q' to quit.")
        frame_count = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                print("End of video or failed to read frame.")
                break
                
            import time
            t = time.time()
        
            frame = cv2.resize(frame, (640, 640))
            

            # 预处理
            input_img = self.preprocess_frame(frame)

            # 推理
            outputs = self.yolo_rknn.inference(inputs=[input_img, self.text_features])

            # 后处理
            boxes, classes, scores = postprocess(outputs)

            # 绘制
            if boxes is not None:
                draw(frame, boxes, scores, classes)
            
            fps = 1 / (time.time() - t)
            text = "{:.2f}".format(fps) + ' fps'
           
            cv2.putText(frame, text, (20,50), 2, 1, (0, 255, 0))
            resized = cv2.resize(frame, (640, 480))

            # 显示
            cv2.imshow('YOLO-World Video', resized)
            if self.save_output and out:
                out.write(frame)

            frame_count += 1
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        # 释放资源
        cap.release()
        if out:
            out.release()
        cv2.destroyAllWindows()
        self.yolo_rknn.release()
        print(f"Processed {frame_count} frames.")

# ===== 主函数 =====
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='YOLO-World Video Inference with RKNN')
    parser.add_argument('--text_model', type=str, default='../model/clip_text.rknn')
    parser.add_argument('--yolo_world', type=str, default='../model/yolo_world_v2s.rknn')
    parser.add_argument('--target', type=str, required=True, help="e.g., 'rk3588', 'rk3566'")
    parser.add_argument('--input', type=str, default='/home/orangepi/Videos/166959951-1-208.mp4', help="Video file path or camera ID (default: 0)")
    parser.add_argument('--text', type=list, default=[CLASSES])
    parser.add_argument('--save_output', action='store_true', help="Save output video")
    parser.add_argument('--output', type=str, default='', help="Output video path")

    args = parser.parse_args()

    # 自动创建 output 目录(如果需要)
    if args.save_output:
        os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else '.', exist_ok=True)

    detector = YOLOWORLDVideo(args)
    detector.run()