diff --git a/configs/config.py b/configs/config.py
index 221bf28..e924f95 100644
--- a/configs/config.py
+++ b/configs/config.py
@@ -16,7 +16,7 @@ def get_config():
     parser.add_argument('--train_source', type=str, default='Omni6DPose')
     parser.add_argument('--val_source', type=str, default='Omni6DPose')
     parser.add_argument('--test_source', type=str, default='Omni6DPose')
-    parser.add_argument('--device', type=str, default='cuda')
+    parser.add_argument('--device', type=str, default='npu:0')
     parser.add_argument('--num_points', type=int, default=1024)
     parser.add_argument('--per_obj', type=str, default='')
     parser.add_argument('--num_workers', type=int, default=32)
@@ -34,12 +34,18 @@ def get_config():
     parser.add_argument('--s_theta_mode', type=str, default='score') 
     parser.add_argument('--norm_energy', type=str, default='identical')
     parser.add_argument('--dino', type=str, default='pointwise') # none / global / pointwise
+    parser.add_argument('--pretrained_dino_model_path', type=str, default=None,
+                        help='Path to DINOv2 OM model for inference. If None, uses torch.hub PyTorch model.')
     parser.add_argument('--scale_embedding', type=int, default=180)
     
     
     """ training """
     parser.add_argument('--agent_type', type=str, default='score', help='one of the [score, energy, energy_with_ranking, scale]')
     parser.add_argument('--pretrained_score_model_path', type=str)
+    parser.add_argument('--pretrained_pointnet2_score_model_path', type=str, default=None,
+                        help='Path to PointNet2 OM model (from scorenet) for decoupled inference')
+    parser.add_argument('--pretrained_pointnet2_energy_model_path', type=str, default=None,
+                        help='Path to PointNet2 OM model (from energynet) for energy stage')
     parser.add_argument('--pretrained_energy_model_path', type=str)
     parser.add_argument('--pretrained_scale_model_path', type=str)
     parser.add_argument('--distillation', default=False, action='store_true')
diff --git a/datasets/datasets_infer.py b/datasets/datasets_infer.py
index 3ebfeea..61dc944 100644
--- a/datasets/datasets_infer.py
+++ b/datasets/datasets_infer.py
@@ -2,7 +2,6 @@ import numpy as np
 import cv2
 import torch
 import copy
-import open3d as o3d
 
 from cutoop.data_loader import Dataset, ImageMetaData
 from utils.datasets_utils import aug_bbox_eval, get_2d_coord_np, crop_resize_by_warp_affine
@@ -138,7 +137,7 @@ class InferDataset(object):
         data['roi_rgb_'] = torch.as_tensor(np.ascontiguousarray(roi_rgb_), dtype=torch.uint8).contiguous()
         data['roi_xs'] = torch.as_tensor(np.ascontiguousarray(xs), dtype=torch.int64).contiguous()
         data['roi_ys'] = torch.as_tensor(np.ascontiguousarray(ys), dtype=torch.int64).contiguous()
-        data['roi_center_dir'] = torch.as_tensor(pixel2xyz(img_height, img_height, bbox_center, intrinsics), dtype=torch.float32).contiguous()
+        data['roi_center_dir'] = torch.tensor(pixel2xyz(img_height, img_height, bbox_center, intrinsics), dtype=torch.float32).contiguous()
 
         return data
     
diff --git a/datasets/datasets_omni6dpose.py b/datasets/datasets_omni6dpose.py
index aa6cfd0..f196a86 100644
--- a/datasets/datasets_omni6dpose.py
+++ b/datasets/datasets_omni6dpose.py
@@ -616,7 +616,36 @@ def process_batch(batch_sample,
     processed_sample['zero_mean_gt_pose'][:, -3:] -= zero_mean
     processed_sample['pts_center'] = zero_mean
 
-    return processed_sample 
+    return processed_sample
+
+
+def process_batch_numpy(batch_sample, pose_mode='quat_wxyz'):
+    """Numpy-only version of process_batch for OM inference path.
+
+    Only produces the keys needed by OM inference: pts, roi_rgb, roi_xs, roi_ys, pts_center.
+    Skips gt_pose, sym_info, zero_mean_pts etc. which are only used for evaluation metrics.
+    """
+    processed_sample = {}
+
+    # pts: [bs, 1024, 3]
+    pts = batch_sample['pcl_in'].cpu().numpy().astype(np.float32)
+    processed_sample['pts'] = pts
+
+    # pts_center = mean of pts
+    zero_mean = np.mean(pts[:, :, :3], axis=1, keepdims=True)  # [bs, 1, 3]
+    processed_sample['pts_center'] = zero_mean[:, 0, :]  # [bs, 3]
+
+    # roi_rgb: [bs, 3, imgsize, imgsize]
+    roi_rgb = batch_sample['roi_rgb'].cpu().numpy().astype(np.float32)
+    processed_sample['roi_rgb'] = roi_rgb
+
+    # roi_xs, roi_ys: [bs, 1024]
+    roi_xs = batch_sample['roi_xs'].cpu().numpy().astype(np.int64)
+    roi_ys = batch_sample['roi_ys'].cpu().numpy().astype(np.int64)
+    processed_sample['roi_xs'] = roi_xs
+    processed_sample['roi_ys'] = roi_ys
+
+    return processed_sample
     
 
 if __name__ == '__main__':
diff --git a/datasets/datasets_tracking.py b/datasets/datasets_tracking.py
index c6c0ed7..d3d6fc9 100644
--- a/datasets/datasets_tracking.py
+++ b/datasets/datasets_tracking.py
@@ -113,6 +113,9 @@ class Omni6DPoseDataSet(data.Dataset):
 
         self.per_obj = per_obj
         self.per_obj_id = None
+        self._cached_depth = None
+        self._cached_mask = None
+        self._cached_rgb = None
 
         tmp = []
         for img_path in self.img_list:
@@ -144,10 +147,22 @@ class Omni6DPoseDataSet(data.Dataset):
         obj = valid_objects[index % self.num_valid]
         inst_name = obj.meta.oid
 
-        rgb = Dataset.load_color(img_path + "color.png")
-        depth = Dataset.load_depth(img_path + ('depth_syn' if self.cfg.perfect_depth else 'depth') + '.exr')
+        try:
+            rgb = Dataset.load_color(img_path + "color.png")
+            depth = Dataset.load_depth(img_path + ('depth_syn' if self.cfg.perfect_depth else 'depth') + '.exr')
+            mask = Dataset.load_mask(img_path + 'mask.exr')
+        except Exception as e:
+            print(f"[WARN] Failed to load data for {img_path}: {e}, using cached data from previous frame")
+            depth = self._cached_depth
+            mask = self._cached_mask
+            rgb = self._cached_rgb
+            used_cache = True
+        else:
+            self._cached_depth = depth
+            self._cached_mask = mask
+            self._cached_rgb = rgb
+            used_cache = False
         depth[depth > 1e3] = 0
-        mask = Dataset.load_mask(img_path + 'mask.exr')
         if not (mask.shape[:2] == depth.shape[:2] == rgb.shape[:2]):
             assert 0
             return self.__getitem__((index + 1) % self.__len__())
@@ -302,6 +317,7 @@ class Omni6DPoseDataSet(data.Dataset):
         data_dict['class_name'] = obj.meta.class_name
         data_dict['object_name'] = inst_name
         data_dict['is_valid'] = 1
+        data_dict['_corrupted'] = used_cache
 
         # xyz = depth2xyz(depth, intrinsics)
         # choose = np.logical_and(mask == inst_idx, depth > 0).flatten().nonzero()[0]
diff --git a/networks/gf_algorithms/samplers.py b/networks/gf_algorithms/samplers.py
index 5b1131d..f24b69a 100755
--- a/networks/gf_algorithms/samplers.py
+++ b/networks/gf_algorithms/samplers.py
@@ -203,8 +203,14 @@ def cond_ode_sampler(
         # num_steps, from T -> eps
         t_eval = np.linspace(T, eps, num_steps)
     res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', t_eval=t_eval)
-    xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
-    x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
+
+    # Print ODE solver statistics
+    print(f"ODE solver: {len(res.t)} steps computed (adaptive)")
+    print(f"Function evaluations: {res.nfev}")
+    print(f"Time range: {res.t[0]:.6f} → {res.t[-1]:.6f}")
+
+    xs = torch.tensor(res.y, device=device, dtype=torch.float32).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
+    x = torch.tensor(res.y[:, -1], device=device, dtype=torch.float32).reshape(shape) # [bs, pose_dim]
     # denoise, using the predictor step in P-C sampler
     if denoise:
         # Reverse diffusion predictor for denoising
@@ -251,10 +257,10 @@ def cond_edm_sampler(
         data, denoised = decoder(data)
         # recover data
         data['sampled_pose'], data['t'] = x_, t_
-        return denoised.to(torch.float64)
+        return denoised.to(torch.float32)
 
     # Main sampling loop.
-    x_next = latents.to(torch.float64) * t_steps[0]
+    x_next = latents.to(torch.float32) * t_steps[0]
     xs = []
     for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
         x_cur = x_next
diff --git a/networks/gf_algorithms/sde.py b/networks/gf_algorithms/sde.py
index bf98620..9c62eb3 100755
--- a/networks/gf_algorithms/sde.py
+++ b/networks/gf_algorithms/sde.py
@@ -19,13 +19,22 @@ def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
 
 def ve_sde(t, sigma_min=0.01, sigma_max=90):
     sigma = sigma_min * (sigma_max / sigma_min) ** t
-    drift_coeff = torch.tensor(0)
-    diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
+    device = t.device if hasattr(t, 'device') else 'cpu'
+    drift_coeff = torch.tensor(0, device=device)
+    diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=device))
+    return drift_coeff, diffusion_coeff
+
+def ve_sde_numpy(t, sigma_min=0.01, sigma_max=90):
+    """Pure numpy version of ve_sde for ODE solver integration."""
+    sigma = sigma_min * (sigma_max / sigma_min) ** t
+    drift_coeff = 0.0
+    diffusion_coeff = float(sigma) * np.sqrt(2 * (np.log(sigma_max) - np.log(sigma_min)))
     return drift_coeff, diffusion_coeff
 
 def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
     _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
-    return torch.randn(*shape) * sigma_max_prior
+    torch.manual_seed(0)
+    return torch.randn(*shape, dtype=torch.float32) * sigma_max_prior
 
 #----- VP SDE -----
 #------------------
@@ -42,7 +51,7 @@ def vp_sde(t, beta_0=0.1, beta_1=20):
     return drift_coeff, diffusion_coeff
 
 def vp_prior(shape, beta_0=0.1, beta_1=20):
-    return torch.randn(*shape)
+    return torch.randn(*shape, dtype=torch.float32)
 
 #----- sub-VP SDE -----
 #----------------------
@@ -70,12 +79,13 @@ def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
     return mean, std
 
 def edm_sde(t, sigma_min=0.002, sigma_max=80):
-    drift_coeff = torch.tensor(0)
+    device = t.device if hasattr(t, 'device') else 'cpu'
+    drift_coeff = torch.tensor(0, device=device)
     diffusion_coeff = torch.sqrt(2 * t)
     return drift_coeff, diffusion_coeff
 
 def edm_prior(shape, sigma_min=0.002, sigma_max=80):
-    return torch.randn(*shape) * sigma_max
+    return torch.randn(*shape, dtype=torch.float32) * sigma_max
 
 def init_sde(sde_mode):
     # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
diff --git a/networks/posenet.py b/networks/posenet.py
index 52da316..e9eae01 100644
--- a/networks/posenet.py
+++ b/networks/posenet.py
@@ -17,7 +17,6 @@ from configs.config import get_config
 from utils.genpose_utils import encode_axes
 
 
-
 class GFObjectPose(nn.Module):
     dino_name = 'dinov2_vits14'
     dino_dim = 384
@@ -47,32 +46,9 @@ class GFObjectPose(nn.Module):
             self.embedding_dim = GFObjectPose.embedding_dim
         
         ''' encode pts '''
-        if self.cfg.pts_encoder == 'pointnet':
-            assert cfg.dino != 'pointwise' # not supported yet
-            self.pts_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
-        elif self.cfg.pts_encoder == 'pointnet2':
-            if cfg.dino == 'pointwise':
-                self.pts_encoder = Pointnet2ClsMSGFus(self.dino_dim)
-            else:
-                self.pts_encoder = Pointnet2ClsMSG(0)
-        elif self.cfg.pts_encoder == 'pointnet_and_pointnet2':
-            assert cfg.dino != 'pointwise' # not supported yet
-            self.pts_pointnet_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
-            self.pts_pointnet2_encoder = Pointnet2ClsMSG(0)
-            self.fusion_layer = nn.Linear(2048, 1024)
-            self.act = nn.ReLU()
-        else:
-            raise NotImplementedError
-        
-        ''' score network'''
-        # if self.cfg.sde_mode == 'edm':
-        #     self.pose_score_net = PoseDecoderNet(
-        #         self.marginal_prob_fn,
-        #         sigma_data=1.4148, 
-        #         pose_mode=self.cfg.pose_mode, 
-        #         regression_head=self.cfg.regression_head
-        #     )
-        # else:
+
+        self.pts_encoder = Pointnet2ClsMSGFus(self.dino_dim)
+
         per_point_feat = False
         if self.cfg.agent_type == 'score':
             self.pose_score_net = PoseScoreNet(
@@ -98,19 +74,28 @@ class GFObjectPose(nn.Module):
 
         Args:
             data (dict): batch example without pointcloud feature. {'pts': [bs, num_pts, 3], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
+            precomputed_rgb_feat (torch.Tensor, optional): Pre-computed DINOv2 features [B, 1024, 384].
+                If provided, will skip internal DINOv2 computation and use these features directly.
         Returns:
             data (dict): batch example with pointcloud feature. {'pts': [bs, num_pts, 3], 'pts_feat': [bs, c], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
         """
         pts = data['pts']
         if self.cfg.dino == 'pointwise':
-            roi_rgb = data['roi_rgb']
-            feat = self.dino.get_intermediate_layers(roi_rgb)[0]
-            xs = data['roi_xs'] // 14
-            ys = data['roi_ys'] // 14
-            pos = xs * 16 + ys
-            pos = torch.unsqueeze(pos, -1).expand(-1, -1, self.dino_dim)
-            rgb_feat = torch.gather(feat, 1, pos)
-            rgb_feat.requires_grad_(False)
+            # Use precomputed features if provided, otherwise compute with DINOv2
+            precomputed_rgb_feat = getattr(data,'rgb_feat', None)
+            if precomputed_rgb_feat:
+                rgb_feat = precomputed_rgb_feat
+                rgb_feat = rgb_feat.to(pts.device)
+            else:
+                # Original path: compute DINOv2 features internally
+                roi_rgb = data['roi_rgb']
+                feat = self.dino.get_intermediate_layers(roi_rgb)[0]
+                xs = data['roi_xs'] // 14
+                ys = data['roi_ys'] // 14
+                pos = xs * 16 + ys
+                pos = torch.unsqueeze(pos, -1).expand(-1, -1, self.dino_dim)
+                rgb_feat = torch.gather(feat, 1, pos)
+                rgb_feat.requires_grad_(False)
         if self.cfg.pts_encoder == 'pointnet':
             assert 0
             pts_feat = self.pts_encoder(pts.permute(0, 2, 1))    # -> (bs, 3, 1024)
@@ -194,6 +179,7 @@ class GFObjectPose(nn.Module):
                 'pts_feat': [bs, c]
                 'sampled_pose': [bs, pose_dim]
                 't': [bs, 1]
+                'precomputed_rgb_feat': [bs, 1024, 384] (optional)
             }
         '''
         if mode == 'score':
diff --git a/networks/posenet_agent.py b/networks/posenet_agent.py
index e36c154..d7ce2e0 100644
--- a/networks/posenet_agent.py
+++ b/networks/posenet_agent.py
@@ -91,9 +91,6 @@ class PoseNet(nn.Module):
         else:
             net = self.get_network('ScaleNet')
         net = net.to(self.cfg.device)
-        if self.cfg.parallel:
-            device_ids = list(range(self.cfg.num_gpu))
-            net = nn.DataParallel(net, device_ids=device_ids).cuda()
         return net
     
 
@@ -167,14 +164,14 @@ class PoseNet(nn.Module):
         if not os.path.exists(load_path):
             raise ValueError("Checkpoint {} not exists.".format(load_path))
 
-        checkpoint = torch.load(load_path)
+        checkpoint = torch.load(load_path, map_location=self.cfg.device)
         print("Loading checkpoint from {} ...".format(load_path))
-        
+
         if isinstance(self.net, nn.DataParallel):
             self.net.module.load_state_dict(checkpoint['model_state_dict'])
         else:
             self.net.load_state_dict(checkpoint['model_state_dict'])
-        
+
         if not load_model_only:
             self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
             self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
diff --git a/networks/pts_encoder/pointnet2.py b/networks/pts_encoder/pointnet2.py
index cac0563..6ed1a3d 100644
--- a/networks/pts_encoder/pointnet2.py
+++ b/networks/pts_encoder/pointnet2.py
@@ -240,6 +240,7 @@ class Pointnet2ClsMSGFus(nn.Module):
                 )
             )
             channel_in = channel_out + input_channels
+        self.SA_modules[-1].forward = self.SA_modules[-1].forward_npoint_none
 
 
     def _break_up_pc(self, pc):
@@ -258,19 +259,30 @@ class Pointnet2ClsMSGFus(nn.Module):
         # features: bs * F * npoints
 
         l_xyz, l_features = [xyz], [features]
-        for i in range(len(self.SA_modules)):
-            if i != 0:
-                l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
+
+        # first
+        li_xyz, li_features, idx = self.SA_modules[0](l_xyz[0], l_features[0], return_idx=True)
+        l_xyz.append(li_xyz)
+        l_features.append(li_features)
+        features = torch.gather(features, 2, 
+                    torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1))
+        # middle
+        for i in range(1,len(self.SA_modules)-1):
+            l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
             li_xyz, li_features, idx = self.SA_modules[i](l_xyz[i], l_features[i], return_idx=True)
             l_xyz.append(li_xyz)
             l_features.append(li_features)
-            if idx != None:
-                features = torch.gather(
-                    features, 2, 
-                    torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1)
-                ) # only keep features of remaining points
-            else:
-                assert i == len(self.SA_modules) - 1
+
+            features = torch.gather(features, 2, 
+                torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1)) 
+        # last 
+        i += 1
+        l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
+        li_xyz, li_features, idx = self.SA_modules[i](l_xyz[i], l_features[i], return_idx=True)
+        l_xyz.append(li_xyz)
+        l_features.append(li_features)
+        assert i == len(self.SA_modules) - 1
+
         return l_features[-1].squeeze(-1)
 
 
diff --git a/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py b/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py
index 4c95848..25951ff 100755
--- a/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py
+++ b/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_modules.py
@@ -27,42 +27,35 @@ class _PointnetSAModuleBase(nn.Module):
             new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
             new_idx: (B, npoint) tensor of indices
         """
-        new_features_list = []
 
         xyz_flipped = xyz.transpose(1, 2).contiguous()
         idx = None
-        if new_xyz is None:
-            if self.npoint is not None:
-                idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
-                new_xyz = pointnet2_utils.gather_operation(
-                    xyz_flipped,
-                    idx
-                ).transpose(1, 2).contiguous()
-            else:
-                new_xyz = None
+        idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+        new_xyz = pointnet2_utils.gather_operation(
+            xyz_flipped,
+            idx
+        ).transpose(1, 2).contiguous()
+        return self.calculate_xyz_features_idx(xyz, features, new_xyz, idx)
+
+
+    def forward_npoint_none(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None, return_idx=False ,index=None):
+        return self.calculate_xyz_features_idx(xyz, features, None, None)
+
+    def calculate_xyz_features_idx(self, xyz, features, new_xyz, idx):
+        new_features_list = []
 
         for i in range(len(self.groupers)):
             new_features = self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)
 
             new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
 
-            if self.pool_method == 'max_pool':
-                new_features = F.max_pool2d(
-                    new_features, kernel_size=[1, new_features.size(3)]
-                )  # (B, mlp[-1], npoint, 1)
-            elif self.pool_method == 'avg_pool':
-                new_features = F.avg_pool2d(
-                    new_features, kernel_size=[1, new_features.size(3)]
-                )  # (B, mlp[-1], npoint, 1)
-            else:
-                raise NotImplementedError
+            new_features = torch.amax(new_features, dim=3, keepdim=True)
 
             new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
             new_features_list.append(new_features)
 
-        if return_idx:
-            return new_xyz, torch.cat(new_features_list, dim=1), idx
-        return new_xyz, torch.cat(new_features_list, dim=1)
+
+        return new_xyz, torch.cat(new_features_list, dim=1), idx
 
 
 class PointnetSAModuleMSG(_PointnetSAModuleBase):
diff --git a/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py b/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py
index 97a5466..6ae8ccf 100755
--- a/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py
+++ b/networks/pts_encoder/pointnet2_utils/pointnet2/pointnet2_utils.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 from typing import Tuple
 import sys
 
-import pointnet2_cuda as pointnet2
+import pointnet2_ops as pointnet2
 
 
 class FurthestPointSampling(Function):
@@ -20,14 +20,7 @@ class FurthestPointSampling(Function):
         :return:
              output: (B, npoint) tensor containing the set
         """
-        assert xyz.is_contiguous()
-
-        B, N, _ = xyz.size()
-        output = torch.cuda.IntTensor(B, npoint)
-        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
-
-        pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
-        return output
+        return pointnet2._furthest_point_sampling(xyz, npoint)
 
     @staticmethod
     def backward(xyz, a=None):
@@ -48,26 +41,13 @@ class GatherOperation(Function):
         :return:
             output: (B, C, npoint)
         """
-        assert features.is_contiguous()
-        assert idx.is_contiguous()
-
-        B, npoint = idx.size()
-        _, C, N = features.size()
-        output = torch.cuda.FloatTensor(B, C, npoint)
-
-        pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
-
-        ctx.for_backwards = (idx, C, N)
-        return output
+        ctx.for_backwards = (idx, features.shape[1], features.shape[2])
+        return pointnet2._gather_points(features, idx)
 
     @staticmethod
     def backward(ctx, grad_out):
         idx, C, N = ctx.for_backwards
-        B, npoint = idx.size()
-
-        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
-        grad_out_data = grad_out.data.contiguous()
-        pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
+        grad_features = pointnet2._gather_points_grad(grad_out, idx, N)
         return grad_features, None
 
 
@@ -92,8 +72,9 @@ class ThreeNN(Function):
 
         B, N, _ = unknown.size()
         m = known.size(1)
-        dist2 = torch.cuda.FloatTensor(B, N, 3)
-        idx = torch.cuda.IntTensor(B, N, 3)
+        device = known.device
+        dist2 = torch.empty((B, N, 3), dtype=torch.float32, device=device)
+        idx = torch.empty((B, N, 3), dtype=torch.int32, device=device)
 
         pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
         return torch.sqrt(dist2), idx
@@ -125,8 +106,9 @@ class ThreeInterpolate(Function):
 
         B, c, m = features.size()
         n = idx.size(1)
+        device = features.device
         ctx.three_interpolate_for_backward = (idx, weight, m)
-        output = torch.cuda.FloatTensor(B, c, n)
+        output = torch.empty((B, c, n), dtype=torch.float32, device=device)
 
         pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
         return output
@@ -143,8 +125,8 @@ class ThreeInterpolate(Function):
         """
         idx, weight, m = ctx.three_interpolate_for_backward
         B, c, n = grad_out.size()
-
-        grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
+        device = weight.device
+        grad_features = Variable(torch.empty((B, c, m), dtype=torch.float32, device=device).zero_())
         grad_out_data = grad_out.data.contiguous()
 
         pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
@@ -165,17 +147,8 @@ class GroupingOperation(Function):
         :return:
             output: (B, C, npoint, nsample) tensor
         """
-        assert features.is_contiguous()
-        assert idx.is_contiguous()
-
-        B, nfeatures, nsample = idx.size()
-        _, C, N = features.size()
-        output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
-
-        pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
-
-        ctx.for_backwards = (idx, N)
-        return output
+        ctx.for_backwards = (idx, features.shape[2])
+        return pointnet2._group_points(features, idx)
 
     @staticmethod
     def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -186,12 +159,7 @@ class GroupingOperation(Function):
             grad_features: (B, C, N) gradient of the features
         """
         idx, N = ctx.for_backwards
-
-        B, C, npoint, nsample = grad_out.size()
-        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
-
-        grad_out_data = grad_out.data.contiguous()
-        pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+        grad_features = pointnet2._group_points_grad(grad_out, idx, N)
         return grad_features, None
 
 
@@ -211,15 +179,7 @@ class BallQuery(Function):
         :return:
             idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
         """
-        assert new_xyz.is_contiguous()
-        assert xyz.is_contiguous()
-
-        B, N, _ = xyz.size()
-        npoint = new_xyz.size(1)
-        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
-
-        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
-        return idx
+        return pointnet2._ball_query(new_xyz, xyz, radius, nsample)
 
     @staticmethod
     def backward(ctx, a=None):
@@ -255,7 +215,7 @@ class QueryAndGroup(nn.Module):
         if features is not None:
             grouped_features = grouping_operation(features, idx)
             if self.use_xyz:
-                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (B, C + 3, npoint, nsample)
+                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)
             else:
                 new_features = grouped_features
         else:
diff --git a/requirements.txt b/requirements.txt
index a437ae6..6fa4786 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,9 +3,8 @@ scipy==1.12.0
 numpy==1.26.3
 tensorboardX==2.6.2.2
 tensorboard==2.17.0
-open3d==0.18.0
-pyrealsense2==2.55.1.6486
-ipdb
 matplotlib
 tqdm
-scikit-learn
\ No newline at end of file
+scikit-learn
+# open3d==0.18.0 # If use camera, need to install open3d.
+# pyrealsense2==2.56.5.9235 # If use camera, need to install pyrealsense2.
\ No newline at end of file
diff --git a/runners/evaluation_single.py b/runners/evaluation_single.py
index 7211d15..5615c40 100644
--- a/runners/evaluation_single.py
+++ b/runners/evaluation_single.py
@@ -8,6 +8,7 @@ from tqdm import tqdm
 import _pickle as cPickle
 import pickle
 import torch
+import random
 import torch.nn as nn
 import torch.nn.functional as F
 import copy
@@ -22,10 +23,13 @@ from ipdb import set_trace
 
 from networks.posenet_agent import PoseNet
 from networks.reward import sort_poses_by_energy, ranking_loss
-from datasets.datasets_omni6dpose import Omni6DPoseDataSet, array_to_SymLabel, array_to_CameraIntrinsicsBase, process_batch
+from om_wrappers import create_score_network, create_ode_sampler
+from networks.gf_algorithms.sde import init_sde
+from datasets.datasets_omni6dpose import Omni6DPoseDataSet, array_to_SymLabel, array_to_CameraIntrinsicsBase, process_batch, process_batch_numpy
 from utils.metrics import get_rot_matrix
 from utils.transforms import matrix_to_quaternion, quaternion_to_matrix
 from utils.misc import average_quaternion_batch
+from utils.genpose_utils import get_pose_dim
 from utils.so3_visualize import visualize_so3
 from utils.visualize import create_grid_image
 from cutoop.eval_utils import DetectMatch, Metrics
@@ -41,6 +45,21 @@ torch.cuda.manual_seed(cfg.seed)
 random.seed(cfg.seed)
 np.random.seed(cfg.seed)
 
+# Performance statistics
+perf_stats = {
+    'score_time': [],
+    'score_samples': 0,
+    'energy_time': [],
+    'energy_samples': 0,
+    'aggregate_time': [],
+    'aggregate_samples': 0,
+    'scale_time': [],
+    'scale_samples': 0,
+    'bbox_time': [],
+    'bbox_samples': 0,
+}
+
+
 def get_dataloader():
     dataset = Omni6DPoseDataSet(
         cfg=cfg,
@@ -59,12 +78,57 @@ def get_dataloader():
         shuffle=False,
         num_workers=cfg.num_workers,
         persistent_workers=True,
-        drop_last=False,
+        drop_last=True,
         pin_memory=True,
     )
     return dataloader
 
-dataloader = get_dataloader()
+
+_dino_model = None
+
+def get_dino_model():
+    """Lazy load DINOv2 model to save memory."""
+    global _dino_model
+    if _dino_model is None and cfg.dino != 'none':
+        dino_om_path = getattr(cfg, 'pretrained_dino_model_path', None)
+        if dino_om_path is not None and dino_om_path.endswith('.om'):
+            print(f"Loading DINOv2 OM model: {dino_om_path}")
+            from om_wrappers import DINOv2Wrapper
+            _dino_model = DINOv2Wrapper(dino_om_path, device=cfg.device)
+            print("DINOv2 OM loaded successfully")
+        else:
+            print("Loading DINOv2 model for preprocessing...")
+            import torch.hub
+            _dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(cfg.device)
+            _dino_model.requires_grad_(False)
+            print("DINOv2 loaded successfully")
+    return _dino_model
+
+
+def extract_dino_features(batch_sample):
+    """Extract DINOv2 features as a preprocessing step."""
+    dino = get_dino_model()
+    roi_rgb = batch_sample['roi_rgb']  # [B, 3, H, W]
+    roi_xs = batch_sample['roi_xs']    # [B, 1024]
+    roi_ys = batch_sample['roi_ys']    # [B, 1024]
+
+    # OM path: DINOv2Wrapper returns numpy, keep as numpy throughout
+    if is_om_model:
+        return dino(roi_rgb, roi_xs, roi_ys)
+
+    # PyTorch path: original get_intermediate_layers + gather
+    feat = dino.get_intermediate_layers(roi_rgb)[0]  # [B, 256, 384]
+
+    xs = roi_xs // 14
+    ys = roi_ys // 14
+    pos = xs * 16 + ys  # 224x224 input -> 16x16 feature map
+    pos = torch.unsqueeze(pos, -1).expand(-1, -1, 384)
+
+    rgb_feat = torch.gather(feat, 1, pos)  # [B, 1024, 384]
+    rgb_feat.requires_grad_(False)
+
+    return rgb_feat
+
 
 def inference_score(save_path):
     if os.path.exists(save_path):
@@ -77,16 +141,24 @@ def inference_score(save_path):
 
     all_pred_pose = []
     all_score_feature = []
+    total_samples = 0
 
     for i, test_batch in enumerate(tqdm(dataloader, desc="score sampling")):
+        start_time = time.time()
         batch_sample = process_batch(
-            batch_sample = test_batch, 
-            device=cfg.device, 
+            batch_sample = test_batch,
+            device=cfg.device,
             pose_mode=cfg.pose_mode,
         )
+
+        # Extract DINOv2 features as preprocessing (outside the model)
+        rgb_feat = extract_dino_features(batch_sample)
+        if rgb_feat is not None:
+            batch_sample['precomputed_rgb_feat'] = rgb_feat
+
         pred_results = score_agent.pred_func(
-            data=batch_sample, 
-            repeat_num=cfg.eval_repeat_num, 
+            data=batch_sample,
+            repeat_num=cfg.eval_repeat_num,
             T0=cfg.T0,
             return_average_res=False,
             return_process=False
@@ -97,41 +169,222 @@ def inference_score(save_path):
             'pts_feat': batch_sample['pts_feat'].cpu(),
             'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].cpu()),
         })
+        elapsed = time.time() - start_time
+        perf_stats['score_time'].append(elapsed)
+        total_samples += pred_pose.shape[0]
+        perf_stats['score_samples'] = total_samples
+
         if i % 4 == 3:
             gc.collect()
-    
     pickle.dump((all_pred_pose, all_score_feature), open(save_path, 'wb'))
 
+def inference_score_decoupled(save_path):
+    """
+    Unified decoupled inference using ScoreNetworkWrapper and ODESamplerExternal.
+
+    Supports both PyTorch and OM models:
+    - PyTorch: Uses internal PyTorch PointNet2 encoder
+    - OM: Uses separate PointNet2 OM + ScoreNet OM (if available)
+
+    This unified interface enables direct comparison between PyTorch and OM outputs
+    for debugging precision issues.
+
+    Args:
+        save_path: Path to save cached results
+    """
+    if os.path.exists(save_path):
+        return
+
+    # Initialize SDE components
+    prior_fn, _, sde_fn, sampling_eps, _ = init_sde('ve')
+
+    pointnet2_om_path = getattr(cfg, 'pretrained_pointnet2_score_model_path', None)
+
+
+    # Create Score Network wrapper (automatically uses OM or PyTorch)
+    score_net = create_score_network(
+        checkpoint_path=cfg.pretrained_score_model_path,
+        device=cfg.device,
+        pointnet2_om_path=pointnet2_om_path  # Pass None for PyTorch, path for OM
+    )
+    if is_om_model:
+        from networks.gf_algorithms.sde import ve_sde_numpy
+        sde_coeff_fn = ve_sde_numpy
+    else:
+        sde_coeff_fn = sde_fn
+    sde_dir = {'prior_fn': prior_fn, 'sde_fn': sde_coeff_fn}
+    # Create ODE sampler with the Score Network
+    sampler = create_ode_sampler(
+        score_network=score_net,
+        sde=sde_dir,
+        device=cfg.device
+    )
+
+    all_pred_pose = []
+    all_score_feature = []
+    total_samples = 0
+
+    print(f"\nRunning unified decoupled inference ({'OM' if score_net.is_om else 'PyTorch'})...")
+    for i, test_batch in enumerate(tqdm(dataloader, desc="score sampling")):
+        start_time = time.time()
+
+        if is_om_model:
+            batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+        else:
+            batch_sample = process_batch(
+                batch_sample=test_batch,
+                device=cfg.device,
+                pose_mode=cfg.pose_mode,
+            )
+
+        # Extract DINOv2 features as preprocessing
+        rgb_feat = extract_dino_features(batch_sample)
+        if rgb_feat is not None:
+            batch_sample['precomputed_rgb_feat'] = rgb_feat
+
+        # Extract point cloud features using unified interface
+        pts_feat = score_net.extract_pts_feat(
+            pts=batch_sample['pts'],
+            rgb_feat=rgb_feat
+        )
+
+        # Get batch info
+        bs = batch_sample['pts'].shape[0]
+        pose_dim = get_pose_dim(cfg.pose_mode)
+        pts_center = batch_sample.get('pts_center', None)
+
+        # Generate random initial values for all repeats at once
+        init_x_all = prior_fn((bs, cfg.eval_repeat_num, pose_dim), T=cfg.T0).cpu().numpy()
+        init_x_repeated = init_x_all.reshape(bs * cfg.eval_repeat_num, pose_dim)
+
+        # Repeat features and init_x to process all at once
+        if is_om_model:
+            pts_feat_repeated = np.repeat(pts_feat[np.newaxis, ...], cfg.eval_repeat_num, axis=1).reshape(bs * cfg.eval_repeat_num, -1)
+            pts_center_repeated = None if pts_center is None else \
+                np.repeat(pts_center[:, np.newaxis, :], cfg.eval_repeat_num, axis=1).reshape(bs * cfg.eval_repeat_num, -1)
+        else:
+            pts_feat_repeated = pts_feat.unsqueeze(1).repeat(1, cfg.eval_repeat_num, 1).view(bs * cfg.eval_repeat_num, -1)
+            pts_center_repeated = None if pts_center is None else \
+                pts_center.unsqueeze(1).repeat(1, cfg.eval_repeat_num, 1).view(bs * cfg.eval_repeat_num, -1)
+
+        # Single call to sampler for all repeats
+        with torch.no_grad():
+            _, sampled_pose = sampler.sample(
+                pts_feat=pts_feat_repeated,
+                rgb_feat=None,
+                batch_size=bs * cfg.eval_repeat_num,
+                pose_dim=pose_dim,
+                T=cfg.T0,
+                eps=sampling_eps,
+                rtol=1e-5,
+                atol=1e-5,
+                denoise=True,
+                init_x=init_x_repeated,
+                pts_center=pts_center_repeated
+            )
+
+        # Reshape result from [bs*repeat_num, pose_dim] to [bs, repeat_num, pose_dim]
+        if is_om_model:
+            pred_pose = sampled_pose.reshape(bs, cfg.eval_repeat_num, pose_dim)
+        else:
+            pred_pose = sampled_pose.view(bs, cfg.eval_repeat_num, pose_dim)
+
+        # Save pred_pose and features
+        all_pred_pose.append(pred_pose)
+        all_score_feature.append({
+            'pts_feat': pts_feat,
+            'rgb_feat': rgb_feat,
+        })
+
+        
+        elapsed = time.time() - start_time
+        perf_stats['score_time'].append(elapsed)
+        total_samples += pred_pose.shape[0]
+        perf_stats['score_samples'] = total_samples
+        if i % 4 == 3:
+            gc.collect()
+
+    pickle.dump((all_pred_pose, all_score_feature), open(save_path, 'wb'))
+    print(f"Unified decoupled inference complete! ({'OM' if score_net.is_om else 'PyTorch'})")
+    return score_net
+
+
 def inference_energy(score_path, save_path):
     if os.path.exists(save_path):
         return
     assert os.path.exists(score_path)
     all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
 
-    cfg.agent_type = 'energy'
-    energy_agent = PoseNet(cfg)
-    energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
-    energy_agent.eval()
+    if is_om_model:
+        from om_wrappers import EnergyNetWrapper, PointNet2EncoderWrapper
+        energy_net = EnergyNetWrapper(cfg.pretrained_energy_model_path, device=cfg.device)
+        # Load PointNet2 from energy checkpoint for pts_feat extraction
+        pointnet2_encoder = PointNet2EncoderWrapper(cfg.pretrained_pointnet2_energy_model_path, device=cfg.device)
+        print(f"Using EnergyNet OM: {cfg.pretrained_energy_model_path}")
+        print(f"Using PointNet2 (from energy): {cfg.pretrained_pointnet2_energy_model_path}")
+    else:
+        cfg.agent_type = 'energy'
+        energy_agent = PoseNet(cfg)
+        energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
+        energy_agent.eval()
 
     all_pred_energy = []
+    total_samples = 0
 
     for i, test_batch in enumerate(tqdm(dataloader, desc="energy")):
-        batch_sample = process_batch(
-            batch_sample = test_batch, 
-            device=cfg.device, 
-            pose_mode=cfg.pose_mode,
-        )
-        pred_energy = energy_agent.get_energy(
-            data=batch_sample, 
-            pose_samples=all_pred_pose[i], 
-            T=1e-5,
-            mode='test', 
-            extract_feature=True
-        )
-        all_pred_energy.append(pred_energy.cpu())
+        start_time = time.time()
+        if is_om_model:
+            batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+        else:
+            batch_sample = process_batch(
+                batch_sample = test_batch,
+                device=cfg.device,
+                pose_mode=cfg.pose_mode,
+            )
+
+        # Extract DINOv2 features as preprocessing (outside the model)
+        rgb_feat = extract_dino_features(batch_sample)
+        batch_sample['precomputed_rgb_feat'] = rgb_feat
+
+        if is_om_model:
+            bs = batch_sample['pts'].shape[0]
+            repeat_num = all_pred_pose[i].shape[1]
+
+            pointcloud = np.concatenate([batch_sample['pts'], rgb_feat], axis=-1)
+            pts_feat = pointnet2_encoder(pointcloud)
+
+            # Repeat pts_feat
+            repeated_pts_feat = np.repeat(pts_feat[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+            # Prepare sampled_pose with pts_center subtracted
+            pose_samples = all_pred_pose[i].reshape(bs * repeat_num, -1).astype(np.float32)
+            pts_center = batch_sample['pts_center']
+            repeated_pts_center = np.repeat(pts_center[:, np.newaxis, :], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+            pose_samples[:, -3:] -= repeated_pts_center
+
+            T = 1e-5
+            t = np.full((bs * repeat_num, 1), T, dtype=np.float32)
+
+            # OM inference
+            pred_energy = energy_net(repeated_pts_feat, pose_samples, t).reshape(bs, repeat_num, -1)
+        else:
+            pred_energy = energy_agent.get_energy(
+                data=batch_sample,
+                pose_samples=all_pred_pose[i],
+                T=1e-5,
+                mode='test',
+                extract_feature=True
+            )
+        all_pred_energy.append(pred_energy)
+
+        elapsed = time.time() - start_time
+        perf_stats['energy_time'].append(elapsed)
+        total_samples += pred_energy.shape[0]
+        perf_stats['energy_samples'] = total_samples
+
         if i % 4 == 3:
             gc.collect()
-    
+
     pickle.dump(all_pred_energy, open(save_path, 'wb'))
 
 def aggregate_pose(score_path, energy_path, save_path):
@@ -139,17 +392,20 @@ def aggregate_pose(score_path, energy_path, save_path):
         return
     assert os.path.exists(score_path)
     all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
-    if energy_path is not None:
-        assert os.path.exists(energy_path)
-        all_pred_energy = pickle.load(open(energy_path, 'rb'))
-    else:
-        all_pred_energy = [torch.ones(*(all_pred_pose[i].shape[:2]), 2) 
-                           for i in range(len(all_pred_pose))]
+
+    assert os.path.exists(energy_path)
+    all_pred_energy = pickle.load(open(energy_path, 'rb'))
+    # ensure tensors (OM energy path saves numpy)
+    if is_om_model:
+        all_pred_pose = [torch.from_numpy(i) for i in all_pred_pose]
+        all_pred_energy = [torch.from_numpy(i) for i in all_pred_energy]
 
     all_aggregated_pose = []
-    
+    total_samples = 0
+
     for i, (pred_pose, pred_energy) in enumerate(tqdm(zip(all_pred_pose, all_pred_energy), desc="aggregate")):
-        sorted_pose, sorted_energy = sort_poses_by_energy(pred_pose, pred_energy)
+        start_time = time.time()
+        sorted_pose, _ = sort_poses_by_energy(pred_pose, pred_energy)
         bs = pred_pose.shape[0]
         retain_num = int(cfg.eval_repeat_num * cfg.retain_ratio)
         good_pose = sorted_pose[:, :retain_num, :]
@@ -173,6 +429,12 @@ def aggregate_pose(score_path, energy_path, save_path):
         aggregated_pose[:, :3, :3] = quaternion_to_matrix(aggregated_quat_wxyz)
         aggregated_pose[:, :3, 3] = aggregated_trans
         all_aggregated_pose.append(aggregated_pose)
+
+        elapsed = time.time() - start_time
+        perf_stats['aggregate_time'].append(elapsed)
+        total_samples += bs
+        perf_stats['aggregate_samples'] = total_samples
+
         if i % 10 == 9:
             gc.collect()
     
@@ -181,6 +443,8 @@ def aggregate_pose(score_path, energy_path, save_path):
 def inference_scale(score_path, aggregate_path, save_path):
     if os.path.exists(save_path):
         return
+    scale_path = getattr(cfg, 'pretrained_scale_model_path', None)
+    assert os.path.exists(scale_path)
     assert os.path.exists(score_path)
     _, all_score_feature = pickle.load(open(score_path, 'rb'))
     assert os.path.exists(aggregate_path)
@@ -188,8 +452,10 @@ def inference_scale(score_path, aggregate_path, save_path):
 
     if cfg.pretrained_scale_model_path is None:
         all_final_length = []
+        total_samples = 0
 
         for i, test_batch in enumerate(tqdm(dataloader, desc="bbox")):
+            start_time = time.time()
             pcl: torch.Tensor = test_batch['pcl_in'] # [bs, 1024, 3]
             rotation: torch.Tensor = all_aggregated_pose[i][:, :3, :3] # [bs, 3, 3]
             rotation_t = torch.transpose(rotation, 1, 2) # [bs, 3, 3]
@@ -205,37 +471,66 @@ def inference_scale(score_path, aggregate_path, save_path):
             bbox_length *= 2
             all_final_length.append(bbox_length.cpu())
 
+            elapsed = time.time() - start_time
+            perf_stats['bbox_time'].append(elapsed)
+            total_samples += pcl.shape[0]
+            perf_stats['bbox_samples'] = total_samples
+
             if i % 10 == 9:
                 gc.collect()
 
         pickle.dump((all_aggregated_pose, all_final_length), open(save_path, 'wb'))
         return
-    
-    cfg.agent_type = 'scale'
-    scale_agent = PoseNet(cfg)
-    scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
-    scale_agent.eval()
+
+    if is_om_model:
+        from om_wrappers import ScaleNetWrapper
+        scale_net = ScaleNetWrapper(scale_path, device=cfg.device)
+        print(f"Using ScaleNet OM: {scale_path}")
+    else:
+        cfg.agent_type = 'scale'
+        scale_agent = PoseNet(cfg)
+        scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
+        scale_agent.eval()
 
     all_final_pose = []
     all_final_length = []
+    total_samples = 0
 
     for i, test_batch in enumerate(tqdm(dataloader, desc="scale")):
-        batch_sample = process_batch(
-            batch_sample = test_batch, 
-            device=cfg.device, 
-            pose_mode=cfg.pose_mode,
-        )
-        batch_sample.update({key: (None if value is None else value.to(cfg.device)) 
-                             for key, value in all_score_feature[i].items()})
-        batch_sample['axes'] = all_aggregated_pose[i][:, :3, :3].to(cfg.device)
-        cal_mat, length = scale_agent.pred_scale_func(batch_sample)
+        start_time = time.time()
+
+        pts_feat = all_score_feature[i]['pts_feat']
+        axes = all_aggregated_pose[i][:, :3, :3]
+
+        if is_om_model:
+            with torch.no_grad():
+                length = scale_net(pts_feat, axes)
+            cal_mat = axes  # pred_scale_func returns axes unchanged ("historical reasons")
+        else:
+            axes = axes.to(cfg.device)
+            batch_sample = process_batch(
+                batch_sample=test_batch,
+                device=cfg.device,
+                pose_mode=cfg.pose_mode,
+            )
+            batch_sample.update({key: (None if value is None else value.to(cfg.device))
+                                 for key, value in all_score_feature[i].items()})
+            batch_sample['axes'] = axes
+            cal_mat, length = scale_agent.pred_scale_func(batch_sample)
+
         final_pose = all_aggregated_pose[i].clone()
         final_pose[:, :3, :3] = cal_mat.cpu()
         all_final_pose.append(final_pose.cpu())
         all_final_length.append(length.cpu())
+
+        elapsed = time.time() - start_time
+        perf_stats['scale_time'].append(elapsed)
+        total_samples += length.shape[0]
+        perf_stats['scale_samples'] = total_samples
+
         if i % 4 == 3:
             gc.collect()
-    
+
     pickle.dump((all_final_pose, all_final_length), open(save_path, 'wb'))
 
 def get_detect_match(cls_path, save_path):
@@ -312,6 +607,52 @@ def print_metrics(dm_path, criterion_path, save_path):
     
     metrics.dump_json(save_path)
 
+def print_performance_stats():
+    """Print performance statistics for each inference stage."""
+    print("\n" + "="*60)
+    print("Performance Statistics")
+    print("="*60)
+
+    stages = [
+        ('Score Network', 'score_time', 'score_samples'),
+        ('Energy Network', 'energy_time', 'energy_samples'),
+        ('Pose Aggregation', 'aggregate_time', 'aggregate_samples'),
+        ('Scale Network', 'scale_time', 'scale_samples'),
+        ('Bbox Calculation', 'bbox_time', 'bbox_samples'),
+    ]
+
+    for stage_name, time_key, samples_key in stages:
+        if len(perf_stats[time_key]) > 0:
+            times = perf_stats[time_key]
+            samples = perf_stats[samples_key]
+            total_time = sum(times)
+            avg_batch_time = total_time / len(times)
+            fps = samples / total_time if total_time > 0 else 0
+
+            print(f"\n{stage_name}:")
+            print(f"  Total batches: {len(times)}")
+            print(f"  Total samples: {samples}")
+            print(f"  Total time: {total_time:.3f}s")
+            print(f"  Avg batch time: {avg_batch_time:.3f}s")
+            print(f"  FPS: {fps:.3f}")
+            print(f"  Avg latency per sample: {1000/fps if fps > 0 else 0:.2f}ms")
+
+    # Calculate overall statistics
+    total_samples = perf_stats['score_samples']
+    total_time = sum(perf_stats['score_time']) + sum(perf_stats.get('energy_time', [0])) + \
+                 sum(perf_stats.get('aggregate_time', [0])) + \
+                 sum(perf_stats.get('scale_time', []) if perf_stats['scale_time'] else perf_stats.get('bbox_time', []))
+
+    if total_time > 0:
+        overall_fps = total_samples / total_time
+        print(f"\n{'='*60}")
+        print(f"Overall Pipeline:")
+        print(f"  Total samples: {total_samples}")
+        print(f"  Total time: {total_time:.3f}s")
+        print(f"  Overall FPS: {overall_fps:.3f}")
+        print(f"  Avg latency per sample: {1000/overall_fps:.2f}ms")
+        print("="*60)
+
 def visualize_pose_distribution(score_path, dm_path):
     all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
     all_dm: DetectMatch = pickle.load(open(dm_path, 'rb'))
@@ -334,35 +675,44 @@ def visualize_pose_distribution(score_path, dm_path):
             all_dm.draw_image(index=index)
             set_trace()
 
-os.makedirs(f'results/evaluation_results/{cfg.result_dir}', exist_ok=True)
+if __name__ == '__main__':
+    dataloader = get_dataloader()
+    os.makedirs(f'results/evaluation_results/{cfg.result_dir}', exist_ok=True)
 
-score_model_name = '_'.join(cfg.pretrained_score_model_path.split('/')[-2:])
-score_save_path = f'results/evaluation_results/{cfg.result_dir}/score_prediction_{score_model_name}.pkl'
-inference_score(score_save_path)
+    score_model_name = '_'.join(cfg.pretrained_score_model_path.split('/')[-2:])
+    score_save_path = f'results/evaluation_results/{cfg.result_dir}/score_prediction_{score_model_name}.pkl'
 
-aggregate_save_path = f'results/evaluation_results/{cfg.result_dir}/aggregated.pkl'
-if cfg.pretrained_energy_model_path is not None:
-    energy_model_name = '_'.join(cfg.pretrained_energy_model_path.split('/')[-2:])
+    is_om_model = cfg.pretrained_score_model_path.endswith('.om')
+    if not is_om_model:
+        import torch_npu
+        torch_npu.npu.set_compile_mode(jit_compile=False)
+    score_net = inference_score_decoupled(score_save_path)
+    if is_om_model and score_net:
+        score_net.release()
+        del score_net
+        gc.collect()
+
+    aggregate_save_path = f'results/evaluation_results/{cfg.result_dir}/aggregated.pkl'
+    energy_om_path = getattr(cfg, 'pretrained_energy_om_path', None) or getattr(cfg, 'pretrained_energy_model_path', None)
+
+    energy_model_name = '_'.join(energy_om_path.split('/')[-2:])
     energy_save_path = f'results/evaluation_results/{cfg.result_dir}/energy_prediction_{energy_model_name}.pkl'
+    pointnet2_energy_om_path = getattr(cfg, 'pretrained_pointnet2_energy_model_path', None)
     inference_energy(score_save_path, energy_save_path)
     aggregate_pose(score_save_path, energy_save_path, aggregate_save_path)
-else:
-    aggregate_pose(score_save_path, None, aggregate_save_path)
-
-if cfg.pretrained_scale_model_path is not None:
     scale_model_name = '_'.join(cfg.pretrained_scale_model_path.split('/')[-2:])
-else:
-    scale_model_name = 'scale-none'
-cls_save_path = f'results/evaluation_results/{cfg.result_dir}/scale_prediction_{scale_model_name}.pkl'
-inference_scale(score_save_path, aggregate_save_path, cls_save_path)
 
-dm_save_path = f'results/evaluation_results/{cfg.result_dir}/detect_match.pkl'
-get_detect_match(cls_save_path, dm_save_path)
+    cls_save_path = f'results/evaluation_results/{cfg.result_dir}/scale_prediction_{scale_model_name}.pkl'
+    inference_scale(score_save_path, aggregate_save_path, cls_save_path)
+
+    dm_save_path = f'results/evaluation_results/{cfg.result_dir}/detect_match.pkl'
+    get_detect_match(cls_save_path, dm_save_path)
+
+    criterion_save_path = f'results/evaluation_results/{cfg.result_dir}/criterion.pkl'
+    get_criterion(dm_save_path, criterion_save_path)
 
-criterion_save_path = f'results/evaluation_results/{cfg.result_dir}/criterion.pkl'
-get_criterion(dm_save_path, criterion_save_path)
+    metrics_save_path = f'results/evaluation_results/{cfg.result_dir}/metrics.json'
+    print_metrics(dm_save_path, criterion_save_path, metrics_save_path)
 
-metrics_save_path = f'results/evaluation_results/{cfg.result_dir}/metrics.json'
-print_metrics(dm_save_path, criterion_save_path, metrics_save_path)
-# visualize_pose_distribution(score_save_path, dm_save_path)
-os._exit(0)
\ No newline at end of file
+    # Print performance statistics
+    print_performance_stats()
diff --git a/runners/evaluation_tracking.py b/runners/evaluation_tracking.py
index 2aedfa9..7cebc04 100644
--- a/runners/evaluation_tracking.py
+++ b/runners/evaluation_tracking.py
@@ -69,55 +69,189 @@ def get_dataloader(data_dir: str):
     )
     return iter(dataloader)
 
-cfg.agent_type = 'score'
-score_agent = PoseNet(cfg)
-score_agent.load_ckpt(model_dir=cfg.pretrained_score_model_path, model_path=True, load_model_only=True)
-score_agent.eval()
-
-cfg.agent_type = 'energy'
-energy_agent = PoseNet(cfg)
-energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
-energy_agent.eval()
-
-if cfg.pretrained_scale_model_path:
-    cfg.agent_type = 'scale'
-    scale_agent = PoseNet(cfg)
-    scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
-    scale_agent.eval()
+# PTH/OM model loading
+is_om_model = cfg.pretrained_score_model_path.endswith('.om')
 
-def work_batch(test_batch, prev_pose):
-    batch_sample = process_batch(
-        batch_sample = test_batch, 
-        device=cfg.device, 
-        pose_mode=cfg.pose_mode,
-    )
-    
-    _prev_pose = prev_pose.clone()
-    _prev_pose[:, -3:] -= batch_sample['pts_center']
+if not is_om_model:
+    import torch_npu
+    torch_npu.npu.set_compile_mode(jit_compile=False)
     cfg.agent_type = 'score'
-    score_pred_results, _ = score_agent.pred_func(
-        data=batch_sample, 
-        repeat_num=cfg.eval_repeat_num, 
-        T0=cfg.T0,
-        init_x=_prev_pose,
-        return_average_res=False,
-        return_process=False,
-    )
-    score_feature = {
-        'pts_feat': batch_sample['pts_feat'].clone(),
-        'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].clone()),
-    }
-    
+    score_agent = PoseNet(cfg)
+    score_agent.load_ckpt(model_dir=cfg.pretrained_score_model_path, model_path=True, load_model_only=True)
+    score_agent.eval()
+
     cfg.agent_type = 'energy'
-    energy_pred_results = energy_agent.get_energy(
-        data=batch_sample, 
-        pose_samples=score_pred_results, 
-        T=1e-5,
-        mode='test', 
-        extract_feature=True
+    energy_agent = PoseNet(cfg)
+    energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
+    energy_agent.eval()
+
+    if cfg.pretrained_scale_model_path:
+        cfg.agent_type = 'scale'
+        scale_agent = PoseNet(cfg)
+        scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
+        scale_agent.eval()
+else:
+    from om_wrappers import (create_score_network, create_ode_sampler,
+                             DINOv2Wrapper, EnergyNetWrapper,
+                             PointNet2EncoderWrapper, ScaleNetWrapper)
+    from networks.gf_algorithms.sde import init_sde, ve_sde_numpy
+    from datasets.datasets_omni6dpose import process_batch_numpy
+    from utils.misc import get_pose_dim
+
+    prior_fn, _, sde_fn, sampling_eps, _ = init_sde('ve')
+    pointnet2_score_om_path = getattr(cfg, 'pretrained_pointnet2_score_model_path', None)
+
+    score_net = create_score_network(
+        checkpoint_path=cfg.pretrained_score_model_path,
+        device=cfg.device,
+        pointnet2_om_path=pointnet2_score_om_path,
     )
+    sde_dir = {'prior_fn': prior_fn, 'sde_fn': ve_sde_numpy}
+    sampler = create_ode_sampler(score_network=score_net, sde=sde_dir, device=cfg.device)
+
+    energy_net = EnergyNetWrapper(cfg.pretrained_energy_model_path, device=cfg.device)
+    pointnet2_energy_encoder = PointNet2EncoderWrapper(
+        cfg.pretrained_pointnet2_energy_model_path, device=cfg.device)
+    print(f"Using EnergyNet OM: {cfg.pretrained_energy_model_path}")
+
+    if cfg.pretrained_scale_model_path:
+        scale_net = ScaleNetWrapper(cfg.pretrained_scale_model_path, device=cfg.device)
+        print(f"Using ScaleNet OM: {cfg.pretrained_scale_model_path}")
+
+    if cfg.dino != 'none':
+        dino_om_path = getattr(cfg, 'pretrained_dino_model_path', None)
+        if dino_om_path is not None and dino_om_path.endswith('.om'):
+            dino_model = DINOv2Wrapper(dino_om_path, device=cfg.device)
+            print(f"Using DINOv2 OM: {dino_om_path}")
+        else:
+            import torch.hub
+            dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(cfg.device)
+            dino_model.requires_grad_(False)
+            print("Using DINOv2 PyTorch")
+
+        def extract_dino_features(batch_sample):
+            roi_rgb = batch_sample['roi_rgb']
+            roi_xs = batch_sample['roi_xs']
+            roi_ys = batch_sample['roi_ys']
+            if is_om_model:
+                return dino_model(roi_rgb, roi_xs, roi_ys)
+            feat = dino_model.get_intermediate_layers(roi_rgb)[0]
+            xs = roi_xs // 14
+            ys = roi_ys // 14
+            pos = xs * 16 + ys
+            pos = torch.unsqueeze(pos, -1).expand(-1, -1, 384)
+            rgb_feat = torch.gather(feat, 1, pos)
+            rgb_feat.requires_grad_(False)
+            return rgb_feat
+    else:
+        def extract_dino_features(batch_sample):
+            return None
+
+def work_batch(test_batch, prev_pose):
+    if is_om_model:
+        batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+    else:
+        batch_sample = process_batch(
+            batch_sample = test_batch,
+            device=cfg.device,
+            pose_mode=cfg.pose_mode,
+        )
+
+    bs = prev_pose.shape[0]
+    pose_dim = get_pose_dim(cfg.pose_mode) if is_om_model else prev_pose.shape[1]
+    repeat_num = cfg.eval_repeat_num
+
+    if is_om_model:
+        # OM score path
+        t0 = time.time()
+        # DINOv2 + PointNet2 feature extraction
+        rgb_feat = extract_dino_features(batch_sample)
+        pts_feat = score_net.extract_pts_feat(batch_sample['pts'], rgb_feat)
+
+        # Construct init_x: repeat prev_pose and add noise 
+        _prev_pose = prev_pose.cpu().numpy().copy()
+        _prev_pose[:, -3:] -= batch_sample['pts_center']
+        noise = prior_fn((bs * repeat_num, pose_dim), T=cfg.T0).numpy()
+        prev_pose_repeated = np.repeat(_prev_pose, repeat_num, axis=0)
+        init_x_repeated = prev_pose_repeated + noise
+
+        pts_feat_repeated = np.repeat(pts_feat[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+        _, sampled_pose = sampler.sample(
+            pts_feat=pts_feat_repeated,
+            rgb_feat=None,
+            batch_size=bs * repeat_num,
+            pose_dim=pose_dim,
+            T=cfg.T0,
+            eps=sampling_eps,
+            rtol=1e-5,
+            atol=1e-5,
+            denoise=True,
+            init_x=init_x_repeated,
+            pts_center=None if batch_sample.get('pts_center') is None else
+                np.repeat(batch_sample['pts_center'][:, np.newaxis, :], repeat_num, axis=1).reshape(bs * repeat_num, -1),
+        )
+        score_pred_results = sampled_pose.reshape(bs, repeat_num, pose_dim)
 
-    sorted_pose, sorted_energy = sort_poses_by_energy(score_pred_results, energy_pred_results)
+        score_feature = {
+            'pts_feat': pts_feat,
+            'rgb_feat': rgb_feat,
+        }
+
+        # OM energy
+        pts_with_rgb = np.concatenate([batch_sample['pts'], rgb_feat], axis=-1)  # [bs, 1024, 387]
+        pts_feat_energy = pointnet2_energy_encoder(pts_with_rgb)
+
+        pose_samples = score_pred_results.reshape(bs * repeat_num, -1).astype(np.float32)
+        pose_samples[:, -3:] -= np.repeat(batch_sample['pts_center'], repeat_num, axis=0)
+        t = np.full((bs * repeat_num, 1), 1e-5, dtype=np.float32)
+        pts_feat_repeated_energy = np.repeat(pts_feat_energy[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+        with torch.no_grad():
+            pred_energy = energy_net(pts_feat_repeated_energy, pose_samples, t)
+        energy_pred_results = pred_energy.reshape(bs, repeat_num, -1)
+        perf_stats['score_time'].append(time.time() - t0)
+        perf_stats['score_samples'] += bs
+
+    else:
+        # PTH path (original) 
+        t0 = time.time()
+        _prev_pose = prev_pose.clone()
+        _prev_pose[:, -3:] -= batch_sample['pts_center']
+        cfg.agent_type = 'score'
+        score_pred_results, _ = score_agent.pred_func(
+            data=batch_sample,
+            repeat_num=cfg.eval_repeat_num,
+            T0=cfg.T0,
+            init_x=_prev_pose,
+            return_average_res=False,
+            return_process=False,
+        )
+        score_feature = {
+            'pts_feat': batch_sample['pts_feat'].clone(),
+            'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].clone()),
+        }
+
+        cfg.agent_type = 'energy'
+        energy_pred_results = energy_agent.get_energy(
+            data=batch_sample,
+            pose_samples=score_pred_results,
+            T=1e-5,
+            mode='test',
+            extract_feature=True
+        )
+        perf_stats['score_time'].append(time.time() - t0)
+        perf_stats['score_samples'] += bs
+
+    # Convert numpy to tensor for sort and aggregate operations
+    if is_om_model:
+        score_pred_results = torch.from_numpy(score_pred_results)
+        energy_pred_results = torch.from_numpy(energy_pred_results)
+
+    # aggregate + scale
+    t0 = time.time()
+    sorted_pose, sorted_energy = sort_poses_by_energy(
+        score_pred_results, energy_pred_results)
     bs = score_pred_results.shape[0]
     retain_num = int(cfg.eval_repeat_num * cfg.retain_ratio)
     good_pose = sorted_pose[:, :retain_num, :]
@@ -146,28 +280,36 @@ def work_batch(test_batch, prev_pose):
     gt_length = test_batch['bbox_side_len'].numpy()
 
     if cfg.pretrained_scale_model_path:
-        cfg.agent_type = 'scale'
-        batch_sample.update(score_feature)
-        batch_sample['axes'] = aggregated_pose[:, :3, :3].to(cfg.device)
-        with torch.no_grad():
-            pred_length = scale_agent.net(batch_sample) 
-        pred_length = pred_length.cpu().numpy()
+        if is_om_model:
+            pred_length = scale_net(pts_feat, aggregated_pose[:, :3, :3])
+        else:
+            cfg.agent_type = 'scale'
+            batch_sample.update(score_feature)
+            batch_sample['axes'] = aggregated_pose[:, :3, :3].to(cfg.device)
+            with torch.no_grad():
+                pred_length = scale_agent.net(batch_sample)
+            pred_length = pred_length.cpu().numpy()
     else:
         pred_length = np.ones((pred_pose.shape[0], 3))
 
     detect_match = DetectMatch(
-        gt_affine=gt_pose, gt_size=gt_length, 
-        gt_sym_labels=array_to_SymLabel(test_batch['sym_info']), 
+        gt_affine=gt_pose, gt_size=gt_length,
+        gt_sym_labels=array_to_SymLabel(test_batch['sym_info']),
         gt_class_labels=test_batch['class_label'],
         pred_affine=pred_pose, pred_size=pred_length,
         # image_path=[path + 'color.png' for path in test_batch['path']],
         camera_intrinsics=array_to_CameraIntrinsicsBase(test_batch['intrinsics'])
     )
+    perf_stats['aggregate_time'].append(time.time() - t0)
+    perf_stats['aggregate_samples'] += bs
 
-    prev_pose = torch.zeros_like(prev_pose, device=cfg.device)
+    prev_pose = torch.zeros_like(
+        prev_pose, 
+        device=cfg.device if not is_om_model else 'cpu'
+    )
     prev_pose[:, :-3] = get_pose_representation(aggregated_pose[:, :3, :3], cfg.pose_mode)
     prev_pose[:, -3:] = aggregated_pose[:, :3, 3]
-    
+
     return detect_match, prev_pose
 
 img_list = Dataset.glob_prefix(root = cfg.data_path)
@@ -212,6 +354,13 @@ pbar = tqdm(total=total_objects)
 for i in range(30):
     add_dataloader()
 
+perf_stats = {
+    'score_time': [],
+    'score_samples': 0,
+    'aggregate_time': [],
+    'aggregate_samples': 0,
+}
+
 while 1:
     test_batch = []
     prev_pose = []
@@ -233,15 +382,24 @@ while 1:
                 f.write(dataloader.save_path + '\n')
             dd.add(dataloader)
             continue
-        test_batch.append(batch)
         length = dataloader._dataset.num_valid
+        if batch.get('_corrupted', torch.tensor(False)).any():
+            print(f"[SKIP] Corrupted frame in {dataloader.save_path}")
+            pbar.update(length)
+            continue
+        test_batch.append(batch)
         try:
             prev_pose.append(dataloader.prev_pose)
         except:
-            pose = torch.zeros(length, get_pose_dim(cfg.pose_mode), device=cfg.device) # on gpu
+            pose = torch.zeros(
+                length, get_pose_dim(cfg.pose_mode), 
+                device=cfg.device if not is_om_model else 'cpu'
+            )
             assert batch['affine'].shape[0] == length, set_trace()
             for j in range(length):
-                noise_gt_pose = add_noise_to_RT(batch['affine'][j].to(cfg.device).unsqueeze(0))[0]
+                noise_gt_pose = add_noise_to_RT(batch['affine'][j].to(
+                    cfg.device if not is_om_model else 'cpu'
+                ).unsqueeze(0))[0]
                 pose[j, :-3] = get_pose_representation(
                     noise_gt_pose[:3, :3].unsqueeze(0), 
                     pose_mode=cfg.pose_mode
@@ -252,7 +410,11 @@ while 1:
         if split_pos[-1][0] > cfg.batch_size - 8:
             break
     if test_batch == []:
-        break
+        for dl in dd:
+            dataloaders.remove(dl)
+        if len(dataloaders) == 0:
+            break
+        continue
     
     keys = {key for key, value in test_batch[0].items() if type(value) != list}
     test_batch = {
@@ -277,6 +439,43 @@ while 1:
 
 pbar.close()
 
+# Print performance statistics
+print("\n" + "="*60)
+print("Performance Statistics")
+print("="*60)
+
+stages = [
+    ('Score + Energy', 'score_time', 'score_samples'),
+    ('Aggregate + Scale', 'aggregate_time', 'aggregate_samples'),
+]
+
+for stage_name, time_key, samples_key in stages:
+    times = perf_stats[time_key]
+    samples = perf_stats[samples_key]
+    if len(times) > 0:
+        t_total = sum(times)
+        fps = samples / t_total if t_total > 0 else 0
+        print(f"\n{stage_name}:")
+        print(f"  Total batches: {len(times)}")
+        print(f"  Total samples: {samples}")
+        print(f"  Total time: {t_total:.3f}s")
+        print(f"  Avg batch time: {t_total/len(times):.3f}s")
+        print(f"  FPS: {fps:.3f}")
+        print(f"  Avg latency per sample: {1000/fps:.2f}ms" if fps > 0 else "")
+
+total_samples = perf_stats['score_samples']
+total_time = sum(perf_stats['score_time']) + sum(perf_stats['aggregate_time'])
+
+if total_time > 0:
+    overall_fps = total_samples / total_time
+    print(f"\n{'='*60}")
+    print(f"Overall Pipeline:")
+    print(f"  Total samples: {total_samples}")
+    print(f"  Total time: {total_time:.3f}s")
+    print(f"  Overall FPS: {overall_fps:.3f}")
+    print(f"  Avg latency per sample: {1000/overall_fps:.2f}ms")
+    print("="*60)
+
 all_dm = []
 all_crit = []
 for path in tqdm(video_paths):
diff --git a/runners/infer.py b/runners/infer.py
index 963c6a0..b622583 100644
--- a/runners/infer.py
+++ b/runners/infer.py
@@ -1,15 +1,19 @@
 import os
 import sys
+
+# Disable OpenCV GUI for headless environments
+os.environ['QT_QPA_PLATFORM'] = 'offscreen'
 import numpy as np
 from tqdm import tqdm
 import pickle
 import torch
+import torch_npu
+torch_npu.npu.set_compile_mode(jit_compile=False)
 import random
 import gc
 import cv2
-import open3d as o3d
-import pyrealsense2 as rs
-import pyrealsense2 as rs
+# import open3d as o3d  # Not needed for offline inference
+# import pyrealsense2 as rs  # Not needed for offline inference
 import numpy as np
 import glob
 
@@ -25,9 +29,9 @@ from utils.so3_visualize import visualize_so3
 from cutoop.eval_utils import DetectMatch, Metrics
 from configs.config import get_config
 from datasets.datasets_infer import InferDataset
-
-from flask import Flask, request
-flask_app = Flask(__name__)
+from runners.evaluation_single import apply_spec_ops_patches
+# from flask import Flask, request
+# flask_app = Flask(__name__)
 
 
 class GenPose2:
@@ -243,11 +247,13 @@ def visualize_pose(data:InferDataset, all_final_pose, all_final_length, visualiz
     all_final_length = all_final_length[0].cpu().numpy()
 
     for index, (obj_pose, obj_length) in enumerate(zip(all_final_pose, all_final_length)):
-        if visualize_pts:
-            pts = data.get_objects()['pts'].cpu().numpy()[index]
-            pcd = o3d.geometry.PointCloud()
-            pcd.points = o3d.utility.Vector3dVector(pts)
-            o3d.visualization.draw_geometries([pcd])
+        # open3d Not needed for offline inference
+        # if visualize_pts:
+        #     pts = data.get_objects()['pts'].cpu().numpy()[index]
+        #     pcd = o3d.geometry.PointCloud()
+        #     pcd.points = o3d.utility.Vector3dVector(pts)
+        #     o3d.visualization.draw_geometries([pcd])
+        #     print(f"Object {index}: visualize_pts is not supported (open3d disabled)")
         color_img = DetectMatch._draw_image(
             vis_img=color_img,
             pred_affine=obj_pose,
@@ -264,17 +270,19 @@ def visualize_pose(data:InferDataset, all_final_pose, all_final_length, visualiz
             thickness=True,
         )
     
-    if visualize_image:
-        cv2.namedWindow('rgb')
-        cv2.imshow('rgb', color_img)
-        cv2.waitKey() 
-        cv2.destroyAllWindows()
+    # Not needed for offline inference
+    # if visualize_image:
+    #     cv2.namedWindow('rgb')
+    #     cv2.imshow('rgb', color_img)
+    #     cv2.waitKey()
+    #     cv2.destroyAllWindows()
     return color_img
 
 
 def main():
     ######################################## PARAMETERS ########################################
-    DATA_PATH = 'data/Omni6DPose/ROPE/000007'                 # Path to the data
+    DATA_PATH = 'omin6dpose-000a/ROPE/000000/'                 # Path to the data
+    RESULT_DIR = 'result_images'                               # Output directory for result images
     TRACKING = True                                           # Tracking mode
 
     # Tracking parameter, if the relative pose between the current frame and the previous frame
@@ -286,8 +294,13 @@ def main():
     ENERGY_MODEL_PATH='results/ckpts/EnergyNet/energynet.pth'  # Path to the energy model
     SCALE_MODEL_PATH='results/ckpts/ScaleNet/scalenet.pth'     # Path to the scale model
     PREV_POSE = None                                           # Previous pose
+    apply_spec_ops_patches()
     ######################################## PARAMETERS ########################################
 
+    # Create result directory
+    os.makedirs(RESULT_DIR, exist_ok=True)
+    print(f"Results will be saved to: {os.path.abspath(RESULT_DIR)}")
+
     ''' load data '''
     # Get data from image file
     color_images = sorted(glob.glob(DATA_PATH + '/*_color.png'))
@@ -297,20 +310,26 @@ def main():
         scale_model_path=SCALE_MODEL_PATH,
     )
     
-    cv2.namedWindow('rgb')
+    # cv2.namedWindow('rgb')  # Not needed for offline inference
     for index, color_image in enumerate(tqdm(color_images)):
         data_prefix = color_image.replace('color.png', '')
         data = InferDataset.alternetive_init(data_prefix, img_size=GenPose2.cfg.img_size, device=GenPose2.cfg.device, n_pts=GenPose2.cfg.num_points)
         pose, length = GenPose2.inference(data, PREV_POSE, TRACKING, TRACKING_T0)
         color_image_w_pose = visualize_pose(data, pose, length, visualize_image=False)
+
+        # Save result image to result_images directory
+        image_filename = os.path.basename(color_image)  # e.g., "000123_color.png"
+        output_filename = image_filename.replace('color.png', '_result.png')  # "000123_result.png"
+        output_path = os.path.join(RESULT_DIR, output_filename)
+        cv2.imwrite(output_path, color_image_w_pose)
+
         PREV_POSE = pose
-        cv2.imshow('rgb', color_image_w_pose)
-        cv2.waitKey(1) 
+        # cv2.imshow('rgb', color_image_w_pose)  # Not needed for offline inference
+        # cv2.waitKey(1)
 
-    cv2.destroyAllWindows()    
+    # cv2.destroyAllWindows()  # Not needed for offline inference    
 
 
 if __name__ == '__main__':
     main()
 
-
diff --git a/scripts/eval_single.sh b/scripts/eval_single.sh
index 6386f80..c2122d4 100644
--- a/scripts/eval_single.sh
+++ b/scripts/eval_single.sh
@@ -3,10 +3,10 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
 --pretrained_score_model_path results/ckpts/ScoreNet/scorenet.pth \
 --pretrained_energy_model_path results/ckpts/EnergyNet/energynet.pth \
 --pretrained_scale_model_path results/ckpts/ScaleNet/scalenet.pth \
---data_path Omni6DPose_ROPE_PATH \
+--data_path omni6dpose-000000/ROPE/ \
 --sampler_mode ode \
 --percentage_data_for_test 1.0 \
---batch_size 128 \
+--batch_size 32 \
 --seed 0 \
 --result_dir single \
 --eval_repeat_num 50 \
@@ -14,4 +14,5 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
 --T0 0.55 \
 --dino pointwise \
 --num_worker 32 \
---real_drop 3 \
\ No newline at end of file
+--real_drop 3 \
+--device npu:0
\ No newline at end of file
diff --git a/scripts/eval_single_om.sh b/scripts/eval_single_om.sh
new file mode 100644
index 0000000..e0f94fa
--- /dev/null
+++ b/scripts/eval_single_om.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+# OM Model Evaluation Script
+# Uses separate PointNet2 OM and ScoreNet OM models for better performance
+CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
+--pretrained_dino_model_path om_models/dinov2_vits14.om \
+--pretrained_pointnet2_score_model_path om_models/pointnet2_from_score.om \
+--pretrained_pointnet2_energy_model_path om_models/pointnet2_from_energy.om \
+--pretrained_score_model_path om_models/scorenet.om \
+--pretrained_energy_model_path om_models/energynet.om \
+--pretrained_scale_model_path om_models/scalenet.om \
+--data_path omni6dpose-000000/ROPE/ \
+--sampler_mode ode \
+--percentage_data_for_test 1.0 \
+--batch_size 16 \
+--seed 0 \
+--result_dir single_om \
+--eval_repeat_num 50 \
+--clustering 1 \
+--T0 0.55 \
+--dino pointwise \
+--num_worker 32 \
+--real_drop 3 \
+--device npu:0
diff --git a/scripts/eval_tracking.sh b/scripts/eval_tracking.sh
index fe9c504..881df8a 100644
--- a/scripts/eval_tracking.sh
+++ b/scripts/eval_tracking.sh
@@ -3,10 +3,10 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \
 --pretrained_score_model_path results/ckpts/ScoreNet/scorenet.pth \
 --pretrained_energy_model_path results/ckpts/EnergyNet/energynet.pth \
 --pretrained_scale_model_path results/ckpts/ScaleNet/scalenet.pth \
---data_path Omni6DPose_ROPE_PATH \
+--data_path omni6dpose-000000/ROPE/ \
 --sampler_mode ode \
 --percentage_data_for_test 1.0 \
---batch_size 128 \
+--batch_size 16 \
 --seed 0 \
 --result_dir tracking \
 --eval_repeat_num 50 \
diff --git a/scripts/eval_tracking_om.sh b/scripts/eval_tracking_om.sh
new file mode 100644
index 0000000..0712822
--- /dev/null
+++ b/scripts/eval_tracking_om.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# OM Model Tracking Evaluation Script
+CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \
+--pretrained_dino_model_path om_models/dinov2_vits14.om \
+--pretrained_pointnet2_score_model_path om_models/pointnet2_from_score.om \
+--pretrained_pointnet2_energy_model_path om_models/pointnet2_from_energy.om \
+--pretrained_score_model_path om_models/scorenet.om \
+--pretrained_energy_model_path om_models/energynet.om \
+--pretrained_scale_model_path om_models/scalenet.om \
+--data_path omni6dpose-000000/ROPE/ \
+--sampler_mode ode \
+--percentage_data_for_test 1.0 \
+--batch_size 4 \
+--seed 0 \
+--result_dir tracking_om \
+--eval_repeat_num 50 \
+--clustering 1 \
+--T0 0.25 \
+--dino pointwise \
+--num_worker 32 \
+--device npu:0
diff --git a/utils/misc.py b/utils/misc.py
index 8332686..8bd34a6 100644
--- a/utils/misc.py
+++ b/utils/misc.py
@@ -302,6 +302,24 @@ def normalize_rotation(rotation, rotation_mode):
         raise NotImplementedError
     return rotation
 
+
+def normalize_rotation_numpy(rotation, rotation_mode):
+    """Numpy version of normalize_rotation for OM inference (rot_matrix mode only)."""
+    a1 = rotation[:, :3]
+    a2 = rotation[:, 3:]
+    b1 = a1 / np.linalg.norm(a1, axis=-1, keepdims=True)
+    b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
+    b2 = b2 / np.linalg.norm(b2, axis=-1, keepdims=True)
+    b3 = np.stack([
+        b1[..., 1] * b2[..., 2] - b1[..., 2] * b2[..., 1],
+        b1[..., 2] * b2[..., 0] - b1[..., 0] * b2[..., 2],
+        b1[..., 0] * b2[..., 1] - b1[..., 1] * b2[..., 0],
+    ], axis=-1)
+    rot_matrix = np.stack((b1, b2, b3), axis=-1)
+    rotation[:, :3] = rot_matrix[:, :, 0]
+    rotation[:, 3:6] = rot_matrix[:, :, 1]
+    return rotation
+
     
 if __name__ == '__main__':
     quat = torch.randn(2, 3, 4)
diff --git a/utils/transforms/rotation_conversions.py b/utils/transforms/rotation_conversions.py
index df68525..11efc01 100755
--- a/utils/transforms/rotation_conversions.py
+++ b/utils/transforms/rotation_conversions.py
@@ -553,6 +553,13 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
     return quaternions[..., 1:] / sin_half_angles_over_angles
 
 
+
+def _simple_cross(a, b):
+    c1 = a[...,1]*b[...,2] - a[...,2]*b[...,1]
+    c2 = a[...,2]*b[...,0] - a[...,0]*b[...,2]
+    c3 = a[...,0]*b[...,1] - a[...,1]*b[...,0]
+    return torch.stack([c1, c2, c3], dim=-1)
+
 def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
     """
     Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
@@ -573,7 +580,7 @@ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
     b1 = F.normalize(a1, dim=-1)
     b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
     b2 = F.normalize(b2, dim=-1)
-    b3 = torch.cross(b1, b2, dim=-1)
+    b3 = _simple_cross(b1, b2)
     return torch.stack((b1, b2, b3), dim=-2)