diff --git a/demo/demo.py b/demo/demo.py
index 49e25f4..4adf03a 100755
--- a/demo/demo.py
+++ b/demo/demo.py
@@ -1,37 +1,41 @@
 import os
-
+import torch
+import torch_npu
+import numpy as np
+import torchair
+import matplotlib.pyplot as plt
 import cv2
 import imageio
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
 
 
 # use bfloat16 for the entire notebook
-torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
-
-if torch.cuda.get_device_properties(0).major >= 8:
-    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
-    torch.backends.cuda.matmul.allow_tf32 = True
-    torch.backends.cudnn.allow_tf32 = True
-
-import time
+torch.autocast(device_type="npu", dtype=torch.float16)
 
 from sam2.build_sam import build_sam2_camera_predictor
+import time
 
+# device = torch.device("npu:0")
+device = torch.device("npu")
+torch_npu.npu.set_compile_mode(jit_compile=False)
 
-sam2_checkpoint = "../checkpoints/sam2.1_hiera_small.pt"
+sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
 model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
 
-predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint)
+predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint, device=device)
 
+config = torchair.CompilerConfig()
+npu_backend = torchair.get_npu_backend(compiler_config=config)
+# 基于NPU backend进行compile
+predictor = torch.compile(predictor, backend=npu_backend)
 
-cap = cv2.VideoCapture("../notebooks/videos/aquarium/aquarium.mp4")
+cap = cv2.VideoCapture("./notebooks/videos/aquarium/aquarium.mp4")
 
 if_init = False
-tracking_i = 0
+frame_list = []
 
 while True:
+    start_time = time.perf_counter()
+
     ret, frame = cap.read()
     if not ret:
         break
@@ -45,19 +49,18 @@ while True:
         if_init = True
 
         ann_frame_idx = 0  # the frame index we interact with
-
-        # First annotation
         ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)
+        # Let's add a positive click at (x, y) = (210, 350) to get started
+
+
         ##! add points, `1` means positive click and `0` means negative click
-        points = np.array([[600, 255]], dtype=np.float32)
-        labels = np.array([1], dtype=np.int32)
+        # points = np.array([[660, 267]], dtype=np.float32)
+        # labels = np.array([1], dtype=np.int32)
 
-        _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
-            frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
-        )
+        # _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
+        #     frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
+        # )
 
-        # Second annotation
-        ann_obj_id = 2
         ## ! add bbox
         bbox = np.array([[600, 214], [765, 286]], dtype=np.float32)
         _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
@@ -74,56 +77,23 @@ while True:
         # )
 
     else:
-        out_obj_ids, out_mask_logits = predictor.track(frame)
-        tracking_i += 1
-
-        if tracking_i == 100:
-            predictor.add_conditioning_frame(frame)
-
-            ## ! add new bbox
-            bbox = np.array([[450, 280], [520, 340]], dtype=np.float32)
-            ann_obj_id = 2
-            predictor.add_new_prompt_during_track(
-                bbox=bbox,
-                obj_id=ann_obj_id,
-                if_new_target=False,
-                clear_old_points=False,
-            )
-
-        if tracking_i == 160:
-            predictor.add_conditioning_frame(frame)
-
-            # ! add new point
-            points = np.array([[460, 270]], dtype=np.float32)
-            labels = np.array([1], dtype=np.int32)
-            ann_obj_id = 1
-            predictor.add_new_prompt_during_track(
-                point=points,
-                labels=labels,
-                obj_id=ann_obj_id,
-                if_new_target=False,
-                clear_old_points=False,
-            )
-
-        all_mask = np.zeros((height, width, 3), dtype=np.uint8)
-        all_mask[..., 1] = 255
+        out_obj_ids, out_mask_logits = predictor.track(frame) #
+
+        all_mask = np.zeros((height, width, 1), dtype=np.uint8)
         # print(all_mask.shape)
         for i in range(0, len(out_obj_ids)):
             out_mask = (out_mask_logits[i] > 0.0).permute(1, 2, 0).cpu().numpy().astype(
                 np.uint8
             ) * 255
+            
+            all_mask = cv2.bitwise_or(all_mask, out_mask)
 
-            hue = (i + 3) / (len(out_obj_ids) + 3) * 255
-            all_mask[out_mask[..., 0] == 255, 0] = hue
-            all_mask[out_mask[..., 0] == 255, 2] = 255
-
-        all_mask = cv2.cvtColor(all_mask, cv2.COLOR_HSV2RGB)
+        all_mask = cv2.cvtColor(all_mask, cv2.COLOR_GRAY2RGB)
         frame = cv2.addWeighted(frame, 1, all_mask, 0.5, 0)
     frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-    cv2.imshow("frame", frame)
-
-    if cv2.waitKey(1) & 0xFF == ord("q"):
-        break
-
+    frame_list.append(frame.copy())
+    end_time = time.perf_counter()
+    elapsed_ms = (end_time - start_time) * 1000
+    print(f"当前帧处理耗时 {elapsed_ms:.2f} ms")
 cap.release()
-# gif = imageio.mimsave("./result.gif", frame_list, "GIF", duration=0.00085)
+gif = imageio.mimsave("./result.gif", frame_list, "GIF", duration=0.00085)
\ No newline at end of file
diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py
index 898e671..ee775af 100644
--- a/sam2/modeling/position_encoding.py
+++ b/sam2/modeling/position_encoding.py
@@ -233,6 +233,12 @@ def apply_rotary_enc(
         else:
             # torch.repeat on complex numbers may not be supported on non-CUDA devices
             # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
-            freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+            # freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
+            freqs_cis = freqs_cis.unsqueeze(2)
+            freqs_cis_real = freqs_cis.real
+            freqs_cis_imag = freqs_cis.imag
+            freqs_cis_real_after = freqs_cis_real.expand(-1, -1, r, -1, -1)
+            freqs_cis_imag_after = freqs_cis_imag.expand(-1, -1, r, -1, -1)
+            freqs_cis = torch.complex(freqs_cis_real_after, freqs_cis_imag_after).flatten(2, 3)
     xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
     return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
diff --git a/sam2/sam2_camera_predictor.py b/sam2/sam2_camera_predictor.py
index 1cc11b6..90f735c 100755
--- a/sam2/sam2_camera_predictor.py
+++ b/sam2/sam2_camera_predictor.py
@@ -10,6 +10,7 @@ import cv2
 import numpy as np
 
 import torch
+import torch_npu
 import torch.nn.functional as F
 
 from tqdm import tqdm
@@ -19,7 +20,8 @@ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores
 
 # torch._dynamo.config.capture_dynamic_output_shape_ops = True
 
-
+# device = torch.device("npu:0")
+device = torch.device("npu")
 class SAM2CameraPredictor(SAM2Base):
     """The predictor class to handle user interactions and manage inference states."""
 
@@ -53,17 +55,21 @@ class SAM2CameraPredictor(SAM2Base):
     ):
         if isinstance(img, np.ndarray):
             img_np = img
-            img_np = cv2.resize(img_np, (image_size, image_size)) / 255.0
+            img_np = cv2.resize(img_np, (image_size, image_size))
             height, width = img.shape[:2]
         else:
             img_np = (
-                np.array(img.convert("RGB").resize((image_size, image_size))) / 255.0
+                np.array(img.convert("RGB").resize((image_size, image_size)))
             )
             width, height = img.size
-        img = torch.from_numpy(img_np).permute(2, 0, 1).float()
+        # img = torch.from_numpy(img_np).permute(2, 0, 1).float()
+
+        # img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+        # img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+        img = torch.from_numpy(img_np).npu().permute(2, 0, 1).float() / 255.0
 
-        img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
-        img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+        img_mean = torch.tensor(img_mean, dtype=torch.float32, device=device)[:, None, None]
+        img_std = torch.tensor(img_std, dtype=torch.float32, device=device)[:, None, None]
         img -= img_mean
         img /= img_std
         return img, width, height
@@ -109,11 +115,11 @@ class SAM2CameraPredictor(SAM2Base):
         self.condition_state["offload_state_to_cpu"] = offload_state_to_cpu
         # the original video height and width, used for resizing final output scores
 
-        self.condition_state["device"] = torch.device("cuda")
+        self.condition_state["device"] = device
         if offload_state_to_cpu:
             self.condition_state["storage_device"] = torch.device("cpu")
         else:
-            self.condition_state["storage_device"] = torch.device("cuda")
+            self.condition_state["storage_device"] = device
         # inputs on each frame
         self.condition_state["point_inputs_per_obj"] = {}
         self.condition_state["mask_inputs_per_obj"] = {}
@@ -861,7 +867,7 @@ class SAM2CameraPredictor(SAM2Base):
         storage_device = self.condition_state["storage_device"]
         maskmem_features = current_out["maskmem_features"]
         if maskmem_features is not None:
-            maskmem_features = maskmem_features.to(torch.bfloat16)
+            maskmem_features = maskmem_features.to(torch.float16)
             maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
         pred_masks_gpu = current_out["pred_masks"]
         # potentially fill holes in the predicted masks
@@ -1055,7 +1061,7 @@ class SAM2CameraPredictor(SAM2Base):
         if backbone_out is None:
             # Cache miss -- we will run inference on a single image
             image = (
-                self.condition_state["images"][frame_idx].cuda().float().unsqueeze(0)
+                self.condition_state["images"][frame_idx].to(device).float().unsqueeze(0)
             )
             backbone_out = self.forward_image(image)
             # Cache the most recent frame's feature (for repeated interactions with
@@ -1082,7 +1088,7 @@ class SAM2CameraPredictor(SAM2Base):
 
     ###
     def _get_feature(self, img, batch_size):
-        image = img.cuda().float().unsqueeze(0)
+        image = img.to(device).float().unsqueeze(0)
         backbone_out = self.forward_image(image)
         expanded_image = image.expand(batch_size, -1, -1, -1)
         expanded_backbone_out = {
@@ -1144,7 +1150,7 @@ class SAM2CameraPredictor(SAM2Base):
         storage_device = self.condition_state["storage_device"]
         maskmem_features = current_out["maskmem_features"]
         if maskmem_features is not None:
-            maskmem_features = maskmem_features.to(torch.bfloat16)
+            maskmem_features = maskmem_features.to(torch.float16)
             maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
         pred_masks_gpu = current_out["pred_masks"]
         # potentially fill holes in the predicted masks
@@ -1195,7 +1201,7 @@ class SAM2CameraPredictor(SAM2Base):
 
         # optionally offload the output to CPU memory to save GPU space
         storage_device = self.condition_state["storage_device"]
-        maskmem_features = maskmem_features.to(torch.bfloat16)
+        maskmem_features = maskmem_features.to(torch.float16)
         maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
         # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
         maskmem_pos_enc = self._get_maskmem_pos_enc(
@@ -1405,7 +1411,7 @@ class SAM2CameraPredictorVOS(SAM2CameraPredictor):
                 NO_OBJ_SCORE,
             )
 
-        # convert masks from possibly bfloat16 (or float16) to float32
+        # convert masks from possibly float16 (or float16) to float32
         # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
         low_res_multimasks = low_res_multimasks.float()
         high_res_multimasks = F.interpolate(