import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import SAM
from xml.etree.ElementTree import Element, SubElement, tostring
from xml.dom.minidom import parseString
IMAGE_FOLDER = r"C:\Users\13672\PycharmProjects\PythonProject\data"
MODEL_NAME = "sam2.1_s.pt"
ANNOTATIONS_DIR = Path(IMAGE_FOLDER).parent / "Annotations"
ANNOTATIONS_DIR.mkdir(exist_ok=True)
def mask_to_bbox(mask):
coords = np.argwhere(mask)
if coords.size == 0:
return None
y_min, x_min = coords.min(axis=0)
y_max, x_max = coords.max(axis=0)
return [int(x_min), int(y_min), int(x_max), int(y_max)]
CLASSES_FILE = Path("classes.txt")
def get_initial_class_names():
if CLASSES_FILE.exists():
with open(CLASSES_FILE, 'r', encoding='utf-8') as f:
classes = [line.strip() for line in f if line.strip()]
print(f"📚 Loaded {len(classes)} classes from {CLASSES_FILE}: {classes}")
return classes
else:
try:
import tkinter as tk
from tkinter import simpledialog
root = tk.Tk()
root.withdraw()
user_input = simpledialog.askstring(
"Initial Classes",
"Enter initial class names (comma-separated, e.g., cat,dog):"
)
root.destroy()
if not user_input:
raise SystemExit("No classes provided. Exiting.")
classes = [c.strip() for c in user_input.split(',') if c.strip()]
with open(CLASSES_FILE, 'w', encoding='utf-8') as f:
f.write('\n'.join(classes))
print(f"✅ Saved initial classes to {CLASSES_FILE}")
return classes
except Exception:
print("GUI not available. Using console input.")
user_input = input("Enter initial class names (comma-separated): ")
classes = [c.strip() for c in user_input.split(',') if c.strip()]
with open(CLASSES_FILE, 'w', encoding='utf-8') as f:
f.write('\n'.join(classes))
return classes
CLASS_NAMES = get_initial_class_names()
if not CLASS_NAMES:
raise SystemExit("❌ No valid classes provided.")
COLORS = [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
(0, 255, 255), (128, 0, 128), (255, 165, 0), (0, 128, 0), (128, 128, 0),
(128, 0, 0), (0, 128, 128), (192, 192, 192), (128, 128, 128), (0, 0, 128),
(128, 128, 255), (255, 128, 128), (128, 255, 128), (255, 128, 255), (128, 255, 255)
]
def save_voc_xml(image_path, objects, output_dir):
img = cv2.imread(str(image_path))
if img is None:
print(f"⚠️ Failed to read image: {image_path}")
return
h, w = img.shape[:2]
annotation = Element('annotation')
SubElement(annotation, 'folder').text = image_path.parent.name
SubElement(annotation, 'filename').text = image_path.name
SubElement(annotation, 'path').text = str(image_path)
source = SubElement(annotation, 'source')
SubElement(source, 'database').text = 'Custom'
size = SubElement(annotation, 'size')
SubElement(size, 'width').text = str(w)
SubElement(size, 'height').text = str(h)
SubElement(size, 'depth').text = '3'
SubElement(annotation, 'segmented').text = '0'
for obj in objects:
bbox = obj.get('bbox')
if bbox is None:
continue
xmin, ymin, xmax, ymax = bbox
obj_elem = SubElement(annotation, 'object')
SubElement(obj_elem, 'name').text = obj['class_name']
SubElement(obj_elem, 'pose').text = 'Unspecified'
SubElement(obj_elem, 'truncated').text = '0'
SubElement(obj_elem, 'difficult').text = '0'
bndbox = SubElement(obj_elem, 'bndbox')
SubElement(bndbox, 'xmin').text = str(xmin)
SubElement(bndbox, 'ymin').text = str(ymin)
SubElement(bndbox, 'xmax').text = str(xmax)
SubElement(bndbox, 'ymax').text = str(ymax)
rough = tostring(annotation, 'utf-8')
reparsed = parseString(rough)
xml_str = reparsed.toprettyxml(indent=" ")[23:]
xml_path = output_dir / (image_path.stem + '.xml')
with open(xml_path, 'w', encoding='utf-8') as f:
f.write(xml_str)
print(f"📁 Saved {len(objects)} objects to {xml_path}")
SUPPORTED_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_paths = sorted([
p for p in Path(IMAGE_FOLDER).iterdir()
if p.suffix.lower() in SUPPORTED_EXTS
])
if not image_paths:
raise ValueError(f"No images found in {IMAGE_FOLDER}")
print(f"Found {len(image_paths)} images.")
print(f"Initial classes: {CLASS_NAMES}")
model = SAM(MODEL_NAME)
current_index = 0
current_points = []
current_labels = []
current_class_idx = 0
confirmed_objects = []
image_rgb = None
fig, ax = plt.subplots(figsize=(14, 9))
plt.subplots_adjust(left=0.05, right=0.95, top=0.92, bottom=0.05)
def load_image(index):
global image_rgb
img_bgr = cv2.imread(str(image_paths[index]))
if img_bgr is None:
raise ValueError(f"Failed to load image: {image_paths[index]}")
image_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
def redraw():
ax.clear()
ax.imshow(image_rgb)
overlay = np.zeros_like(image_rgb)
for obj in confirmed_objects:
overlay[obj['mask'] == 1] = obj['color']
if 'bbox' in obj and obj['bbox'] is not None:
x_min, y_min, x_max, y_max = obj['bbox']
rect = plt.Rectangle(
(x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=2,
edgecolor=tuple(c / 255 for c in obj['color']),
facecolor='none'
)
ax.add_patch(rect)
ax.text(x_min, y_min - 5, obj['class_name'],
color='white', fontsize=8,
bbox=dict(facecolor=tuple(c / 255 for c in obj['color']), alpha=0.7))
for (x, y), label in zip(current_points, current_labels):
color = 'green' if label == 1 else 'red'
circle = plt.Circle((x, y), radius=8, color=color, fill=False, linewidth=2)
ax.add_patch(circle)
title = f"[{current_index + 1}/{len(image_paths)}] {image_paths[current_index].name}"
if CLASS_NAMES:
current_cls = CLASS_NAMES[current_class_idx] if 0 <= current_class_idx < len(CLASS_NAMES) else "N/A"
title += f"\nClass: {current_cls} | Press 1-{min(9, len(CLASS_NAMES))} to switch"
title += "\nL-click: FG | R-click: BG | Enter: Confirm Obj | S: Save | U: Undo | R: Reset Points | ←→: Nav | Q: Quit"
ax.set_title(title, fontsize=9)
plt.axis("off")
fig.canvas.draw()
def confirm_current_object():
global current_points, current_labels, confirmed_objects, CLASS_NAMES
if not current_points:
print("⚠️ No points to confirm.")
return
try:
results = model(
str(image_paths[current_index]),
points=[current_points],
labels=[current_labels]
)
mask = results[0].masks.data[0].cpu().numpy()
binary_mask = (mask > 0.5).astype(np.uint8)
bbox = mask_to_bbox(binary_mask)
if bbox is None:
print("⚠️ SAM produced empty mask. Discarded.")
return
default_class = CLASS_NAMES[current_class_idx] if 0 <= current_class_idx < len(CLASS_NAMES) else ""
try:
import tkinter as tk
from tkinter import ttk, simpledialog
class ClassSelectorDialog(simpledialog.Dialog):
def __init__(self, parent, title, class_list, default=""):
self.class_list = class_list
self.default = default
self.result_class = None
super().__init__(parent, title)
def body(self, master):
tk.Label(master, text="Select or enter object class:").grid(row=0, column=0, sticky="w", padx=10,
pady=5)
self.combo = ttk.Combobox(master, values=self.class_list, width=30)
self.combo.set(self.default)
self.combo.grid(row=1, column=0, padx=10, pady=5)
self.combo.focus()
return self.combo
def apply(self):
self.result_class = self.combo.get().strip()
root = tk.Tk()
root.withdraw()
dialog = ClassSelectorDialog(root, "Confirm Object Class", CLASS_NAMES, default_class)
user_class = dialog.result_class
root.destroy()
except Exception as e:
print(f"\n⚠️ GUI not available ({e}). Falling back to console input.")
print(f"Detected object. Suggested class: '{default_class}'")
print("Available classes:", ", ".join(CLASS_NAMES))
user_class = input("Enter class name (press Enter to accept suggestion): ").strip()
if user_class == "":
user_class = default_class
if not user_class:
print("❌ Empty class name. Discarded object.")
return
if user_class not in CLASS_NAMES:
try:
import tkinter as tk
from tkinter import messagebox
root = tk.Tk()
root.withdraw()
add_new = messagebox.askyesno("New Class",
f"Class '{user_class}' is not in the list.\nAdd it permanently?")
root.destroy()
except Exception:
add_new = input(f"Class '{user_class}' is new. Add to global list? (y/n): ").strip().lower() == 'y'
if add_new:
CLASS_NAMES.append(user_class)
with open(CLASSES_FILE, 'w', encoding='utf-8') as f:
f.write('\n'.join(CLASS_NAMES))
print(f"➕ Added new class: '{user_class}' (total: {len(CLASS_NAMES)})")
else:
print(f"ℹ️ Using class '{user_class}' without adding to global list.")
if user_class in CLASS_NAMES:
color = COLORS[CLASS_NAMES.index(user_class) % len(COLORS)]
else:
color = (128, 128, 128)
confirmed_objects.append({
'mask': binary_mask,
'class_name': user_class,
'color': color,
'bbox': bbox
})
current_points, current_labels = [], []
redraw()
print(f"✅ Confirmed object with class: '{user_class}' and bbox: {bbox}")
except Exception as e:
print(f"❌ Error during confirmation: {e}")
import traceback
traceback.print_exc()
def save_current_annotations():
if confirmed_objects:
save_voc_xml(image_paths[current_index], confirmed_objects, ANNOTATIONS_DIR)
else:
print("⚠️ No confirmed objects to save.")
def reset_current_points():
global current_points, current_labels
current_points, current_labels = [], []
redraw()
def undo_last_object():
if confirmed_objects:
confirmed_objects.pop()
redraw()
print("↩️ Undid last object.")
else:
print("⚠️ No object to undo.")
def next_image():
global current_index, confirmed_objects, current_points, current_labels, current_class_idx
if confirmed_objects:
save_current_annotations()
current_index = min(current_index + 1, len(image_paths) - 1)
confirmed_objects, current_points, current_labels = [], [], []
current_class_idx = 0
load_image(current_index)
redraw()
def prev_image():
global current_index, confirmed_objects, current_points, current_labels, current_class_idx
if confirmed_objects:
save_current_annotations()
current_index = max(current_index - 1, 0)
confirmed_objects, current_points, current_labels = [], [], []
current_class_idx = 0
load_image(current_index)
redraw()
def on_click(event):
if event.inaxes != ax or event.xdata is None or event.ydata is None:
return
if event.button == 1:
label = 1
elif event.button == 3:
label = 0
else:
return
current_points.append([event.xdata, event.ydata])
current_labels.append(label)
redraw()
def on_key(event):
global current_class_idx
if event.key in [str(i) for i in range(1, min(10, len(CLASS_NAMES) + 1))]:
current_class_idx = int(event.key) - 1
redraw()
elif event.key == 'enter':
confirm_current_object()
elif event.key.lower() == 's':
save_current_annotations()
elif event.key.lower() == 'u':
undo_last_object()
elif event.key.lower() == 'r':
reset_current_points()
elif event.key == 'right':
next_image()
elif event.key == 'left':
prev_image()
elif event.key.lower() == 'q':
plt.close(fig)
load_image(current_index)
redraw()
fig.canvas.mpl_connect('button_press_event', on_click)
fig.canvas.mpl_connect('key_press_event', on_key)
print("\n" + "=" * 60)
print("🚀 SAM VOC Annotator (Mask + BBox + Class Confirmation) Ready!")
print(f"Images: {len(image_paths)} | Initial Classes: {len(CLASS_NAMES)}")
print("=" * 60)
print("Controls:")
print(" 1-9 : Switch current class suggestion")
print(" Left-click: Add foreground point (green)")
print(" Right-click:Add background point (red)")
print(" Enter : Run SAM → Confirm class → Add mask + bbox")
print(" S : Save all confirmed objects as VOC XML")
print(" U : Undo last confirmed object")
print(" R : Reset current points (not confirmed objects)")
print(" ← / → : Navigate images (auto-save before switch)")
print(" Q : Quit")
print("=" * 60)
plt.show()