import os
import json
import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from pycocotools import mask as mask_util
from ais_bench.infer.interface import InferSession
from transformers.models.sam3 import Sam3Processor
SEQ_LEN = 32
MASK_SIZE = 288
NUM_QUERIES = 200
saco_gold_gts = {
"metaclip": [
"gold_metaclip_merged_a_release_test.json",
"gold_metaclip_merged_b_release_test.json",
"gold_metaclip_merged_c_release_test.json",
],
"sa1b": [
"gold_sa1b_merged_a_release_test.json",
"gold_sa1b_merged_b_release_test.json",
"gold_sa1b_merged_c_release_test.json",
],
"crowded": [
"gold_crowded_merged_a_release_test.json",
"gold_crowded_merged_b_release_test.json",
"gold_crowded_merged_c_release_test.json",
],
"fg_food": [
"gold_fg_food_merged_a_release_test.json",
"gold_fg_food_merged_b_release_test.json",
"gold_fg_food_merged_c_release_test.json",
],
"fg_sports_equipment": [
"gold_fg_sports_equipment_merged_a_release_test.json",
"gold_fg_sports_equipment_merged_b_release_test.json",
"gold_fg_sports_equipment_merged_c_release_test.json",
],
"attributes": [
"gold_attributes_merged_a_release_test.json",
"gold_attributes_merged_b_release_test.json",
"gold_attributes_merged_c_release_test.json",
],
"wiki_common": [
"gold_wiki_common_merged_a_release_test.json",
"gold_wiki_common_merged_b_release_test.json",
"gold_wiki_common_merged_c_release_test.json",
],
}
def load_processor(model_path):
"""Load the SAM3 processor."""
print(f"Loading processor from: {model_path}")
processor = Sam3Processor.from_pretrained(model_path, local_files_only=True)
print("Processor loaded successfully")
return processor
def load_om_model(model_path, device_id):
"""Load OM model on the specified device."""
print(f"Loading OM model: {model_path}, device={device_id}")
session = InferSession(device_id, model_path)
print("OM model loaded successfully")
return session
def collect_samples(gt_json_path, image_root):
"""
Collect evaluation samples from ground truth JSON.
Returns:
list of tuples: (image_path, text_input, image_id, height, width, category_id)
"""
with open(gt_json_path, 'r') as f:
data = json.load(f)
samples = []
for img in data['images']:
if not img.get('is_instance_exhaustive', False):
continue
img_id = img['id']
file_name = img['file_name']
text_input = img.get('text_input', '')
width, height = img['width'], img['height']
if not text_input:
continue
img_path = os.path.join(image_root, file_name)
if not os.path.exists(img_path):
continue
samples.append((img_path, text_input, img_id, height, width, img.get('queried_category', 1)))
return samples
def rle_encode(orig_mask, return_areas=False):
"""Encodes a collection of masks in RLE format
This function emulates the behavior of the COCO API's encode function, but
is executed partially on the GPU for faster execution.
Args:
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
return_areas (bool): If True, add the areas of the masks as a part of
the RLE output dict under the "area" key. Default is False.
Returns:
str: The RLE encoded masks
"""
if orig_mask.ndim != 3:
raise ValueError(f"Mask must be of shape (N, H, W), got {orig_mask.ndim} dimensions")
if orig_mask.dtype != torch.bool:
raise TypeError(f"Mask must have dtype=torch.bool, got {orig_mask.dtype}")
if orig_mask.numel() == 0:
return []
mask = orig_mask.transpose(1, 2)
flat_mask = mask.reshape(mask.shape[0], -1)
if return_areas:
mask_areas = flat_mask.sum(-1).tolist()
differences = torch.ones(
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
)
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
differences[:, 0] = flat_mask[:, 0]
_, change_indices = torch.where(differences)
try:
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
except RuntimeError as _:
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
change_indices_clone = change_indices.clone()
for i in range(mask.shape[0]):
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
change_indices[beg + 1: end] -= change_indices_clone[beg: end - 1]
change_indices = change_indices.tolist()
batch_rles = []
for i in range(mask.shape[0]):
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
run_lengths = change_indices[beg:end]
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
h, w = uncompressed_rle["size"]
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8")
if return_areas:
rle["area"] = mask_areas[i]
batch_rles.append(rle)
return batch_rles
def prepare_batch_inputs(batch_samples, processor, model_batch_size):
"""
Prepare input tensors for the model.
Returns:
pixel_values, input_ids, attention_mask, actual_batch_size
"""
actual_bs = len(batch_samples)
images = [Image.open(s[0]).convert('RGB') for s in batch_samples]
pixel_values = processor.image_processor(images, return_tensors="pt")["pixel_values"]
texts = [s[1] for s in batch_samples]
tokenizer = processor.tokenizer
text_inputs = tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding='max_length',
max_length=SEQ_LEN
)
input_ids = text_inputs["input_ids"]
attn_mask = text_inputs["attention_mask"]
cur_len = input_ids.shape[1]
if cur_len > SEQ_LEN:
input_ids = input_ids[:, :SEQ_LEN]
attn_mask = attn_mask[:, :SEQ_LEN]
elif cur_len < SEQ_LEN:
pad_len = SEQ_LEN - cur_len
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0)
attn_mask = torch.nn.functional.pad(attn_mask, (0, pad_len), value=0)
if actual_bs < model_batch_size:
pad_num = model_batch_size - actual_bs
pixel_values = torch.cat([pixel_values, pixel_values[0:1].repeat(pad_num, 1, 1, 1)], dim=0)
input_ids = torch.cat([input_ids, input_ids[0:1].repeat(pad_num, 1)], dim=0)
attn_mask = torch.cat([attn_mask, attn_mask[0:1].repeat(pad_num, 1)], dim=0)
return pixel_values, input_ids, attn_mask, actual_bs
def model_infer(session, inputs):
"""Run inference with the OM session."""
return session.infer(inputs)
def process_batch_core(session, processor, pixel_values, input_ids, attn_mask,
batch_samples, actual_batch_size, model_batch_size, threshold):
"""
Core inference and post-processing for a batch.
Returns:
List of COCO-style predictions.
"""
pixel_np = np.ascontiguousarray(pixel_values.numpy().astype(np.float32))
input_np = np.ascontiguousarray(input_ids.numpy().astype(np.int64))
attn_np = np.ascontiguousarray(attn_mask.numpy().astype(np.int64))
outputs = model_infer(session, [pixel_np, input_np, attn_np])
pred_masks = torch.from_numpy(outputs[0].reshape(model_batch_size, NUM_QUERIES, MASK_SIZE, MASK_SIZE))
pred_logits = torch.from_numpy(outputs[1].reshape(model_batch_size, NUM_QUERIES))
pred_boxes = torch.from_numpy(outputs[2].reshape(model_batch_size, NUM_QUERIES, 4))
presence_logits = torch.from_numpy(outputs[3].reshape(model_batch_size, 1))
semantic_seg = torch.from_numpy(outputs[4].reshape(model_batch_size, 1, MASK_SIZE, MASK_SIZE))
pred_masks = pred_masks[:actual_batch_size]
pred_logits = pred_logits[:actual_batch_size]
pred_boxes = pred_boxes[:actual_batch_size]
presence_logits = presence_logits[:actual_batch_size]
semantic_seg = semantic_seg[:actual_batch_size]
target_sizes = [[s[3], s[4]] for s in batch_samples]
class ModelOutputs:
def __init__(self, pred_masks, pred_logits, pred_boxes, presence_logits, semantic_seg):
self.pred_masks = pred_masks
self.pred_logits = pred_logits
self.pred_boxes = pred_boxes
self.presence_logits = presence_logits
self.semantic_seg = semantic_seg
outputs_obj = ModelOutputs(
pred_masks,
pred_logits,
pred_boxes,
presence_logits,
semantic_seg
)
batch_results = processor.post_process_instance_segmentation(
outputs_obj,
threshold=threshold,
mask_threshold=threshold,
target_sizes=target_sizes
)
all_preds = []
for i, result in enumerate(batch_results):
img_id = batch_samples[i][2]
masks = result.get('masks', [])
scores = result.get('scores', []).tolist()
boxes = result.get('boxes', []).tolist()
if len(scores) == 0:
continue
masks_rle = rle_encode(masks.bool())
for j, score in enumerate(scores):
all_preds.append({
"image_id": img_id,
"category_id": 1,
"bbox": boxes[j],
"score": score,
"segmentation": masks_rle[j]
})
return all_preds
def worker_process(args_tuple):
"""
Worker function for a single subset. Executed in a separate process.
"""
(subset_key, gt_file, img_root, device_id,
model_path, processor_path, batch_size, threshold, output_dir) = args_tuple
print(f"[Device {device_id}] Processing subset: {subset_key}")
try:
processor = load_processor(processor_path)
session = load_om_model(model_path, device_id)
samples = collect_samples(gt_file, img_root)
if not samples:
print(f"[Device {device_id}] No valid samples for {subset_key}, skipping.")
return subset_key
all_preds = []
total_samples = len(samples)
with tqdm(total=total_samples, desc=f"[Device {device_id}] {subset_key}", unit='img') as pbar:
for i in range(0, len(samples), batch_size):
batch = samples[i:i + batch_size]
pixel_values, input_ids, attn_mask, actual_bs = prepare_batch_inputs(
batch, processor, batch_size
)
preds = process_batch_core(
session, processor, pixel_values, input_ids, attn_mask,
batch, actual_bs, batch_size, threshold
)
all_preds.extend(preds)
pbar.update(len(batch))
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'predictions_{subset_key}.json')
with open(output_file, 'w') as f:
json.dump(all_preds, f)
print(f"[Device {device_id}] Subset {subset_key} finished. Saved {len(all_preds)} predictions -> {output_file}")
return subset_key
except Exception as e:
import traceback
print(f"[Device {device_id}] Subset {subset_key} failed: {e}\n{traceback.format_exc()}")
return None
def main():
parser = argparse.ArgumentParser(description='SAM3 OM Model Evaluation')
parser.add_argument('--dataset_root', type=str, required=True,
help='Root directory of the dataset')
parser.add_argument('--model_path', type=str, required=True,
help='Path to the OM model file')
parser.add_argument('--processor_path', type=str, required=True,
help='Path to the SAM3 processor (HuggingFace model)')
parser.add_argument('--output_dir', type=str, default='./gold_predictions',
help='Output directory for prediction JSON files')
parser.add_argument('--device', type=str, default='0',
help='Comma-separated list of device IDs (e.g., "0,1,2,3")')
parser.add_argument('--batch_size', type=int, default=4,
help='Inference batch size per device')
parser.add_argument('--procs_per_card', type=int, default=1,
help='Maximum number of concurrent processes per NPU')
parser.add_argument('--threshold', type=float, default=0.5,
help='Score threshold for mask and box filtering')
args = parser.parse_args()
devices = [int(d.strip()) for d in args.device.split(',') if d.strip()]
if not devices:
raise ValueError("At least one device must be specified.")
print(f"Available devices: {devices}, max processes per card: {args.procs_per_card}")
gt_dir = os.path.join(args.dataset_root, 'gt-annotations')
tasks = []
dev_idx = 0
for subset_key, gt_file_list in saco_gold_gts.items():
gt_file = os.path.join(gt_dir, gt_file_list[0])
if not os.path.exists(gt_file):
print(f"Skipping {subset_key}: GT file does not exist")
continue
if subset_key == "sa1b":
img_root = os.path.join(args.dataset_root, 'sa1b-images')
else:
img_root = os.path.join(args.dataset_root, 'metaclip-images')
if not os.path.exists(img_root):
print(f"Skipping {subset_key}: image directory does not exist")
continue
device_id = devices[dev_idx % len(devices)]
dev_idx += 1
tasks.append((
subset_key, gt_file, img_root, device_id,
args.model_path, args.processor_path,
args.batch_size, args.threshold, args.output_dir
))
if not tasks:
print("No valid tasks to process.")
return
tasks_by_device = {d: [] for d in devices}
for t in tasks:
device_id = t[3]
tasks_by_device[device_id].append(t)
mp.set_start_method('spawn', force=True)
executors = {}
all_futures = []
for dev, dev_tasks in tasks_by_device.items():
if not dev_tasks:
continue
print(f"Device {dev}: {len(dev_tasks)} task(s), max workers: {args.procs_per_card}")
executor = ProcessPoolExecutor(
max_workers=args.procs_per_card,
mp_context=mp.get_context('spawn')
)
executors[dev] = executor
for task in dev_tasks:
future = executor.submit(worker_process, task)
all_futures.append(future)
for future in tqdm(as_completed(all_futures), total=len(all_futures), desc="Overall progress"):
try:
result = future.result()
if result:
print(f"Completed: {result}")
except Exception as e:
print(f"Task failed: {e}")
for executor in executors.values():
executor.shutdown(wait=True)
print("All subsets processed successfully.")
if __name__ == '__main__':
main()