39222781创建于 2025年9月23日历史提交
diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py
index 5c6f838..0bfaff2 100644
--- a/scripts/export_onnx_model.py
+++ b/scripts/export_onnx_model.py
@@ -6,8 +6,12 @@
 
 import torch
 
+from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
+from segment_anything.modeling.sam import Sam
 from segment_anything import sam_model_registry
 from segment_anything.utils.onnx import SamOnnxModel
+import onnx
+from onnx.external_data_helper import convert_model_to_external_data
 
 import argparse
 import warnings
@@ -24,11 +28,30 @@ parser = argparse.ArgumentParser(
 )
 
 parser.add_argument(
-    "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
+    "--checkpoint",
+    type=str,
+    required=True,
+    help="The path to the SAM model checkpoint.",
+)
+
+parser.add_argument(
+    "--encoder-output",
+    type=str,
+    required=True,
+    help="The filename to save the encoder ONNX model to.",
 )
 
 parser.add_argument(
-    "--output", type=str, required=True, help="The filename to save the ONNX model to."
+    "--encoder-data-file",
+    type=str,
+    help="The filename to save the external data for encoder ONNX model to. Use this if the encoder model is too large to be saved in a single file.",
+)
+
+parser.add_argument(
+    "--decoder-output",
+    type=str,
+    required=True,
+    help="The filename to save the decoder ONNX model to.",
 )
 
 parser.add_argument(
@@ -56,11 +79,21 @@ parser.add_argument(
 )
 
 parser.add_argument(
-    "--quantize-out",
+    "--quantize-encoder-out",
+    type=str,
+    default=None,
+    help=(
+        "If set, will quantize the encoder model and save it with this name. "
+        "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
+    ),
+)
+
+parser.add_argument(
+    "--quantize-decoder-out",
     type=str,
     default=None,
     help=(
-        "If set, will quantize the model and save it with this name. "
+        "If set, will quantize the decoder model and save it with this name. "
         "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
     ),
 )
@@ -97,7 +130,9 @@ parser.add_argument(
 def run_export(
     model_type: str,
     checkpoint: str,
-    output: str,
+    encoder_output: str,
+    encoder_data_file: str,
+    decoder_output: str,
     opset: int,
     return_single_mask: bool,
     gelu_approximate: bool = False,
@@ -107,6 +142,74 @@ def run_export(
     print("Loading model...")
     sam = sam_model_registry[model_type](checkpoint=checkpoint)
 
+    export_encoder(sam, encoder_output, opset, encoder_data_file)
+
+    export_decoder(
+        sam,
+        decoder_output,
+        opset,
+        return_single_mask,
+        gelu_approximate,
+        use_stability_score,
+        return_extra_metrics,
+    )
+
+
+def export_encoder(sam: Sam, output: str, opset: int, encoder_data_file: str):
+    dynamic_axes = {
+        "x": {0: "batch"},
+    }
+    dummy_inputs = {
+        "x": torch.randn(1, 3, 1024, 1024, dtype=torch.float),
+    }
+    _ = sam.image_encoder(**dummy_inputs)
+
+    output_names = ["image_embeddings"]
+
+    with warnings.catch_warnings():
+        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
+        warnings.filterwarnings("ignore", category=UserWarning)
+        print(f"Exporing onnx model to {output}...")
+        torch.onnx.export(
+            sam.image_encoder,
+            tuple(dummy_inputs.values()),
+            output,
+            export_params=True,
+            verbose=False,
+            opset_version=opset,
+            do_constant_folding=True,
+            input_names=list(dummy_inputs.keys()),
+            output_names=output_names,
+            dynamic_axes=dynamic_axes,
+        )
+
+    if encoder_data_file:
+        onnx_model = onnx.load(output)
+        convert_model_to_external_data(
+            onnx_model,
+            all_tensors_to_one_file=True,
+            location=encoder_data_file,
+            size_threshold=1024,
+            convert_attribute=False,
+        )
+        onnx.save_model(onnx_model, output)
+
+    if onnxruntime_exists:
+        ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
+        ort_session = onnxruntime.InferenceSession(output)
+        _ = ort_session.run(None, ort_inputs)
+        print("Encoder has successfully been run with ONNXRuntime.")
+
+
+def export_decoder(
+    sam: Sam,
+    output: str,
+    opset: int,
+    return_single_mask: bool,
+    gelu_approximate: bool,
+    use_stability_score: bool,
+    return_extra_metrics: bool,
+):
     onnx_model = SamOnnxModel(
         model=sam,
         return_single_mask=return_single_mask,
@@ -129,16 +232,17 @@ def run_export(
     mask_input_size = [4 * x for x in embed_size]
     dummy_inputs = {
         "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
-        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
+        "point_coords": torch.randint(
+            low=0, high=1024, size=(1, 5, 2), dtype=torch.float
+        ),
         "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
         "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
         "has_mask_input": torch.tensor([1], dtype=torch.float),
-        "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
     }
 
     _ = onnx_model(**dummy_inputs)
 
-    output_names = ["masks", "iou_predictions", "low_res_masks"]
+    output_names = ["iou_predictions", "low_res_masks"]
 
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
@@ -164,7 +268,7 @@ def run_export(
         providers = ["CPUExecutionProvider"]
         ort_session = onnxruntime.InferenceSession(output, providers=providers)
         _ = ort_session.run(None, ort_inputs)
-        print("Model has successfully been run with ONNXRuntime.")
+        print("Decoder has successfully been run with ONNXRuntime.")
 
 
 def to_numpy(tensor):
@@ -176,7 +280,9 @@ if __name__ == "__main__":
     run_export(
         model_type=args.model_type,
         checkpoint=args.checkpoint,
-        output=args.output,
+        encoder_output=args.encoder_output,
+        encoder_data_file=args.encoder_data_file,
+        decoder_output=args.decoder_output,
         opset=args.opset,
         return_single_mask=args.return_single_mask,
         gelu_approximate=args.gelu_approximate,
@@ -184,18 +290,34 @@ if __name__ == "__main__":
         return_extra_metrics=args.return_extra_metrics,
     )
 
-    if args.quantize_out is not None:
+    if args.quantize_encoder_out is not None:
         assert onnxruntime_exists, "onnxruntime is required to quantize the model."
         from onnxruntime.quantization import QuantType  # type: ignore
         from onnxruntime.quantization.quantize import quantize_dynamic  # type: ignore
 
-        print(f"Quantizing model and writing to {args.quantize_out}...")
+        print(f"Quantizing encoder model and writing to {args.quantize_encoder_out}...")
         quantize_dynamic(
-            model_input=args.output,
-            model_output=args.quantize_out,
+            model_input=args.encoder_output,
+            model_output=args.quantize_encoder_out,
             optimize_model=True,
             per_channel=False,
             reduce_range=False,
             weight_type=QuantType.QUInt8,
         )
         print("Done!")
+
+    if args.quantize_decoder_out is not None:
+        assert onnxruntime_exists, "onnxruntime is required to quantize the model."
+        from onnxruntime.quantization import QuantType  # type: ignore
+        from onnxruntime.quantization.quantize import quantize_dynamic  # type: ignore
+
+        print(f"Quantizing decoder model and writing to {args.quantize_decoder_out}...")
+        quantize_dynamic(
+            model_input=args.decoder_output,
+            model_output=args.quantize_decoder_out,
+            optimize_model=True,
+            per_channel=False,
+            reduce_range=False,
+            weight_type=QuantType.QUInt8,
+        )
+        print("Done!")
\ No newline at end of file
diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py
index 66351d9..31d622c 100644
--- a/segment_anything/modeling/image_encoder.py
+++ b/segment_anything/modeling/image_encoder.py
@@ -253,8 +253,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
     """
     B, H, W, C = x.shape
 
-    pad_h = (window_size - H % window_size) % window_size
-    pad_w = (window_size - W % window_size) % window_size
+    pad_h = 6
+    pad_w = 6
     if pad_h > 0 or pad_w > 0:
         x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
     Hp, Wp = H + pad_h, W + pad_w
@@ -322,6 +322,15 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
     return rel_pos_resized[relative_coords.long()]
 
 
+def forge_einsum(equation, a, b):
+    if equation == 'bhwc,hkc->bhwk':
+        return torch.sum(a.unsqueeze(3) * b.unsqueeze(0).unsqueeze(2), dim=4)
+    elif equation == 'bhwc,wkc->bhwk':
+        return torch.sum(a.unsqueeze(3) * b.unsqueeze(0).unsqueeze(1), dim=4)
+    else:
+        raise Exception('Unkown equation')
+
+
 def add_decomposed_rel_pos(
     attn: torch.Tensor,
     q: torch.Tensor,
@@ -351,8 +360,8 @@ def add_decomposed_rel_pos(
 
     B, _, dim = q.shape
     r_q = q.reshape(B, q_h, q_w, dim)
-    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
-    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+    rel_h = forge_einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = forge_einsum("bhwc,wkc->bhwk", r_q, Rw)
 
     attn = (
         attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py
index 5d2fdb0..ee8da94 100644
--- a/segment_anything/modeling/mask_decoder.py
+++ b/segment_anything/modeling/mask_decoder.py
@@ -123,9 +123,15 @@ class MaskDecoder(nn.Module):
         tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
 
         # Expand per-image data in batch direction to be per-mask
-        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        N = tokens.shape[0]
+        B, C, H, W = image_embeddings.shape
+        src = image_embeddings.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W)
+
         src = src + dense_prompt_embeddings
-        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+
+        B, C, H, W = image_pe.shape
+        pos_src = image_pe.unsqueeze(1).expand(B, N, C, H, W).reshape(B * N, C, H, W)
+
         b, c, h, w = src.shape
 
         # Run the transformer
diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py
index 3196bdf..e718afc 100644
--- a/segment_anything/utils/onnx.py
+++ b/segment_anything/utils/onnx.py
@@ -112,7 +112,6 @@ class SamOnnxModel(nn.Module):
         point_labels: torch.Tensor,
         mask_input: torch.Tensor,
         has_mask_input: torch.Tensor,
-        orig_im_size: torch.Tensor,
     ):
         sparse_embedding = self._embed_points(point_coords, point_labels)
         dense_embedding = self._embed_masks(mask_input, has_mask_input)
@@ -131,14 +130,4 @@ class SamOnnxModel(nn.Module):
 
         if self.return_single_mask:
             masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
-
-        upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
-
-        if self.return_extra_metrics:
-            stability_scores = calculate_stability_score(
-                upscaled_masks, self.model.mask_threshold, self.stability_score_offset
-            )
-            areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
-            return upscaled_masks, scores, stability_scores, areas, masks
-
-        return upscaled_masks, scores, masks
+        return scores, masks