import os
import argparse
import cv2
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
from ais_bench.infer.interface import InferSession
from sam_preprocessing_pytorch import encoder_preprocessing, decoder_preprocessing
from sam_postprocessing_pytorch import sam_postprocessing
def rle_to_mask(rle, h, w):
"""COCO segmentation → binary mask (h,w) uint8."""
if isinstance(rle, list):
rles = maskUtils.frPyObjects(rle, h, w)
rle = maskUtils.merge(rles)
elif isinstance(rle, dict) and isinstance(rle.get("counts"), list):
rle = maskUtils.frPyObjects(rle, h, w)
return maskUtils.decode(rle).astype(np.uint8)
def compute_iou(pred_mask, gt_mask):
pred = (pred_mask > 0).astype(np.uint8)
gt = (gt_mask > 0).astype(np.uint8)
inter = (pred & gt).sum()
union = (pred | gt).sum()
return float(inter) / float(union) if union > 0 else 0.0
def coco_bbox_to_xyxy(bbox_xywh):
x, y, w, h = bbox_xywh
return [x, y, x + w, y + h]
def encoder_infer(session_encoder, x):
encoder_outputs = session_encoder.infer([x])
image_embedding = encoder_outputs[0]
return image_embedding
def decoder_infer(session_decoder, decoder_inputs):
decoder_outputs = session_decoder.infer(decoder_inputs, mode="dymdims", custom_sizes=[1000, 1000000])
low_res_masks = decoder_outputs[1]
return low_res_masks
def save_mask_overlay(masks, image, save_dir, image_name):
overlay = image.copy()
alpha = 0.5
for mask in masks:
if mask.sum() == 0:
continue
color = np.random.randint(0, 255, (3,), dtype=np.uint8)
overlay[mask > 0] = (overlay[mask > 0] * (1 - alpha) + color * alpha).astype(np.uint8)
base, ext = os.path.splitext(image_name)
save_path = os.path.join(save_dir, f"{base}_sam_pre{ext}")
cv2.imwrite(save_path, overlay)
def evaluate_sam_on_coco(coco_root, save_path, encoder, decoder, max_instances=0):
ann_file = os.path.join(coco_root, "annotations", "instances_val2017.json")
img_root = os.path.join(coco_root, "val2017")
if not os.path.isfile(ann_file):
raise FileNotFoundError(f"COCO annotations not found: {ann_file}")
if not os.path.isdir(img_root):
raise FileNotFoundError(f"COCO val2017 images not found: {img_root}")
coco = COCO(ann_file)
img_ids = coco.getImgIds()
session_encoder = encoder
session_decoder = decoder
ious = []
counted = 0
for img_id in tqdm(img_ids, desc="Evaluating"):
img_info = coco.loadImgs(img_id)[0]
img_path = os.path.join(img_root, img_info["file_name"])
image = cv2.imread(img_path)
H, W = image.shape[:2]
x = encoder_preprocessing(image)
image_embedding = encoder_infer(session_encoder, x)
ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
anns = coco.loadAnns(ann_ids)
mask_list = []
for ann in anns:
if max_instances > 0 and counted >= max_instances:
break
box_xyxy = coco_bbox_to_xyxy(ann["bbox"])
decoder_inputs = decoder_preprocessing(image_embedding, box=box_xyxy, image=image)
low_res_masks = decoder_infer(session_decoder, decoder_inputs)
masks = sam_postprocessing(low_res_masks, image)
pred2d = masks[0][0].astype(np.uint8)
mask_list.append(pred2d)
pred_bin = pred2d.astype(np.uint8)
gt_mask = rle_to_mask(ann["segmentation"], H, W)
iou = compute_iou(pred_bin, gt_mask)
ious.append(iou)
counted += 1
if save_path is not None and len(mask_list) > 0:
save_mask_overlay(mask_list, image, save_path, img_info["file_name"])
if max_instances > 0 and counted >= max_instances:
break
miou = float(np.mean(ious)) if counted > 0 else 0.0
print("\n=========== COCO Evaluation (Box Prompt) ===========")
print(f"Instances Evaluated : {counted}")
print(f"Mean IoU (mIoU) : {miou:.4f}")
print("====================================================\n")
def check_device_range_valid(value):
min_value = 0
max_value = 255
if ',' in value:
ilist = [int(v) for v in value.split(',')]
for ivalue in ilist[:2]:
if ivalue < min_value or ivalue > max_value:
raise argparse.ArgumentTypeError("{} of device:{} is invalid. valid value range is [{}, {}]".format(
ivalue, value, min_value, max_value))
return ilist[:2]
else:
ivalue = int(value)
if ivalue < min_value or ivalue > max_value:
raise argparse.ArgumentTypeError("device:{} is invalid. valid value range is [{}, {}]".format(
ivalue, min_value, max_value))
return ivalue
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='./datasets/', help='input path to coco dataset')
parser.add_argument('--save-path', type=str, default=None, help='output path to image')
parser.add_argument('--encoder-model-path', type=str, default='./models/encoder_sim.om', help='path to encoder model')
parser.add_argument('--decoder-model-path', type=str, default='./models/decoder_sim.om', help='path to decoder model')
parser.add_argument('--device-id', type=check_device_range_valid, default=0, help='NPU device id.')
parser.add_argument('--max-instances', type=int, default=0, help='Maximum number of instances to evaluate (0 = all).')
args = parser.parse_args()
if args.save_path and not os.path.exists(args.save_path):
os.makedirs(os.path.realpath(args.save_path), mode=0o744)
session_encoder = InferSession(args.device_id, args.encoder_model_path)
session_decoder = InferSession(args.device_id, args.decoder_model_path)
evaluate_sam_on_coco(
args.dataset_path,
args.save_path,
session_encoder,
session_decoder,
max_instances=args.max_instances
)
if __name__ == "__main__":
main()