"""
RT-DETR OM inference on COCO2017 images
"""
import argparse
import os
import random
from pathlib import Path
import cv2
import numpy as np
from ais_bench.infer.interface import InferSession
COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cellphone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
def preprocess_image(image_path, input_size=640):
"""Preprocess image for inference"""
img = cv2.imread(image_path)
orig_h, orig_w = img.shape[:2]
img_resized = cv2.resize(img, (input_size, input_size))
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
img_normalized = img_rgb.astype(np.float32) / 255.0
img_input = np.transpose(img_normalized, (2, 0, 1))
img_input = np.expand_dims(img_input, axis=0)
return img_input, img, (orig_h, orig_w)
def draw_boxes(image, labels, boxes, scores, threshold=0.5):
"""Draw bounding boxes on image"""
h, w = image.shape[:2]
for label, box, score in zip(labels, boxes, scores):
if score < threshold:
continue
x1, y1, x2, y2 = box
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
label_text = f"{COCO_CLASSES[int(label)]}: {score:.2f}"
cv2.putText(image, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return image
def main():
parser = argparse.ArgumentParser(description='RT-DETR OM inference on COCO2017')
parser.add_argument('-m', '--model', required=True, help='OM model file')
parser.add_argument('-i', '--image', default=None, help='Image path (random if not specified)')
parser.add_argument('-d', '--data-path', default='./dataset/coco', help='COCO dataset path')
parser.add_argument('-o', '--output', default='result.jpg', help='Output image path')
parser.add_argument('-s', '--input-size', type=int, default=640, help='Input size')
parser.add_argument('-t', '--threshold', type=float, default=0.5, help='Confidence threshold')
parser.add_argument('--device', type=int, default=0, help='NPU device ID')
args = parser.parse_args()
if args.image:
image_path = args.image
else:
val_dir = Path(args.data_path) / 'val2017'
images = list(val_dir.glob('*.jpg'))
if not images:
print(f"Error: No images found in {val_dir}")
exit(1)
image_path = str(random.choice(images))
print(f"Image: {image_path}")
print(f"Loading OM model: {args.model}")
session = InferSession(device_id=args.device, model_path=args.model)
preprocess_result = preprocess_image(image_path, args.input_size)
if preprocess_result is None:
print("Failed to preprocess image.")
exit(1)
img_input, orig_img, orig_size = preprocess_result
orig_h, orig_w = orig_size
orig_target_sizes = np.array([[orig_w, orig_h]], dtype=np.int64)
print(f"Original size: {orig_h}x{orig_w}")
print(f"Input shape: {img_input.shape}")
print("Running inference...")
outputs = session.infer([img_input, orig_target_sizes])
labels = outputs[0][0]
boxes = outputs[1][0]
scores = outputs[2][0]
print(f"\nDetections: {len(labels)}")
print(f"Labels shape: {labels.shape}")
print(f"Boxes shape: {boxes.shape}")
print(f"Scores shape: {scores.shape}")
valid_idx = scores >= args.threshold
labels = labels[valid_idx]
boxes = boxes[valid_idx]
scores = scores[valid_idx]
print(f"\nDetections above threshold {args.threshold}: {len(labels)}")
for i, (label, box, score) in enumerate(zip(labels, boxes, scores)):
print(f" {i+1}. {COCO_CLASSES[int(label)]}: {score:.3f} at {box}")
result_img = draw_boxes(orig_img.copy(), labels, boxes, scores, args.threshold)
cv2.imwrite(args.output, result_img)
print(f"\nResult saved to: {args.output}")
if __name__ == '__main__':
main()