import os



os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"



import cv2

import numpy as np

import matplotlib.pyplot as plt

from pathlib import Path

from ultralytics import SAM

from xml.etree.ElementTree import Element, SubElement, tostring

from xml.dom.minidom import parseString



# ==============================

# 配置区(仅需改 IMAGE_FOLDER)

# ==============================

IMAGE_FOLDER = r"C:\Users\13672\PycharmProjects\PythonProject\data"  # ←← 修改为你的图片文件夹

MODEL_NAME = "sam2.1_s.pt"  # 支持 "sam_b.pt", "sam_l.pt"

# ==============================



# 自动创建 Annotations 目录(与 IMAGE_FOLDER 同级)

ANNOTATIONS_DIR = Path(IMAGE_FOLDER).parent / "Annotations"

ANNOTATIONS_DIR.mkdir(exist_ok=True)





# ==============================

# 工具函数:掩码转边界框

# ==============================

def mask_to_bbox(mask):

    coords = np.argwhere(mask)

    if coords.size == 0:

        return None

    y_min, x_min = coords.min(axis=0)

    y_max, x_max = coords.max(axis=0)

    return [int(x_min), int(y_min), int(x_max), int(y_max)]





# ==============================

# 动态加载初始类别:优先 classes.txt,否则弹窗输入

# ==============================

CLASSES_FILE = Path("classes.txt")





def get_initial_class_names():

    if CLASSES_FILE.exists():

        with open(CLASSES_FILE, 'r', encoding='utf-8') as f:

            classes = [line.strip() for line in f if line.strip()]

        print(f"📚 Loaded {len(classes)} classes from {CLASSES_FILE}: {classes}")

        return classes

    else:

        try:

            import tkinter as tk

            from tkinter import simpledialog

            root = tk.Tk()

            root.withdraw()

            user_input = simpledialog.askstring(

                "Initial Classes",

                "Enter initial class names (comma-separated, e.g., cat,dog):"

            )

            root.destroy()

            if not user_input:

                raise SystemExit("No classes provided. Exiting.")

            classes = [c.strip() for c in user_input.split(',') if c.strip()]

            with open(CLASSES_FILE, 'w', encoding='utf-8') as f:

                f.write('\n'.join(classes))

            print(f"✅ Saved initial classes to {CLASSES_FILE}")

            return classes

        except Exception:

            print("GUI not available. Using console input.")

            user_input = input("Enter initial class names (comma-separated): ")

            classes = [c.strip() for c in user_input.split(',') if c.strip()]

            with open(CLASSES_FILE, 'w', encoding='utf-8') as f:

                f.write('\n'.join(classes))

            return classes





# 全局可变类别列表(支持运行时新增)

CLASS_NAMES = get_initial_class_names()

if not CLASS_NAMES:

    raise SystemExit("❌ No valid classes provided.")



# 预定义颜色(支持最多 20 个类别)

COLORS = [

    (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),

    (0, 255, 255), (128, 0, 128), (255, 165, 0), (0, 128, 0), (128, 128, 0),

    (128, 0, 0), (0, 128, 128), (192, 192, 192), (128, 128, 128), (0, 0, 128),

    (128, 128, 255), (255, 128, 128), (128, 255, 128), (255, 128, 255), (128, 255, 255)

]





# ==============================

# VOC XML 保存函数

# ==============================

def save_voc_xml(image_path, objects, output_dir):

    img = cv2.imread(str(image_path))

    if img is None:

        print(f"⚠️ Failed to read image: {image_path}")

        return

    h, w = img.shape[:2]



    annotation = Element('annotation')

    SubElement(annotation, 'folder').text = image_path.parent.name

    SubElement(annotation, 'filename').text = image_path.name

    SubElement(annotation, 'path').text = str(image_path)



    source = SubElement(annotation, 'source')

    SubElement(source, 'database').text = 'Custom'



    size = SubElement(annotation, 'size')

    SubElement(size, 'width').text = str(w)

    SubElement(size, 'height').text = str(h)

    SubElement(size, 'depth').text = '3'



    SubElement(annotation, 'segmented').text = '0'



    for obj in objects:

        bbox = obj.get('bbox')

        if bbox is None:

            continue

        xmin, ymin, xmax, ymax = bbox



        obj_elem = SubElement(annotation, 'object')

        SubElement(obj_elem, 'name').text = obj['class_name']

        SubElement(obj_elem, 'pose').text = 'Unspecified'

        SubElement(obj_elem, 'truncated').text = '0'

        SubElement(obj_elem, 'difficult').text = '0'



        bndbox = SubElement(obj_elem, 'bndbox')

        SubElement(bndbox, 'xmin').text = str(xmin)

        SubElement(bndbox, 'ymin').text = str(ymin)

        SubElement(bndbox, 'xmax').text = str(xmax)

        SubElement(bndbox, 'ymax').text = str(ymax)



    rough = tostring(annotation, 'utf-8')

    reparsed = parseString(rough)

    xml_str = reparsed.toprettyxml(indent="  ")[23:]



    xml_path = output_dir / (image_path.stem + '.xml')

    with open(xml_path, 'w', encoding='utf-8') as f:

        f.write(xml_str)

    print(f"📁 Saved {len(objects)} objects to {xml_path}")





# ==============================

# 全局状态

# ==============================

SUPPORTED_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}



image_paths = sorted([

    p for p in Path(IMAGE_FOLDER).iterdir()

    if p.suffix.lower() in SUPPORTED_EXTS

])

if not image_paths:

    raise ValueError(f"No images found in {IMAGE_FOLDER}")



print(f"Found {len(image_paths)} images.")

print(f"Initial classes: {CLASS_NAMES}")



model = SAM(MODEL_NAME)



current_index = 0

current_points = []

current_labels = []

current_class_idx = 0

confirmed_objects = []  # 每个元素: {'mask', 'class_name', 'color', 'bbox'}

image_rgb = None



# ==============================

# 绘图与交互

# ==============================

fig, ax = plt.subplots(figsize=(14, 9))

plt.subplots_adjust(left=0.05, right=0.95, top=0.92, bottom=0.05)





def load_image(index):

    global image_rgb

    img_bgr = cv2.imread(str(image_paths[index]))

    if img_bgr is None:

        raise ValueError(f"Failed to load image: {image_paths[index]}")

    image_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)





def redraw():

    ax.clear()

    ax.imshow(image_rgb)



    # 绘制已确认目标:掩码 + 边界框

    overlay = np.zeros_like(image_rgb)

    for obj in confirmed_objects:

        # 绘制半透明掩码

        overlay[obj['mask'] == 1] = obj['color']



        # 绘制边界框

        if 'bbox' in obj and obj['bbox'] is not None:

            x_min, y_min, x_max, y_max = obj['bbox']

            rect = plt.Rectangle(

                (x_min, y_min),

                x_max - x_min,

                y_max - y_min,

                linewidth=2,

                edgecolor=tuple(c / 255 for c in obj['color']),

                facecolor='none'

            )

            ax.add_patch(rect)



            # 可选:在框上方显示类别名

            ax.text(x_min, y_min - 5, obj['class_name'],

                    color='white', fontsize=8,

                    bbox=dict(facecolor=tuple(c / 255 for c in obj['color']), alpha=0.7))



    # 绘制当前点

    for (x, y), label in zip(current_points, current_labels):

        color = 'green' if label == 1 else 'red'

        circle = plt.Circle((x, y), radius=8, color=color, fill=False, linewidth=2)

        ax.add_patch(circle)



    # 标题

    title = f"[{current_index + 1}/{len(image_paths)}] {image_paths[current_index].name}"

    if CLASS_NAMES:

        current_cls = CLASS_NAMES[current_class_idx] if 0 <= current_class_idx < len(CLASS_NAMES) else "N/A"

        title += f"\nClass: {current_cls} | Press 1-{min(9, len(CLASS_NAMES))} to switch"

    title += "\nL-click: FG | R-click: BG | Enter: Confirm Obj | S: Save | U: Undo | R: Reset Points | ←→: Nav | Q: Quit"

    ax.set_title(title, fontsize=9)

    plt.axis("off")

    fig.canvas.draw()





def confirm_current_object():

    global current_points, current_labels, confirmed_objects, CLASS_NAMES

    if not current_points:

        print("⚠️ No points to confirm.")

        return



    try:

        # Run SAM

        results = model(

            str(image_paths[current_index]),

            points=[current_points],

            labels=[current_labels]

        )

        mask = results[0].masks.data[0].cpu().numpy()

        binary_mask = (mask > 0.5).astype(np.uint8)

        bbox = mask_to_bbox(binary_mask)

        if bbox is None:

            print("⚠️ SAM produced empty mask. Discarded.")

            return



        # Default class suggestion

        default_class = CLASS_NAMES[current_class_idx] if 0 <= current_class_idx < len(CLASS_NAMES) else ""



        # === 弹窗确认类别:带下拉选择 ===

        try:

            import tkinter as tk

            from tkinter import ttk, simpledialog



            # 创建临时对话框类(支持下拉)

            class ClassSelectorDialog(simpledialog.Dialog):

                def __init__(self, parent, title, class_list, default=""):

                    self.class_list = class_list

                    self.default = default

                    self.result_class = None

                    super().__init__(parent, title)



                def body(self, master):

                    tk.Label(master, text="Select or enter object class:").grid(row=0, column=0, sticky="w", padx=10,

                                                                                pady=5)

                    self.combo = ttk.Combobox(master, values=self.class_list, width=30)

                    self.combo.set(self.default)

                    self.combo.grid(row=1, column=0, padx=10, pady=5)

                    self.combo.focus()

                    return self.combo  # initial focus



                def apply(self):

                    self.result_class = self.combo.get().strip()



            root = tk.Tk()

            root.withdraw()

            dialog = ClassSelectorDialog(root, "Confirm Object Class", CLASS_NAMES, default_class)

            user_class = dialog.result_class

            root.destroy()



        except Exception as e:

            print(f"\n⚠️ GUI not available ({e}). Falling back to console input.")

            print(f"Detected object. Suggested class: '{default_class}'")

            print("Available classes:", ", ".join(CLASS_NAMES))

            user_class = input("Enter class name (press Enter to accept suggestion): ").strip()

            if user_class == "":

                user_class = default_class



        if not user_class:

            print("❌ Empty class name. Discarded object.")

            return



        # If new class, ask to add to global list

        if user_class not in CLASS_NAMES:

            try:

                import tkinter as tk

                from tkinter import messagebox

                root = tk.Tk()

                root.withdraw()

                add_new = messagebox.askyesno("New Class",

                                              f"Class '{user_class}' is not in the list.\nAdd it permanently?")

                root.destroy()

            except Exception:

                add_new = input(f"Class '{user_class}' is new. Add to global list? (y/n): ").strip().lower() == 'y'



            if add_new:

                CLASS_NAMES.append(user_class)

                with open(CLASSES_FILE, 'w', encoding='utf-8') as f:

                    f.write('\n'.join(CLASS_NAMES))

                print(f"➕ Added new class: '{user_class}' (total: {len(CLASS_NAMES)})")

            else:

                print(f"ℹ️ Using class '{user_class}' without adding to global list.")



        # Assign color

        if user_class in CLASS_NAMES:

            color = COLORS[CLASS_NAMES.index(user_class) % len(COLORS)]

        else:

            color = (128, 128, 128)  # gray for unknown



        # Save object with mask, class, color, and bbox

        confirmed_objects.append({

            'mask': binary_mask,

            'class_name': user_class,

            'color': color,

            'bbox': bbox

        })



        current_points, current_labels = [], []

        redraw()

        print(f"✅ Confirmed object with class: '{user_class}' and bbox: {bbox}")



    except Exception as e:

        print(f"❌ Error during confirmation: {e}")

        import traceback

        traceback.print_exc()





def save_current_annotations():

    if confirmed_objects:

        save_voc_xml(image_paths[current_index], confirmed_objects, ANNOTATIONS_DIR)

    else:

        print("⚠️ No confirmed objects to save.")





def reset_current_points():

    global current_points, current_labels

    current_points, current_labels = [], []

    redraw()





def undo_last_object():

    if confirmed_objects:

        confirmed_objects.pop()

        redraw()

        print("↩️ Undid last object.")

    else:

        print("⚠️ No object to undo.")





def next_image():

    global current_index, confirmed_objects, current_points, current_labels, current_class_idx

    if confirmed_objects:

        save_current_annotations()

    current_index = min(current_index + 1, len(image_paths) - 1)

    confirmed_objects, current_points, current_labels = [], [], []

    current_class_idx = 0

    load_image(current_index)

    redraw()





def prev_image():

    global current_index, confirmed_objects, current_points, current_labels, current_class_idx

    if confirmed_objects:

        save_current_annotations()

    current_index = max(current_index - 1, 0)

    confirmed_objects, current_points, current_labels = [], [], []

    current_class_idx = 0

    load_image(current_index)

    redraw()





def on_click(event):

    if event.inaxes != ax or event.xdata is None or event.ydata is None:

        return

    if event.button == 1:

        label = 1

    elif event.button == 3:

        label = 0

    else:

        return

    current_points.append([event.xdata, event.ydata])

    current_labels.append(label)

    redraw()





def on_key(event):

    global current_class_idx



    # Switch class with 1-9

    if event.key in [str(i) for i in range(1, min(10, len(CLASS_NAMES) + 1))]:

        current_class_idx = int(event.key) - 1

        redraw()



    elif event.key == 'enter':

        confirm_current_object()



    elif event.key.lower() == 's':

        save_current_annotations()



    elif event.key.lower() == 'u':

        undo_last_object()



    elif event.key.lower() == 'r':

        reset_current_points()



    elif event.key == 'right':

        next_image()



    elif event.key == 'left':

        prev_image()



    elif event.key.lower() == 'q':

        plt.close(fig)





# ==============================

# 启动

# ==============================

load_image(current_index)

redraw()



fig.canvas.mpl_connect('button_press_event', on_click)

fig.canvas.mpl_connect('key_press_event', on_key)



print("\n" + "=" * 60)

print("🚀 SAM VOC Annotator (Mask + BBox + Class Confirmation) Ready!")

print(f"Images: {len(image_paths)} | Initial Classes: {len(CLASS_NAMES)}")

print("=" * 60)

print("Controls:")

print("  1-9       : Switch current class suggestion")

print("  Left-click: Add foreground point (green)")

print("  Right-click:Add background point (red)")

print("  Enter     : Run SAM → Confirm class → Add mask + bbox")

print("  S         : Save all confirmed objects as VOC XML")

print("  U         : Undo last confirmed object")

print("  R         : Reset current points (not confirmed objects)")

print("  ← / →     : Navigate images (auto-save before switch)")

print("  Q         : Quit")

print("=" * 60)



plt.show()