from typing import Union, Tuple
from PIL.Image import Image
import torch
import numpy as np
from sam2.utils.transforms import SAM2Transforms
IMAGE_SIZE = 1024
sam2Transforms = SAM2Transforms(
resolution=IMAGE_SIZE,
mask_threshold=0.0
)
def encoder_preprocessing(image: Union[np.ndarray, Image]):
"""
Arguments:
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
with pixel values in [0, 255].
return:
input_image (np.ndarray): 1x3xHxW
"""
if not isinstance(image, Image) and not isinstance(image, np.ndarray):
raise NotImplementedError("Image format not supported")
input_image = sam2Transforms(image)
input_image = input_image[None, ...].numpy()
return input_image
def decoder_preprocessing(high_res_feats_0: np.ndarray, high_res_feats_1: np.ndarray, \
image_embedding: np.ndarray, image_orig_hw: Tuple[int, int], \
point_coords: np.ndarray = None, point_labels: np.ndarray = None, box: np.ndarray = None, \
mask_input: np.ndarray = None):
"""
Prepare decoder inputs for SAM2 inference.
This function processes user prompts (points, boxes, masks) and combines them
with encoder features to create the input format expected by SAM2ImageDecoder.
Args:
high_res_feats_0: High-resolution feature map from encoder, shape (1, 32, 256, 256).
high_res_feats_1: Medium-resolution feature map from encoder, shape (1, 64, 128, 128).
image_embedding: Low-resolution image embedding from encoder, shape (1, 256, 64, 64).
image_orig_hw: Original image dimensions as (height, width) tuple, e.g., (1080, 1920).
Used to denormalize prompt coordinates.
point_coords: Optional point prompt coordinates, shape (N, 2) or (1, N, 2).
Coordinates can be normalized [0, 1] or pixel values. If None, no point prompts.
point_labels: Optional point prompt labels, shape (N,) or (1, N).
Each value is 1 (foreground), 0 (background), or -1 (ignore). Must be supplied
if point_coords is supplied.
box: Optional box prompt, shape (4,) or (1, 4) in [x1, y1, x2, y2] format.
Coordinates can be normalized [0, 1] or pixel values. If None, no box prompt.
mask_input: Optional low-resolution mask input from previous iteration,
shape (1, 256, 256) or (1, 1, 256, 256). Used for iterative refinement.
If None, will be replaced with zeros.
Returns:
decoder_inputs: List of 7 numpy arrays ready for SAM2ImageDecoder inference:
[0] image_embedding: Shape (1, 256, 64, 64)
[1] high_res_feats_0: Shape (1, 32, 256, 256)
[2] high_res_feats_1: Shape (1, 64, 128, 128)
[3] point_coords: Shape (1, N, 2), normalized to [0, 1]. N=number of points+boxes.
[4] point_labels: Shape (1, N), values in {0, 1, 2, 3} where:
0-1: point foreground/background
2-3: box top-left/bottom-right corners
[5] mask_input: Shape (1, 1, 256, 256), zeros if no mask provided
[6] has_mask_input: Shape (1,), binary flag (1.0 if mask provided, 0.0 otherwise)
Note:
- Box prompts are converted to 2 point prompts (top-left and bottom-right corners)
"""
mask_input, point_coords, point_labels, box = _prep_prompts(
point_coords=point_coords, point_labels=point_labels, box=box, mask_logits=mask_input, normalize_coords=True, image_orig_hw=image_orig_hw
)
if point_coords is not None:
concat_points = (point_coords, point_labels)
else:
concat_points = None
if box is not None:
box_coords = box.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int8)
box_labels = box_labels.repeat(box.size(0), 1)
if concat_points is not None:
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
concat_points = (concat_coords, concat_labels)
else:
concat_points = (box_coords, box_labels)
if concat_points is not None:
point_coords, point_labels = concat_points[0].cpu().numpy(), concat_points[1].cpu().numpy()
else:
point_coords = np.zeros((1, 1, 2), dtype=np.float32)
point_labels = np.zeros((1, 1), dtype=np.int8)
if mask_input is None:
has_mask_input = np.zeros(1, dtype=np.int8)
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
else:
has_mask_input = np.ones(1, dtype=np.int8)
decoder_inputs = [image_embedding, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input]
return decoder_inputs
def _prep_prompts(
point_coords, point_labels, box, mask_logits, normalize_coords, image_orig_hw
):
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
if point_coords is not None:
if point_labels is None:
raise ValueError("point_labels must be supplied if point_coords is supplied.")
point_coords = torch.as_tensor(
point_coords, dtype=torch.float
)
unnorm_coords = sam2Transforms.transform_coords(
point_coords, normalize=normalize_coords, orig_hw=image_orig_hw
)
labels = torch.as_tensor(point_labels, dtype=torch.int)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
if box is not None:
box = torch.as_tensor(box, dtype=torch.float)
unnorm_box = sam2Transforms.transform_boxes(
box, normalize=normalize_coords, orig_hw=image_orig_hw
)
if mask_logits is not None:
mask_input = torch.as_tensor(
mask_logits, dtype=torch.float
)
if len(mask_input.shape) == 3:
mask_input = mask_input[None, :, :, :]
mask_input = mask_input
return mask_input, unnorm_coords, labels, unnorm_box