@@ -6,8 +6,12 @@
import torch
+from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
+from segment_anything.modeling.sam import Sam
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
+import onnx
+from onnx.external_data_helper import convert_model_to_external_data
import argparse
import warnings
@@ -24,11 +28,30 @@ parser = argparse.ArgumentParser(
)
parser.add_argument(
- "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="The path to the SAM model checkpoint.",
+)
+
+parser.add_argument(
+ "--encoder-output",
+ type=str,
+ required=True,
+ help="The filename to save the encoder ONNX model to.",
)
parser.add_argument(
- "--output", type=str, required=True, help="The filename to save the ONNX model to."
+ "--encoder-data-file",
+ type=str,
+ help="The filename to save the external data for encoder ONNX model to. Use this if the encoder model is too large to be saved in a single file.",
+)
+
+parser.add_argument(
+ "--decoder-output",
+ type=str,
+ required=True,
+ help="The filename to save the decoder ONNX model to.",
)
parser.add_argument(
@@ -56,11 +79,21 @@ parser.add_argument(
)
parser.add_argument(
- "--quantize-out",
+ "--quantize-encoder-out",
+ type=str,
+ default=None,
+ help=(
+ "If set, will quantize the encoder model and save it with this name. "
+ "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
+ ),
+)
+
+parser.add_argument(
+ "--quantize-decoder-out",
type=str,
default=None,
help=(
- "If set, will quantize the model and save it with this name. "
+ "If set, will quantize the decoder model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
@@ -97,7 +130,9 @@ parser.add_argument(
def run_export(
model_type: str,
checkpoint: str,
- output: str,
+ encoder_output: str,
+ encoder_data_file: str,
+ decoder_output: str,
opset: int,
return_single_mask: bool,
gelu_approximate: bool = False,
@@ -107,6 +142,74 @@ def run_export(
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
+ export_encoder(sam, encoder_output, opset, encoder_data_file)
+
+ export_decoder(
+ sam,
+ decoder_output,
+ opset,
+ return_single_mask,
+ gelu_approximate,
+ use_stability_score,
+ return_extra_metrics,
+ )
+
+
+def export_encoder(sam: Sam, output: str, opset: int, encoder_data_file: str):
+ dynamic_axes = {
+ "x": {0: "batch"},
+ }
+ dummy_inputs = {
+ "x": torch.randn(1, 3, 1024, 1024, dtype=torch.float),
+ }
+ _ = sam.image_encoder(**dummy_inputs)
+
+ output_names = ["image_embeddings"]
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
+ warnings.filterwarnings("ignore", category=UserWarning)
+ print(f"Exporing onnx model to {output}...")
+ torch.onnx.export(
+ sam.image_encoder,
+ tuple(dummy_inputs.values()),
+ output,
+ export_params=True,
+ verbose=False,
+ opset_version=opset,
+ do_constant_folding=True,
+ input_names=list(dummy_inputs.keys()),
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ )
+
+ if encoder_data_file:
+ onnx_model = onnx.load(output)
+ convert_model_to_external_data(
+ onnx_model,
+ all_tensors_to_one_file=True,
+ location=encoder_data_file,
+ size_threshold=1024,
+ convert_attribute=False,
+ )
+ onnx.save_model(onnx_model, output)
+
+ if onnxruntime_exists:
+ ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
+ ort_session = onnxruntime.InferenceSession(output)
+ _ = ort_session.run(None, ort_inputs)
+ print("Encoder has successfully been run with ONNXRuntime.")
+
+
+def export_decoder(
+ sam: Sam,
+ output: str,
+ opset: int,
+ return_single_mask: bool,
+ gelu_approximate: bool,
+ use_stability_score: bool,
+ return_extra_metrics: bool,
+):
onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
@@ -129,16 +232,17 @@ def run_export(
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
- "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
+ "point_coords": torch.randint(
+ low=0, high=1024, size=(1, 5, 2), dtype=torch.float
+ ),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
- "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
_ = onnx_model(**dummy_inputs)
- output_names = ["masks", "iou_predictions", "low_res_masks"]
+ output_names = ["iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
@@ -164,7 +268,7 @@ def run_export(
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(output, providers=providers)
_ = ort_session.run(None, ort_inputs)
- print("Model has successfully been run with ONNXRuntime.")
+ print("Decoder has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
@@ -176,7 +280,9 @@ if __name__ == "__main__":
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
- output=args.output,
+ encoder_output=args.encoder_output,
+ encoder_data_file=args.encoder_data_file,
+ decoder_output=args.decoder_output,
opset=args.opset,
return_single_mask=args.return_single_mask,
gelu_approximate=args.gelu_approximate,
@@ -184,18 +290,34 @@ if __name__ == "__main__":
return_extra_metrics=args.return_extra_metrics,
)
- if args.quantize_out is not None:
+ if args.quantize_encoder_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
- print(f"Quantizing model and writing to {args.quantize_out}...")
+ print(f"Quantizing encoder model and writing to {args.quantize_encoder_out}...")
quantize_dynamic(
- model_input=args.output,
- model_output=args.quantize_out,
+ model_input=args.encoder_output,
+ model_output=args.quantize_encoder_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
+
+ if args.quantize_decoder_out is not None:
+ assert onnxruntime_exists, "onnxruntime is required to quantize the model."
+ from onnxruntime.quantization import QuantType # type: ignore
+ from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
+
+ print(f"Quantizing decoder model and writing to {args.quantize_decoder_out}...")
+ quantize_dynamic(
+ model_input=args.decoder_output,
+ model_output=args.quantize_decoder_out,
+ optimize_model=True,
+ per_channel=False,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ )
+ print("Done!")
\ No newline at end of file
@@ -253,8 +253,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
"""
B, H, W, C = x.shape
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
+ pad_h = 6
+ pad_w = 6
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
@@ -322,6 +322,15 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
return rel_pos_resized[relative_coords.long()]
+def forge_einsum(equation, a, b):
+ if equation == 'bhwc,hkc->bhwk':
+ return torch.sum(a.unsqueeze(3) * b.unsqueeze(0).unsqueeze(2), dim=4)
+ elif equation == 'bhwc,wkc->bhwk':
+ return torch.sum(a.unsqueeze(3) * b.unsqueeze(0).unsqueeze(1), dim=4)
+ else:
+ raise Exception('Unkown equation')
+
+
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
@@ -351,8 +360,8 @@ def add_decomposed_rel_pos(
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+ rel_h = forge_einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = forge_einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
@@ -123,9 +123,15 @@ class MaskDecoder(nn.Module):
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ N = tokens.shape[0]
+ B, C, H, W = image_embeddings.shape
+ src = image_embeddings.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W)
+
src = src + dense_prompt_embeddings
- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+
+ B, C, H, W = image_pe.shape
+ pos_src = image_pe.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W)
+
b, c, h, w = src.shape
# Run the transformer
@@ -112,7 +112,6 @@ class SamOnnxModel(nn.Module):
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
- orig_im_size: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
@@ -131,14 +130,4 @@ class SamOnnxModel(nn.Module):
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
-
- upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
-
- if self.return_extra_metrics:
- stability_scores = calculate_stability_score(
- upscaled_masks, self.model.mask_threshold, self.stability_score_offset
- )
- areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
- return upscaled_masks, scores, stability_scores, areas, masks
-
- return upscaled_masks, scores, masks
+ return scores, masks