@@ -1,37 +1,41 @@
import os
-
+import torch
+import torch_npu
+import numpy as np
+import torchair
+import matplotlib.pyplot as plt
import cv2
import imageio
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
# use bfloat16 for the entire notebook
-torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
-
-if torch.cuda.get_device_properties(0).major >= 8:
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
-import time
+torch.autocast(device_type="npu", dtype=torch.float16)
from sam2.build_sam import build_sam2_camera_predictor
+import time
+# device = torch.device("npu:0")
+device = torch.device("npu")
+torch_npu.npu.set_compile_mode(jit_compile=False)
-sam2_checkpoint = "../checkpoints/sam2.1_hiera_small.pt"
+sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
-predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint)
+predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint, device=device)
+config = torchair.CompilerConfig()
+npu_backend = torchair.get_npu_backend(compiler_config=config)
+# 基于NPU backend进行compile
+predictor = torch.compile(predictor, backend=npu_backend)
-cap = cv2.VideoCapture("../notebooks/videos/aquarium/aquarium.mp4")
+cap = cv2.VideoCapture("./notebooks/videos/aquarium/aquarium.mp4")
if_init = False
-tracking_i = 0
+frame_list = []
while True:
+ start_time = time.perf_counter()
+
ret, frame = cap.read()
if not ret:
break
@@ -45,19 +49,18 @@ while True:
if_init = True
ann_frame_idx = 0 # the frame index we interact with
-
- # First annotation
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+ # Let's add a positive click at (x, y) = (210, 350) to get started
+
+
##! add points, `1` means positive click and `0` means negative click
- points = np.array([[600, 255]], dtype=np.float32)
- labels = np.array([1], dtype=np.int32)
+ # points = np.array([[660, 267]], dtype=np.float32)
+ # labels = np.array([1], dtype=np.int32)
- _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
- frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
- )
+ # _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
+ # frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
+ # )
- # Second annotation
- ann_obj_id = 2
## ! add bbox
bbox = np.array([[600, 214], [765, 286]], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
@@ -74,56 +77,23 @@ while True:
# )
else:
- out_obj_ids, out_mask_logits = predictor.track(frame)
- tracking_i += 1
-
- if tracking_i == 100:
- predictor.add_conditioning_frame(frame)
-
- ## ! add new bbox
- bbox = np.array([[450, 280], [520, 340]], dtype=np.float32)
- ann_obj_id = 2
- predictor.add_new_prompt_during_track(
- bbox=bbox,
- obj_id=ann_obj_id,
- if_new_target=False,
- clear_old_points=False,
- )
-
- if tracking_i == 160:
- predictor.add_conditioning_frame(frame)
-
- # ! add new point
- points = np.array([[460, 270]], dtype=np.float32)
- labels = np.array([1], dtype=np.int32)
- ann_obj_id = 1
- predictor.add_new_prompt_during_track(
- point=points,
- labels=labels,
- obj_id=ann_obj_id,
- if_new_target=False,
- clear_old_points=False,
- )
-
- all_mask = np.zeros((height, width, 3), dtype=np.uint8)
- all_mask[..., 1] = 255
+ out_obj_ids, out_mask_logits = predictor.track(frame) #
+
+ all_mask = np.zeros((height, width, 1), dtype=np.uint8)
# print(all_mask.shape)
for i in range(0, len(out_obj_ids)):
out_mask = (out_mask_logits[i] > 0.0).permute(1, 2, 0).cpu().numpy().astype(
np.uint8
) * 255
+
+ all_mask = cv2.bitwise_or(all_mask, out_mask)
- hue = (i + 3) / (len(out_obj_ids) + 3) * 255
- all_mask[out_mask[..., 0] == 255, 0] = hue
- all_mask[out_mask[..., 0] == 255, 2] = 255
-
- all_mask = cv2.cvtColor(all_mask, cv2.COLOR_HSV2RGB)
+ all_mask = cv2.cvtColor(all_mask, cv2.COLOR_GRAY2RGB)
frame = cv2.addWeighted(frame, 1, all_mask, 0.5, 0)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- cv2.imshow("frame", frame)
-
- if cv2.waitKey(1) & 0xFF == ord("q"):
- break
-
+ frame_list.append(frame.copy())
+ end_time = time.perf_counter()
+ elapsed_ms = (end_time - start_time) * 1000
+ print(f"当前帧处理耗时 {elapsed_ms:.2f} ms")
cap.release()
-# gif = imageio.mimsave("./result.gif", frame_list, "GIF", duration=0.00085)
+gif = imageio.mimsave("./result.gif", frame_list, "GIF", duration=0.00085)
\ No newline at end of file
@@ -233,6 +233,12 @@ def apply_rotary_enc(
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
- freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+ # freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+ freqs_cis = freqs_cis.unsqueeze(2)
+ freqs_cis_real = freqs_cis.real
+ freqs_cis_imag = freqs_cis.imag
+ freqs_cis_real_after = freqs_cis_real.expand(-1, -1, r, -1, -1)
+ freqs_cis_imag_after = freqs_cis_imag.expand(-1, -1, r, -1, -1)
+ freqs_cis = torch.complex(freqs_cis_real_after, freqs_cis_imag_after).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
@@ -10,6 +10,7 @@ import cv2
import numpy as np
import torch
+import torch_npu
import torch.nn.functional as F
from tqdm import tqdm
@@ -19,7 +20,8 @@ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
-
+# device = torch.device("npu:0")
+device = torch.device("npu")
class SAM2CameraPredictor(SAM2Base):
"""The predictor class to handle user interactions and manage inference states."""
@@ -53,17 +55,21 @@ class SAM2CameraPredictor(SAM2Base):
):
if isinstance(img, np.ndarray):
img_np = img
- img_np = cv2.resize(img_np, (image_size, image_size)) / 255.0
+ img_np = cv2.resize(img_np, (image_size, image_size))
height, width = img.shape[:2]
else:
img_np = (
- np.array(img.convert("RGB").resize((image_size, image_size))) / 255.0
+ np.array(img.convert("RGB").resize((image_size, image_size)))
)
width, height = img.size
- img = torch.from_numpy(img_np).permute(2, 0, 1).float()
+ # img = torch.from_numpy(img_np).permute(2, 0, 1).float()
+
+ # img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ # img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+ img = torch.from_numpy(img_np).npu().permute(2, 0, 1).float() / 255.0
- img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
- img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+ img_mean = torch.tensor(img_mean, dtype=torch.float32, device=device)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32, device=device)[:, None, None]
img -= img_mean
img /= img_std
return img, width, height
@@ -109,11 +115,11 @@ class SAM2CameraPredictor(SAM2Base):
self.condition_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
- self.condition_state["device"] = torch.device("cuda")
+ self.condition_state["device"] = device
if offload_state_to_cpu:
self.condition_state["storage_device"] = torch.device("cpu")
else:
- self.condition_state["storage_device"] = torch.device("cuda")
+ self.condition_state["storage_device"] = device
# inputs on each frame
self.condition_state["point_inputs_per_obj"] = {}
self.condition_state["mask_inputs_per_obj"] = {}
@@ -861,7 +867,7 @@ class SAM2CameraPredictor(SAM2Base):
storage_device = self.condition_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
- maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(torch.float16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
@@ -1055,7 +1061,7 @@ class SAM2CameraPredictor(SAM2Base):
if backbone_out is None:
# Cache miss -- we will run inference on a single image
image = (
- self.condition_state["images"][frame_idx].cuda().float().unsqueeze(0)
+ self.condition_state["images"][frame_idx].to(device).float().unsqueeze(0)
)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
@@ -1082,7 +1088,7 @@ class SAM2CameraPredictor(SAM2Base):
###
def _get_feature(self, img, batch_size):
- image = img.cuda().float().unsqueeze(0)
+ image = img.to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
@@ -1144,7 +1150,7 @@ class SAM2CameraPredictor(SAM2Base):
storage_device = self.condition_state["storage_device"]
maskmem_features = current_out["maskmem_features"]
if maskmem_features is not None:
- maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(torch.float16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
pred_masks_gpu = current_out["pred_masks"]
# potentially fill holes in the predicted masks
@@ -1195,7 +1201,7 @@ class SAM2CameraPredictor(SAM2Base):
# optionally offload the output to CPU memory to save GPU space
storage_device = self.condition_state["storage_device"]
- maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(torch.float16)
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
maskmem_pos_enc = self._get_maskmem_pos_enc(
@@ -1405,7 +1411,7 @@ class SAM2CameraPredictorVOS(SAM2CameraPredictor):
NO_OBJ_SCORE,
)
- # convert masks from possibly bfloat16 (or float16) to float32
+ # convert masks from possibly float16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(