@@ -16,7 +16,7 @@ def get_config():
parser.add_argument('--train_source', type=str, default='Omni6DPose')
parser.add_argument('--val_source', type=str, default='Omni6DPose')
parser.add_argument('--test_source', type=str, default='Omni6DPose')
- parser.add_argument('--device', type=str, default='cuda')
+ parser.add_argument('--device', type=str, default='npu:0')
parser.add_argument('--num_points', type=int, default=1024)
parser.add_argument('--per_obj', type=str, default='')
parser.add_argument('--num_workers', type=int, default=32)
@@ -34,12 +34,18 @@ def get_config():
parser.add_argument('--s_theta_mode', type=str, default='score')
parser.add_argument('--norm_energy', type=str, default='identical')
parser.add_argument('--dino', type=str, default='pointwise') # none / global / pointwise
+ parser.add_argument('--pretrained_dino_model_path', type=str, default=None,
+ help='Path to DINOv2 OM model for inference. If None, uses torch.hub PyTorch model.')
parser.add_argument('--scale_embedding', type=int, default=180)
""" training """
parser.add_argument('--agent_type', type=str, default='score', help='one of the [score, energy, energy_with_ranking, scale]')
parser.add_argument('--pretrained_score_model_path', type=str)
+ parser.add_argument('--pretrained_pointnet2_score_model_path', type=str, default=None,
+ help='Path to PointNet2 OM model (from scorenet) for decoupled inference')
+ parser.add_argument('--pretrained_pointnet2_energy_model_path', type=str, default=None,
+ help='Path to PointNet2 OM model (from energynet) for energy stage')
parser.add_argument('--pretrained_energy_model_path', type=str)
parser.add_argument('--pretrained_scale_model_path', type=str)
parser.add_argument('--distillation', default=False, action='store_true')
@@ -2,7 +2,6 @@ import numpy as np
import cv2
import torch
import copy
-import open3d as o3d
from cutoop.data_loader import Dataset, ImageMetaData
from utils.datasets_utils import aug_bbox_eval, get_2d_coord_np, crop_resize_by_warp_affine
@@ -138,7 +137,7 @@ class InferDataset(object):
data['roi_rgb_'] = torch.as_tensor(np.ascontiguousarray(roi_rgb_), dtype=torch.uint8).contiguous()
data['roi_xs'] = torch.as_tensor(np.ascontiguousarray(xs), dtype=torch.int64).contiguous()
data['roi_ys'] = torch.as_tensor(np.ascontiguousarray(ys), dtype=torch.int64).contiguous()
- data['roi_center_dir'] = torch.as_tensor(pixel2xyz(img_height, img_height, bbox_center, intrinsics), dtype=torch.float32).contiguous()
+ data['roi_center_dir'] = torch.tensor(pixel2xyz(img_height, img_height, bbox_center, intrinsics), dtype=torch.float32).contiguous()
return data
@@ -616,7 +616,36 @@ def process_batch(batch_sample,
processed_sample['zero_mean_gt_pose'][:, -3:] -= zero_mean
processed_sample['pts_center'] = zero_mean
- return processed_sample
+ return processed_sample
+
+
+def process_batch_numpy(batch_sample, pose_mode='quat_wxyz'):
+ """Numpy-only version of process_batch for OM inference path.
+
+ Only produces the keys needed by OM inference: pts, roi_rgb, roi_xs, roi_ys, pts_center.
+ Skips gt_pose, sym_info, zero_mean_pts etc. which are only used for evaluation metrics.
+ """
+ processed_sample = {}
+
+ # pts: [bs, 1024, 3]
+ pts = batch_sample['pcl_in'].cpu().numpy().astype(np.float32)
+ processed_sample['pts'] = pts
+
+ # pts_center = mean of pts
+ zero_mean = np.mean(pts[:, :, :3], axis=1, keepdims=True) # [bs, 1, 3]
+ processed_sample['pts_center'] = zero_mean[:, 0, :] # [bs, 3]
+
+ # roi_rgb: [bs, 3, imgsize, imgsize]
+ roi_rgb = batch_sample['roi_rgb'].cpu().numpy().astype(np.float32)
+ processed_sample['roi_rgb'] = roi_rgb
+
+ # roi_xs, roi_ys: [bs, 1024]
+ roi_xs = batch_sample['roi_xs'].cpu().numpy().astype(np.int64)
+ roi_ys = batch_sample['roi_ys'].cpu().numpy().astype(np.int64)
+ processed_sample['roi_xs'] = roi_xs
+ processed_sample['roi_ys'] = roi_ys
+
+ return processed_sample
if __name__ == '__main__':
@@ -113,6 +113,9 @@ class Omni6DPoseDataSet(data.Dataset):
self.per_obj = per_obj
self.per_obj_id = None
+ self._cached_depth = None
+ self._cached_mask = None
+ self._cached_rgb = None
tmp = []
for img_path in self.img_list:
@@ -144,10 +147,22 @@ class Omni6DPoseDataSet(data.Dataset):
obj = valid_objects[index % self.num_valid]
inst_name = obj.meta.oid
- rgb = Dataset.load_color(img_path + "color.png")
- depth = Dataset.load_depth(img_path + ('depth_syn' if self.cfg.perfect_depth else 'depth') + '.exr')
+ try:
+ rgb = Dataset.load_color(img_path + "color.png")
+ depth = Dataset.load_depth(img_path + ('depth_syn' if self.cfg.perfect_depth else 'depth') + '.exr')
+ mask = Dataset.load_mask(img_path + 'mask.exr')
+ except Exception as e:
+ print(f"[WARN] Failed to load data for {img_path}: {e}, using cached data from previous frame")
+ depth = self._cached_depth
+ mask = self._cached_mask
+ rgb = self._cached_rgb
+ used_cache = True
+ else:
+ self._cached_depth = depth
+ self._cached_mask = mask
+ self._cached_rgb = rgb
+ used_cache = False
depth[depth > 1e3] = 0
- mask = Dataset.load_mask(img_path + 'mask.exr')
if not (mask.shape[:2] == depth.shape[:2] == rgb.shape[:2]):
assert 0
return self.__getitem__((index + 1) % self.__len__())
@@ -302,6 +317,7 @@ class Omni6DPoseDataSet(data.Dataset):
data_dict['class_name'] = obj.meta.class_name
data_dict['object_name'] = inst_name
data_dict['is_valid'] = 1
+ data_dict['_corrupted'] = used_cache
# xyz = depth2xyz(depth, intrinsics)
# choose = np.logical_and(mask == inst_idx, depth > 0).flatten().nonzero()[0]
@@ -203,8 +203,14 @@ def cond_ode_sampler(
# num_steps, from T -> eps
t_eval = np.linspace(T, eps, num_steps)
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', t_eval=t_eval)
- xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
- x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
+
+ # Print ODE solver statistics
+ print(f"ODE solver: {len(res.t)} steps computed (adaptive)")
+ print(f"Function evaluations: {res.nfev}")
+ print(f"Time range: {res.t[0]:.6f} → {res.t[-1]:.6f}")
+
+ xs = torch.tensor(res.y, device=device, dtype=torch.float32).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
+ x = torch.tensor(res.y[:, -1], device=device, dtype=torch.float32).reshape(shape) # [bs, pose_dim]
# denoise, using the predictor step in P-C sampler
if denoise:
# Reverse diffusion predictor for denoising
@@ -251,10 +257,10 @@ def cond_edm_sampler(
data, denoised = decoder(data)
# recover data
data['sampled_pose'], data['t'] = x_, t_
- return denoised.to(torch.float64)
+ return denoised.to(torch.float32)
# Main sampling loop.
- x_next = latents.to(torch.float64) * t_steps[0]
+ x_next = latents.to(torch.float32) * t_steps[0]
xs = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
@@ -19,13 +19,22 @@ def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
def ve_sde(t, sigma_min=0.01, sigma_max=90):
sigma = sigma_min * (sigma_max / sigma_min) ** t
- drift_coeff = torch.tensor(0)
- diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
+ device = t.device if hasattr(t, 'device') else 'cpu'
+ drift_coeff = torch.tensor(0, device=device)
+ diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=device))
+ return drift_coeff, diffusion_coeff
+
+def ve_sde_numpy(t, sigma_min=0.01, sigma_max=90):
+ """Pure numpy version of ve_sde for ODE solver integration."""
+ sigma = sigma_min * (sigma_max / sigma_min) ** t
+ drift_coeff = 0.0
+ diffusion_coeff = float(sigma) * np.sqrt(2 * (np.log(sigma_max) - np.log(sigma_min)))
return drift_coeff, diffusion_coeff
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
- return torch.randn(*shape) * sigma_max_prior
+ torch.manual_seed(0)
+ return torch.randn(*shape, dtype=torch.float32) * sigma_max_prior
#----- VP SDE -----
#------------------
@@ -42,7 +51,7 @@ def vp_sde(t, beta_0=0.1, beta_1=20):
return drift_coeff, diffusion_coeff
def vp_prior(shape, beta_0=0.1, beta_1=20):
- return torch.randn(*shape)
+ return torch.randn(*shape, dtype=torch.float32)
#----- sub-VP SDE -----
#----------------------
@@ -70,12 +79,13 @@ def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
return mean, std
def edm_sde(t, sigma_min=0.002, sigma_max=80):
- drift_coeff = torch.tensor(0)
+ device = t.device if hasattr(t, 'device') else 'cpu'
+ drift_coeff = torch.tensor(0, device=device)
diffusion_coeff = torch.sqrt(2 * t)
return drift_coeff, diffusion_coeff
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
- return torch.randn(*shape) * sigma_max
+ return torch.randn(*shape, dtype=torch.float32) * sigma_max
def init_sde(sde_mode):
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
@@ -17,7 +17,6 @@ from configs.config import get_config
from utils.genpose_utils import encode_axes
-
class GFObjectPose(nn.Module):
dino_name = 'dinov2_vits14'
dino_dim = 384
@@ -47,32 +46,9 @@ class GFObjectPose(nn.Module):
self.embedding_dim = GFObjectPose.embedding_dim
''' encode pts '''
- if self.cfg.pts_encoder == 'pointnet':
- assert cfg.dino != 'pointwise' # not supported yet
- self.pts_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
- elif self.cfg.pts_encoder == 'pointnet2':
- if cfg.dino == 'pointwise':
- self.pts_encoder = Pointnet2ClsMSGFus(self.dino_dim)
- else:
- self.pts_encoder = Pointnet2ClsMSG(0)
- elif self.cfg.pts_encoder == 'pointnet_and_pointnet2':
- assert cfg.dino != 'pointwise' # not supported yet
- self.pts_pointnet_encoder = PointNetfeat(num_points=self.cfg.num_points, out_dim=1024)
- self.pts_pointnet2_encoder = Pointnet2ClsMSG(0)
- self.fusion_layer = nn.Linear(2048, 1024)
- self.act = nn.ReLU()
- else:
- raise NotImplementedError
-
- ''' score network'''
- # if self.cfg.sde_mode == 'edm':
- # self.pose_score_net = PoseDecoderNet(
- # self.marginal_prob_fn,
- # sigma_data=1.4148,
- # pose_mode=self.cfg.pose_mode,
- # regression_head=self.cfg.regression_head
- # )
- # else:
+
+ self.pts_encoder = Pointnet2ClsMSGFus(self.dino_dim)
+
per_point_feat = False
if self.cfg.agent_type == 'score':
self.pose_score_net = PoseScoreNet(
@@ -98,19 +74,28 @@ class GFObjectPose(nn.Module):
Args:
data (dict): batch example without pointcloud feature. {'pts': [bs, num_pts, 3], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
+ precomputed_rgb_feat (torch.Tensor, optional): Pre-computed DINOv2 features [B, 1024, 384].
+ If provided, will skip internal DINOv2 computation and use these features directly.
Returns:
data (dict): batch example with pointcloud feature. {'pts': [bs, num_pts, 3], 'pts_feat': [bs, c], 'sampled_pose': [bs, pose_dim], 't': [bs, 1]}
"""
pts = data['pts']
if self.cfg.dino == 'pointwise':
- roi_rgb = data['roi_rgb']
- feat = self.dino.get_intermediate_layers(roi_rgb)[0]
- xs = data['roi_xs'] // 14
- ys = data['roi_ys'] // 14
- pos = xs * 16 + ys
- pos = torch.unsqueeze(pos, -1).expand(-1, -1, self.dino_dim)
- rgb_feat = torch.gather(feat, 1, pos)
- rgb_feat.requires_grad_(False)
+ # Use precomputed features if provided, otherwise compute with DINOv2
+ precomputed_rgb_feat = getattr(data,'rgb_feat', None)
+ if precomputed_rgb_feat:
+ rgb_feat = precomputed_rgb_feat
+ rgb_feat = rgb_feat.to(pts.device)
+ else:
+ # Original path: compute DINOv2 features internally
+ roi_rgb = data['roi_rgb']
+ feat = self.dino.get_intermediate_layers(roi_rgb)[0]
+ xs = data['roi_xs'] // 14
+ ys = data['roi_ys'] // 14
+ pos = xs * 16 + ys
+ pos = torch.unsqueeze(pos, -1).expand(-1, -1, self.dino_dim)
+ rgb_feat = torch.gather(feat, 1, pos)
+ rgb_feat.requires_grad_(False)
if self.cfg.pts_encoder == 'pointnet':
assert 0
pts_feat = self.pts_encoder(pts.permute(0, 2, 1)) # -> (bs, 3, 1024)
@@ -194,6 +179,7 @@ class GFObjectPose(nn.Module):
'pts_feat': [bs, c]
'sampled_pose': [bs, pose_dim]
't': [bs, 1]
+ 'precomputed_rgb_feat': [bs, 1024, 384] (optional)
}
'''
if mode == 'score':
@@ -91,9 +91,6 @@ class PoseNet(nn.Module):
else:
net = self.get_network('ScaleNet')
net = net.to(self.cfg.device)
- if self.cfg.parallel:
- device_ids = list(range(self.cfg.num_gpu))
- net = nn.DataParallel(net, device_ids=device_ids).cuda()
return net
@@ -167,14 +164,14 @@ class PoseNet(nn.Module):
if not os.path.exists(load_path):
raise ValueError("Checkpoint {} not exists.".format(load_path))
- checkpoint = torch.load(load_path)
+ checkpoint = torch.load(load_path, map_location=self.cfg.device)
print("Loading checkpoint from {} ...".format(load_path))
-
+
if isinstance(self.net, nn.DataParallel):
self.net.module.load_state_dict(checkpoint['model_state_dict'])
else:
self.net.load_state_dict(checkpoint['model_state_dict'])
-
+
if not load_model_only:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
@@ -240,6 +240,7 @@ class Pointnet2ClsMSGFus(nn.Module):
)
)
channel_in = channel_out + input_channels
+ self.SA_modules[-1].forward = self.SA_modules[-1].forward_npoint_none
def _break_up_pc(self, pc):
@@ -258,19 +259,30 @@ class Pointnet2ClsMSGFus(nn.Module):
# features: bs * F * npoints
l_xyz, l_features = [xyz], [features]
- for i in range(len(self.SA_modules)):
- if i != 0:
- l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
+
+ # first
+ li_xyz, li_features, idx = self.SA_modules[0](l_xyz[0], l_features[0], return_idx=True)
+ l_xyz.append(li_xyz)
+ l_features.append(li_features)
+ features = torch.gather(features, 2,
+ torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1))
+ # middle
+ for i in range(1,len(self.SA_modules)-1):
+ l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
li_xyz, li_features, idx = self.SA_modules[i](l_xyz[i], l_features[i], return_idx=True)
l_xyz.append(li_xyz)
l_features.append(li_features)
- if idx != None:
- features = torch.gather(
- features, 2,
- torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1)
- ) # only keep features of remaining points
- else:
- assert i == len(self.SA_modules) - 1
+
+ features = torch.gather(features, 2,
+ torch.unsqueeze(idx.type(torch.int64), 1).expand(-1, features.shape[1], -1))
+ # last
+ i += 1
+ l_features[i] = torch.concatenate([l_features[i], features], dim=1) # concatenate
+ li_xyz, li_features, idx = self.SA_modules[i](l_xyz[i], l_features[i], return_idx=True)
+ l_xyz.append(li_xyz)
+ l_features.append(li_features)
+ assert i == len(self.SA_modules) - 1
+
return l_features[-1].squeeze(-1)
@@ -27,42 +27,35 @@ class _PointnetSAModuleBase(nn.Module):
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
new_idx: (B, npoint) tensor of indices
"""
- new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
idx = None
- if new_xyz is None:
- if self.npoint is not None:
- idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
- new_xyz = pointnet2_utils.gather_operation(
- xyz_flipped,
- idx
- ).transpose(1, 2).contiguous()
- else:
- new_xyz = None
+ idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ new_xyz = pointnet2_utils.gather_operation(
+ xyz_flipped,
+ idx
+ ).transpose(1, 2).contiguous()
+ return self.calculate_xyz_features_idx(xyz, features, new_xyz, idx)
+
+
+ def forward_npoint_none(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None, return_idx=False ,index=None):
+ return self.calculate_xyz_features_idx(xyz, features, None, None)
+
+ def calculate_xyz_features_idx(self, xyz, features, new_xyz, idx):
+ new_features_list = []
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
- if self.pool_method == 'max_pool':
- new_features = F.max_pool2d(
- new_features, kernel_size=[1, new_features.size(3)]
- ) # (B, mlp[-1], npoint, 1)
- elif self.pool_method == 'avg_pool':
- new_features = F.avg_pool2d(
- new_features, kernel_size=[1, new_features.size(3)]
- ) # (B, mlp[-1], npoint, 1)
- else:
- raise NotImplementedError
+ new_features = torch.amax(new_features, dim=3, keepdim=True)
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
- if return_idx:
- return new_xyz, torch.cat(new_features_list, dim=1), idx
- return new_xyz, torch.cat(new_features_list, dim=1)
+
+ return new_xyz, torch.cat(new_features_list, dim=1), idx
class PointnetSAModuleMSG(_PointnetSAModuleBase):
@@ -5,7 +5,7 @@ import torch.nn as nn
from typing import Tuple
import sys
-import pointnet2_cuda as pointnet2
+import pointnet2_ops as pointnet2
class FurthestPointSampling(Function):
@@ -20,14 +20,7 @@ class FurthestPointSampling(Function):
:return:
output: (B, npoint) tensor containing the set
"""
- assert xyz.is_contiguous()
-
- B, N, _ = xyz.size()
- output = torch.cuda.IntTensor(B, npoint)
- temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
-
- pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
- return output
+ return pointnet2._furthest_point_sampling(xyz, npoint)
@staticmethod
def backward(xyz, a=None):
@@ -48,26 +41,13 @@ class GatherOperation(Function):
:return:
output: (B, C, npoint)
"""
- assert features.is_contiguous()
- assert idx.is_contiguous()
-
- B, npoint = idx.size()
- _, C, N = features.size()
- output = torch.cuda.FloatTensor(B, C, npoint)
-
- pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
-
- ctx.for_backwards = (idx, C, N)
- return output
+ ctx.for_backwards = (idx, features.shape[1], features.shape[2])
+ return pointnet2._gather_points(features, idx)
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
- B, npoint = idx.size()
-
- grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
- grad_out_data = grad_out.data.contiguous()
- pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
+ grad_features = pointnet2._gather_points_grad(grad_out, idx, N)
return grad_features, None
@@ -92,8 +72,9 @@ class ThreeNN(Function):
B, N, _ = unknown.size()
m = known.size(1)
- dist2 = torch.cuda.FloatTensor(B, N, 3)
- idx = torch.cuda.IntTensor(B, N, 3)
+ device = known.device
+ dist2 = torch.empty((B, N, 3), dtype=torch.float32, device=device)
+ idx = torch.empty((B, N, 3), dtype=torch.int32, device=device)
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@@ -125,8 +106,9 @@ class ThreeInterpolate(Function):
B, c, m = features.size()
n = idx.size(1)
+ device = features.device
ctx.three_interpolate_for_backward = (idx, weight, m)
- output = torch.cuda.FloatTensor(B, c, n)
+ output = torch.empty((B, c, n), dtype=torch.float32, device=device)
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
return output
@@ -143,8 +125,8 @@ class ThreeInterpolate(Function):
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
-
- grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
+ device = weight.device
+ grad_features = Variable(torch.empty((B, c, m), dtype=torch.float32, device=device).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
@@ -165,17 +147,8 @@ class GroupingOperation(Function):
:return:
output: (B, C, npoint, nsample) tensor
"""
- assert features.is_contiguous()
- assert idx.is_contiguous()
-
- B, nfeatures, nsample = idx.size()
- _, C, N = features.size()
- output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
-
- pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
-
- ctx.for_backwards = (idx, N)
- return output
+ ctx.for_backwards = (idx, features.shape[2])
+ return pointnet2._group_points(features, idx)
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -186,12 +159,7 @@ class GroupingOperation(Function):
grad_features: (B, C, N) gradient of the features
"""
idx, N = ctx.for_backwards
-
- B, C, npoint, nsample = grad_out.size()
- grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
-
- grad_out_data = grad_out.data.contiguous()
- pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+ grad_features = pointnet2._group_points_grad(grad_out, idx, N)
return grad_features, None
@@ -211,15 +179,7 @@ class BallQuery(Function):
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
- assert new_xyz.is_contiguous()
- assert xyz.is_contiguous()
-
- B, N, _ = xyz.size()
- npoint = new_xyz.size(1)
- idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
-
- pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
- return idx
+ return pointnet2._ball_query(new_xyz, xyz, radius, nsample)
@staticmethod
def backward(ctx, a=None):
@@ -255,7 +215,7 @@ class QueryAndGroup(nn.Module):
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
- new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
+ new_features = torch.cat([grouped_xyz, grouped_features], dim=1)
else:
new_features = grouped_features
else:
@@ -3,9 +3,8 @@ scipy==1.12.0
numpy==1.26.3
tensorboardX==2.6.2.2
tensorboard==2.17.0
-open3d==0.18.0
-pyrealsense2==2.55.1.6486
-ipdb
matplotlib
tqdm
-scikit-learn
\ No newline at end of file
+scikit-learn
+# open3d==0.18.0 # If use camera, need to install open3d.
+# pyrealsense2==2.56.5.9235 # If use camera, need to install pyrealsense2.
\ No newline at end of file
@@ -8,6 +8,7 @@ from tqdm import tqdm
import _pickle as cPickle
import pickle
import torch
+import random
import torch.nn as nn
import torch.nn.functional as F
import copy
@@ -22,10 +23,13 @@ from ipdb import set_trace
from networks.posenet_agent import PoseNet
from networks.reward import sort_poses_by_energy, ranking_loss
-from datasets.datasets_omni6dpose import Omni6DPoseDataSet, array_to_SymLabel, array_to_CameraIntrinsicsBase, process_batch
+from om_wrappers import create_score_network, create_ode_sampler
+from networks.gf_algorithms.sde import init_sde
+from datasets.datasets_omni6dpose import Omni6DPoseDataSet, array_to_SymLabel, array_to_CameraIntrinsicsBase, process_batch, process_batch_numpy
from utils.metrics import get_rot_matrix
from utils.transforms import matrix_to_quaternion, quaternion_to_matrix
from utils.misc import average_quaternion_batch
+from utils.genpose_utils import get_pose_dim
from utils.so3_visualize import visualize_so3
from utils.visualize import create_grid_image
from cutoop.eval_utils import DetectMatch, Metrics
@@ -41,6 +45,21 @@ torch.cuda.manual_seed(cfg.seed)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
+# Performance statistics
+perf_stats = {
+ 'score_time': [],
+ 'score_samples': 0,
+ 'energy_time': [],
+ 'energy_samples': 0,
+ 'aggregate_time': [],
+ 'aggregate_samples': 0,
+ 'scale_time': [],
+ 'scale_samples': 0,
+ 'bbox_time': [],
+ 'bbox_samples': 0,
+}
+
+
def get_dataloader():
dataset = Omni6DPoseDataSet(
cfg=cfg,
@@ -59,12 +78,57 @@ def get_dataloader():
shuffle=False,
num_workers=cfg.num_workers,
persistent_workers=True,
- drop_last=False,
+ drop_last=True,
pin_memory=True,
)
return dataloader
-dataloader = get_dataloader()
+
+_dino_model = None
+
+def get_dino_model():
+ """Lazy load DINOv2 model to save memory."""
+ global _dino_model
+ if _dino_model is None and cfg.dino != 'none':
+ dino_om_path = getattr(cfg, 'pretrained_dino_model_path', None)
+ if dino_om_path is not None and dino_om_path.endswith('.om'):
+ print(f"Loading DINOv2 OM model: {dino_om_path}")
+ from om_wrappers import DINOv2Wrapper
+ _dino_model = DINOv2Wrapper(dino_om_path, device=cfg.device)
+ print("DINOv2 OM loaded successfully")
+ else:
+ print("Loading DINOv2 model for preprocessing...")
+ import torch.hub
+ _dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(cfg.device)
+ _dino_model.requires_grad_(False)
+ print("DINOv2 loaded successfully")
+ return _dino_model
+
+
+def extract_dino_features(batch_sample):
+ """Extract DINOv2 features as a preprocessing step."""
+ dino = get_dino_model()
+ roi_rgb = batch_sample['roi_rgb'] # [B, 3, H, W]
+ roi_xs = batch_sample['roi_xs'] # [B, 1024]
+ roi_ys = batch_sample['roi_ys'] # [B, 1024]
+
+ # OM path: DINOv2Wrapper returns numpy, keep as numpy throughout
+ if is_om_model:
+ return dino(roi_rgb, roi_xs, roi_ys)
+
+ # PyTorch path: original get_intermediate_layers + gather
+ feat = dino.get_intermediate_layers(roi_rgb)[0] # [B, 256, 384]
+
+ xs = roi_xs // 14
+ ys = roi_ys // 14
+ pos = xs * 16 + ys # 224x224 input -> 16x16 feature map
+ pos = torch.unsqueeze(pos, -1).expand(-1, -1, 384)
+
+ rgb_feat = torch.gather(feat, 1, pos) # [B, 1024, 384]
+ rgb_feat.requires_grad_(False)
+
+ return rgb_feat
+
def inference_score(save_path):
if os.path.exists(save_path):
@@ -77,16 +141,24 @@ def inference_score(save_path):
all_pred_pose = []
all_score_feature = []
+ total_samples = 0
for i, test_batch in enumerate(tqdm(dataloader, desc="score sampling")):
+ start_time = time.time()
batch_sample = process_batch(
- batch_sample = test_batch,
- device=cfg.device,
+ batch_sample = test_batch,
+ device=cfg.device,
pose_mode=cfg.pose_mode,
)
+
+ # Extract DINOv2 features as preprocessing (outside the model)
+ rgb_feat = extract_dino_features(batch_sample)
+ if rgb_feat is not None:
+ batch_sample['precomputed_rgb_feat'] = rgb_feat
+
pred_results = score_agent.pred_func(
- data=batch_sample,
- repeat_num=cfg.eval_repeat_num,
+ data=batch_sample,
+ repeat_num=cfg.eval_repeat_num,
T0=cfg.T0,
return_average_res=False,
return_process=False
@@ -97,41 +169,222 @@ def inference_score(save_path):
'pts_feat': batch_sample['pts_feat'].cpu(),
'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].cpu()),
})
+ elapsed = time.time() - start_time
+ perf_stats['score_time'].append(elapsed)
+ total_samples += pred_pose.shape[0]
+ perf_stats['score_samples'] = total_samples
+
if i % 4 == 3:
gc.collect()
-
pickle.dump((all_pred_pose, all_score_feature), open(save_path, 'wb'))
+def inference_score_decoupled(save_path):
+ """
+ Unified decoupled inference using ScoreNetworkWrapper and ODESamplerExternal.
+
+ Supports both PyTorch and OM models:
+ - PyTorch: Uses internal PyTorch PointNet2 encoder
+ - OM: Uses separate PointNet2 OM + ScoreNet OM (if available)
+
+ This unified interface enables direct comparison between PyTorch and OM outputs
+ for debugging precision issues.
+
+ Args:
+ save_path: Path to save cached results
+ """
+ if os.path.exists(save_path):
+ return
+
+ # Initialize SDE components
+ prior_fn, _, sde_fn, sampling_eps, _ = init_sde('ve')
+
+ pointnet2_om_path = getattr(cfg, 'pretrained_pointnet2_score_model_path', None)
+
+
+ # Create Score Network wrapper (automatically uses OM or PyTorch)
+ score_net = create_score_network(
+ checkpoint_path=cfg.pretrained_score_model_path,
+ device=cfg.device,
+ pointnet2_om_path=pointnet2_om_path # Pass None for PyTorch, path for OM
+ )
+ if is_om_model:
+ from networks.gf_algorithms.sde import ve_sde_numpy
+ sde_coeff_fn = ve_sde_numpy
+ else:
+ sde_coeff_fn = sde_fn
+ sde_dir = {'prior_fn': prior_fn, 'sde_fn': sde_coeff_fn}
+ # Create ODE sampler with the Score Network
+ sampler = create_ode_sampler(
+ score_network=score_net,
+ sde=sde_dir,
+ device=cfg.device
+ )
+
+ all_pred_pose = []
+ all_score_feature = []
+ total_samples = 0
+
+ print(f"\nRunning unified decoupled inference ({'OM' if score_net.is_om else 'PyTorch'})...")
+ for i, test_batch in enumerate(tqdm(dataloader, desc="score sampling")):
+ start_time = time.time()
+
+ if is_om_model:
+ batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+ else:
+ batch_sample = process_batch(
+ batch_sample=test_batch,
+ device=cfg.device,
+ pose_mode=cfg.pose_mode,
+ )
+
+ # Extract DINOv2 features as preprocessing
+ rgb_feat = extract_dino_features(batch_sample)
+ if rgb_feat is not None:
+ batch_sample['precomputed_rgb_feat'] = rgb_feat
+
+ # Extract point cloud features using unified interface
+ pts_feat = score_net.extract_pts_feat(
+ pts=batch_sample['pts'],
+ rgb_feat=rgb_feat
+ )
+
+ # Get batch info
+ bs = batch_sample['pts'].shape[0]
+ pose_dim = get_pose_dim(cfg.pose_mode)
+ pts_center = batch_sample.get('pts_center', None)
+
+ # Generate random initial values for all repeats at once
+ init_x_all = prior_fn((bs, cfg.eval_repeat_num, pose_dim), T=cfg.T0).cpu().numpy()
+ init_x_repeated = init_x_all.reshape(bs * cfg.eval_repeat_num, pose_dim)
+
+ # Repeat features and init_x to process all at once
+ if is_om_model:
+ pts_feat_repeated = np.repeat(pts_feat[np.newaxis, ...], cfg.eval_repeat_num, axis=1).reshape(bs * cfg.eval_repeat_num, -1)
+ pts_center_repeated = None if pts_center is None else \
+ np.repeat(pts_center[:, np.newaxis, :], cfg.eval_repeat_num, axis=1).reshape(bs * cfg.eval_repeat_num, -1)
+ else:
+ pts_feat_repeated = pts_feat.unsqueeze(1).repeat(1, cfg.eval_repeat_num, 1).view(bs * cfg.eval_repeat_num, -1)
+ pts_center_repeated = None if pts_center is None else \
+ pts_center.unsqueeze(1).repeat(1, cfg.eval_repeat_num, 1).view(bs * cfg.eval_repeat_num, -1)
+
+ # Single call to sampler for all repeats
+ with torch.no_grad():
+ _, sampled_pose = sampler.sample(
+ pts_feat=pts_feat_repeated,
+ rgb_feat=None,
+ batch_size=bs * cfg.eval_repeat_num,
+ pose_dim=pose_dim,
+ T=cfg.T0,
+ eps=sampling_eps,
+ rtol=1e-5,
+ atol=1e-5,
+ denoise=True,
+ init_x=init_x_repeated,
+ pts_center=pts_center_repeated
+ )
+
+ # Reshape result from [bs*repeat_num, pose_dim] to [bs, repeat_num, pose_dim]
+ if is_om_model:
+ pred_pose = sampled_pose.reshape(bs, cfg.eval_repeat_num, pose_dim)
+ else:
+ pred_pose = sampled_pose.view(bs, cfg.eval_repeat_num, pose_dim)
+
+ # Save pred_pose and features
+ all_pred_pose.append(pred_pose)
+ all_score_feature.append({
+ 'pts_feat': pts_feat,
+ 'rgb_feat': rgb_feat,
+ })
+
+
+ elapsed = time.time() - start_time
+ perf_stats['score_time'].append(elapsed)
+ total_samples += pred_pose.shape[0]
+ perf_stats['score_samples'] = total_samples
+ if i % 4 == 3:
+ gc.collect()
+
+ pickle.dump((all_pred_pose, all_score_feature), open(save_path, 'wb'))
+ print(f"Unified decoupled inference complete! ({'OM' if score_net.is_om else 'PyTorch'})")
+ return score_net
+
+
def inference_energy(score_path, save_path):
if os.path.exists(save_path):
return
assert os.path.exists(score_path)
all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
- cfg.agent_type = 'energy'
- energy_agent = PoseNet(cfg)
- energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
- energy_agent.eval()
+ if is_om_model:
+ from om_wrappers import EnergyNetWrapper, PointNet2EncoderWrapper
+ energy_net = EnergyNetWrapper(cfg.pretrained_energy_model_path, device=cfg.device)
+ # Load PointNet2 from energy checkpoint for pts_feat extraction
+ pointnet2_encoder = PointNet2EncoderWrapper(cfg.pretrained_pointnet2_energy_model_path, device=cfg.device)
+ print(f"Using EnergyNet OM: {cfg.pretrained_energy_model_path}")
+ print(f"Using PointNet2 (from energy): {cfg.pretrained_pointnet2_energy_model_path}")
+ else:
+ cfg.agent_type = 'energy'
+ energy_agent = PoseNet(cfg)
+ energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
+ energy_agent.eval()
all_pred_energy = []
+ total_samples = 0
for i, test_batch in enumerate(tqdm(dataloader, desc="energy")):
- batch_sample = process_batch(
- batch_sample = test_batch,
- device=cfg.device,
- pose_mode=cfg.pose_mode,
- )
- pred_energy = energy_agent.get_energy(
- data=batch_sample,
- pose_samples=all_pred_pose[i],
- T=1e-5,
- mode='test',
- extract_feature=True
- )
- all_pred_energy.append(pred_energy.cpu())
+ start_time = time.time()
+ if is_om_model:
+ batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+ else:
+ batch_sample = process_batch(
+ batch_sample = test_batch,
+ device=cfg.device,
+ pose_mode=cfg.pose_mode,
+ )
+
+ # Extract DINOv2 features as preprocessing (outside the model)
+ rgb_feat = extract_dino_features(batch_sample)
+ batch_sample['precomputed_rgb_feat'] = rgb_feat
+
+ if is_om_model:
+ bs = batch_sample['pts'].shape[0]
+ repeat_num = all_pred_pose[i].shape[1]
+
+ pointcloud = np.concatenate([batch_sample['pts'], rgb_feat], axis=-1)
+ pts_feat = pointnet2_encoder(pointcloud)
+
+ # Repeat pts_feat
+ repeated_pts_feat = np.repeat(pts_feat[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+ # Prepare sampled_pose with pts_center subtracted
+ pose_samples = all_pred_pose[i].reshape(bs * repeat_num, -1).astype(np.float32)
+ pts_center = batch_sample['pts_center']
+ repeated_pts_center = np.repeat(pts_center[:, np.newaxis, :], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+ pose_samples[:, -3:] -= repeated_pts_center
+
+ T = 1e-5
+ t = np.full((bs * repeat_num, 1), T, dtype=np.float32)
+
+ # OM inference
+ pred_energy = energy_net(repeated_pts_feat, pose_samples, t).reshape(bs, repeat_num, -1)
+ else:
+ pred_energy = energy_agent.get_energy(
+ data=batch_sample,
+ pose_samples=all_pred_pose[i],
+ T=1e-5,
+ mode='test',
+ extract_feature=True
+ )
+ all_pred_energy.append(pred_energy)
+
+ elapsed = time.time() - start_time
+ perf_stats['energy_time'].append(elapsed)
+ total_samples += pred_energy.shape[0]
+ perf_stats['energy_samples'] = total_samples
+
if i % 4 == 3:
gc.collect()
-
+
pickle.dump(all_pred_energy, open(save_path, 'wb'))
def aggregate_pose(score_path, energy_path, save_path):
@@ -139,17 +392,20 @@ def aggregate_pose(score_path, energy_path, save_path):
return
assert os.path.exists(score_path)
all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
- if energy_path is not None:
- assert os.path.exists(energy_path)
- all_pred_energy = pickle.load(open(energy_path, 'rb'))
- else:
- all_pred_energy = [torch.ones(*(all_pred_pose[i].shape[:2]), 2)
- for i in range(len(all_pred_pose))]
+
+ assert os.path.exists(energy_path)
+ all_pred_energy = pickle.load(open(energy_path, 'rb'))
+ # ensure tensors (OM energy path saves numpy)
+ if is_om_model:
+ all_pred_pose = [torch.from_numpy(i) for i in all_pred_pose]
+ all_pred_energy = [torch.from_numpy(i) for i in all_pred_energy]
all_aggregated_pose = []
-
+ total_samples = 0
+
for i, (pred_pose, pred_energy) in enumerate(tqdm(zip(all_pred_pose, all_pred_energy), desc="aggregate")):
- sorted_pose, sorted_energy = sort_poses_by_energy(pred_pose, pred_energy)
+ start_time = time.time()
+ sorted_pose, _ = sort_poses_by_energy(pred_pose, pred_energy)
bs = pred_pose.shape[0]
retain_num = int(cfg.eval_repeat_num * cfg.retain_ratio)
good_pose = sorted_pose[:, :retain_num, :]
@@ -173,6 +429,12 @@ def aggregate_pose(score_path, energy_path, save_path):
aggregated_pose[:, :3, :3] = quaternion_to_matrix(aggregated_quat_wxyz)
aggregated_pose[:, :3, 3] = aggregated_trans
all_aggregated_pose.append(aggregated_pose)
+
+ elapsed = time.time() - start_time
+ perf_stats['aggregate_time'].append(elapsed)
+ total_samples += bs
+ perf_stats['aggregate_samples'] = total_samples
+
if i % 10 == 9:
gc.collect()
@@ -181,6 +443,8 @@ def aggregate_pose(score_path, energy_path, save_path):
def inference_scale(score_path, aggregate_path, save_path):
if os.path.exists(save_path):
return
+ scale_path = getattr(cfg, 'pretrained_scale_model_path', None)
+ assert os.path.exists(scale_path)
assert os.path.exists(score_path)
_, all_score_feature = pickle.load(open(score_path, 'rb'))
assert os.path.exists(aggregate_path)
@@ -188,8 +452,10 @@ def inference_scale(score_path, aggregate_path, save_path):
if cfg.pretrained_scale_model_path is None:
all_final_length = []
+ total_samples = 0
for i, test_batch in enumerate(tqdm(dataloader, desc="bbox")):
+ start_time = time.time()
pcl: torch.Tensor = test_batch['pcl_in'] # [bs, 1024, 3]
rotation: torch.Tensor = all_aggregated_pose[i][:, :3, :3] # [bs, 3, 3]
rotation_t = torch.transpose(rotation, 1, 2) # [bs, 3, 3]
@@ -205,37 +471,66 @@ def inference_scale(score_path, aggregate_path, save_path):
bbox_length *= 2
all_final_length.append(bbox_length.cpu())
+ elapsed = time.time() - start_time
+ perf_stats['bbox_time'].append(elapsed)
+ total_samples += pcl.shape[0]
+ perf_stats['bbox_samples'] = total_samples
+
if i % 10 == 9:
gc.collect()
pickle.dump((all_aggregated_pose, all_final_length), open(save_path, 'wb'))
return
-
- cfg.agent_type = 'scale'
- scale_agent = PoseNet(cfg)
- scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
- scale_agent.eval()
+
+ if is_om_model:
+ from om_wrappers import ScaleNetWrapper
+ scale_net = ScaleNetWrapper(scale_path, device=cfg.device)
+ print(f"Using ScaleNet OM: {scale_path}")
+ else:
+ cfg.agent_type = 'scale'
+ scale_agent = PoseNet(cfg)
+ scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
+ scale_agent.eval()
all_final_pose = []
all_final_length = []
+ total_samples = 0
for i, test_batch in enumerate(tqdm(dataloader, desc="scale")):
- batch_sample = process_batch(
- batch_sample = test_batch,
- device=cfg.device,
- pose_mode=cfg.pose_mode,
- )
- batch_sample.update({key: (None if value is None else value.to(cfg.device))
- for key, value in all_score_feature[i].items()})
- batch_sample['axes'] = all_aggregated_pose[i][:, :3, :3].to(cfg.device)
- cal_mat, length = scale_agent.pred_scale_func(batch_sample)
+ start_time = time.time()
+
+ pts_feat = all_score_feature[i]['pts_feat']
+ axes = all_aggregated_pose[i][:, :3, :3]
+
+ if is_om_model:
+ with torch.no_grad():
+ length = scale_net(pts_feat, axes)
+ cal_mat = axes # pred_scale_func returns axes unchanged ("historical reasons")
+ else:
+ axes = axes.to(cfg.device)
+ batch_sample = process_batch(
+ batch_sample=test_batch,
+ device=cfg.device,
+ pose_mode=cfg.pose_mode,
+ )
+ batch_sample.update({key: (None if value is None else value.to(cfg.device))
+ for key, value in all_score_feature[i].items()})
+ batch_sample['axes'] = axes
+ cal_mat, length = scale_agent.pred_scale_func(batch_sample)
+
final_pose = all_aggregated_pose[i].clone()
final_pose[:, :3, :3] = cal_mat.cpu()
all_final_pose.append(final_pose.cpu())
all_final_length.append(length.cpu())
+
+ elapsed = time.time() - start_time
+ perf_stats['scale_time'].append(elapsed)
+ total_samples += length.shape[0]
+ perf_stats['scale_samples'] = total_samples
+
if i % 4 == 3:
gc.collect()
-
+
pickle.dump((all_final_pose, all_final_length), open(save_path, 'wb'))
def get_detect_match(cls_path, save_path):
@@ -312,6 +607,52 @@ def print_metrics(dm_path, criterion_path, save_path):
metrics.dump_json(save_path)
+def print_performance_stats():
+ """Print performance statistics for each inference stage."""
+ print("\n" + "="*60)
+ print("Performance Statistics")
+ print("="*60)
+
+ stages = [
+ ('Score Network', 'score_time', 'score_samples'),
+ ('Energy Network', 'energy_time', 'energy_samples'),
+ ('Pose Aggregation', 'aggregate_time', 'aggregate_samples'),
+ ('Scale Network', 'scale_time', 'scale_samples'),
+ ('Bbox Calculation', 'bbox_time', 'bbox_samples'),
+ ]
+
+ for stage_name, time_key, samples_key in stages:
+ if len(perf_stats[time_key]) > 0:
+ times = perf_stats[time_key]
+ samples = perf_stats[samples_key]
+ total_time = sum(times)
+ avg_batch_time = total_time / len(times)
+ fps = samples / total_time if total_time > 0 else 0
+
+ print(f"\n{stage_name}:")
+ print(f" Total batches: {len(times)}")
+ print(f" Total samples: {samples}")
+ print(f" Total time: {total_time:.3f}s")
+ print(f" Avg batch time: {avg_batch_time:.3f}s")
+ print(f" FPS: {fps:.3f}")
+ print(f" Avg latency per sample: {1000/fps if fps > 0 else 0:.2f}ms")
+
+ # Calculate overall statistics
+ total_samples = perf_stats['score_samples']
+ total_time = sum(perf_stats['score_time']) + sum(perf_stats.get('energy_time', [0])) + \
+ sum(perf_stats.get('aggregate_time', [0])) + \
+ sum(perf_stats.get('scale_time', []) if perf_stats['scale_time'] else perf_stats.get('bbox_time', []))
+
+ if total_time > 0:
+ overall_fps = total_samples / total_time
+ print(f"\n{'='*60}")
+ print(f"Overall Pipeline:")
+ print(f" Total samples: {total_samples}")
+ print(f" Total time: {total_time:.3f}s")
+ print(f" Overall FPS: {overall_fps:.3f}")
+ print(f" Avg latency per sample: {1000/overall_fps:.2f}ms")
+ print("="*60)
+
def visualize_pose_distribution(score_path, dm_path):
all_pred_pose, _ = pickle.load(open(score_path, 'rb'))
all_dm: DetectMatch = pickle.load(open(dm_path, 'rb'))
@@ -334,35 +675,44 @@ def visualize_pose_distribution(score_path, dm_path):
all_dm.draw_image(index=index)
set_trace()
-os.makedirs(f'results/evaluation_results/{cfg.result_dir}', exist_ok=True)
+if __name__ == '__main__':
+ dataloader = get_dataloader()
+ os.makedirs(f'results/evaluation_results/{cfg.result_dir}', exist_ok=True)
-score_model_name = '_'.join(cfg.pretrained_score_model_path.split('/')[-2:])
-score_save_path = f'results/evaluation_results/{cfg.result_dir}/score_prediction_{score_model_name}.pkl'
-inference_score(score_save_path)
+ score_model_name = '_'.join(cfg.pretrained_score_model_path.split('/')[-2:])
+ score_save_path = f'results/evaluation_results/{cfg.result_dir}/score_prediction_{score_model_name}.pkl'
-aggregate_save_path = f'results/evaluation_results/{cfg.result_dir}/aggregated.pkl'
-if cfg.pretrained_energy_model_path is not None:
- energy_model_name = '_'.join(cfg.pretrained_energy_model_path.split('/')[-2:])
+ is_om_model = cfg.pretrained_score_model_path.endswith('.om')
+ if not is_om_model:
+ import torch_npu
+ torch_npu.npu.set_compile_mode(jit_compile=False)
+ score_net = inference_score_decoupled(score_save_path)
+ if is_om_model and score_net:
+ score_net.release()
+ del score_net
+ gc.collect()
+
+ aggregate_save_path = f'results/evaluation_results/{cfg.result_dir}/aggregated.pkl'
+ energy_om_path = getattr(cfg, 'pretrained_energy_om_path', None) or getattr(cfg, 'pretrained_energy_model_path', None)
+
+ energy_model_name = '_'.join(energy_om_path.split('/')[-2:])
energy_save_path = f'results/evaluation_results/{cfg.result_dir}/energy_prediction_{energy_model_name}.pkl'
+ pointnet2_energy_om_path = getattr(cfg, 'pretrained_pointnet2_energy_model_path', None)
inference_energy(score_save_path, energy_save_path)
aggregate_pose(score_save_path, energy_save_path, aggregate_save_path)
-else:
- aggregate_pose(score_save_path, None, aggregate_save_path)
-
-if cfg.pretrained_scale_model_path is not None:
scale_model_name = '_'.join(cfg.pretrained_scale_model_path.split('/')[-2:])
-else:
- scale_model_name = 'scale-none'
-cls_save_path = f'results/evaluation_results/{cfg.result_dir}/scale_prediction_{scale_model_name}.pkl'
-inference_scale(score_save_path, aggregate_save_path, cls_save_path)
-dm_save_path = f'results/evaluation_results/{cfg.result_dir}/detect_match.pkl'
-get_detect_match(cls_save_path, dm_save_path)
+ cls_save_path = f'results/evaluation_results/{cfg.result_dir}/scale_prediction_{scale_model_name}.pkl'
+ inference_scale(score_save_path, aggregate_save_path, cls_save_path)
+
+ dm_save_path = f'results/evaluation_results/{cfg.result_dir}/detect_match.pkl'
+ get_detect_match(cls_save_path, dm_save_path)
+
+ criterion_save_path = f'results/evaluation_results/{cfg.result_dir}/criterion.pkl'
+ get_criterion(dm_save_path, criterion_save_path)
-criterion_save_path = f'results/evaluation_results/{cfg.result_dir}/criterion.pkl'
-get_criterion(dm_save_path, criterion_save_path)
+ metrics_save_path = f'results/evaluation_results/{cfg.result_dir}/metrics.json'
+ print_metrics(dm_save_path, criterion_save_path, metrics_save_path)
-metrics_save_path = f'results/evaluation_results/{cfg.result_dir}/metrics.json'
-print_metrics(dm_save_path, criterion_save_path, metrics_save_path)
-# visualize_pose_distribution(score_save_path, dm_save_path)
-os._exit(0)
\ No newline at end of file
+ # Print performance statistics
+ print_performance_stats()
@@ -69,55 +69,189 @@ def get_dataloader(data_dir: str):
)
return iter(dataloader)
-cfg.agent_type = 'score'
-score_agent = PoseNet(cfg)
-score_agent.load_ckpt(model_dir=cfg.pretrained_score_model_path, model_path=True, load_model_only=True)
-score_agent.eval()
-
-cfg.agent_type = 'energy'
-energy_agent = PoseNet(cfg)
-energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
-energy_agent.eval()
-
-if cfg.pretrained_scale_model_path:
- cfg.agent_type = 'scale'
- scale_agent = PoseNet(cfg)
- scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
- scale_agent.eval()
+# PTH/OM model loading
+is_om_model = cfg.pretrained_score_model_path.endswith('.om')
-def work_batch(test_batch, prev_pose):
- batch_sample = process_batch(
- batch_sample = test_batch,
- device=cfg.device,
- pose_mode=cfg.pose_mode,
- )
-
- _prev_pose = prev_pose.clone()
- _prev_pose[:, -3:] -= batch_sample['pts_center']
+if not is_om_model:
+ import torch_npu
+ torch_npu.npu.set_compile_mode(jit_compile=False)
cfg.agent_type = 'score'
- score_pred_results, _ = score_agent.pred_func(
- data=batch_sample,
- repeat_num=cfg.eval_repeat_num,
- T0=cfg.T0,
- init_x=_prev_pose,
- return_average_res=False,
- return_process=False,
- )
- score_feature = {
- 'pts_feat': batch_sample['pts_feat'].clone(),
- 'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].clone()),
- }
-
+ score_agent = PoseNet(cfg)
+ score_agent.load_ckpt(model_dir=cfg.pretrained_score_model_path, model_path=True, load_model_only=True)
+ score_agent.eval()
+
cfg.agent_type = 'energy'
- energy_pred_results = energy_agent.get_energy(
- data=batch_sample,
- pose_samples=score_pred_results,
- T=1e-5,
- mode='test',
- extract_feature=True
+ energy_agent = PoseNet(cfg)
+ energy_agent.load_ckpt(model_dir=cfg.pretrained_energy_model_path, model_path=True, load_model_only=True)
+ energy_agent.eval()
+
+ if cfg.pretrained_scale_model_path:
+ cfg.agent_type = 'scale'
+ scale_agent = PoseNet(cfg)
+ scale_agent.load_ckpt(model_dir=cfg.pretrained_scale_model_path, model_path=True, load_model_only=True)
+ scale_agent.eval()
+else:
+ from om_wrappers import (create_score_network, create_ode_sampler,
+ DINOv2Wrapper, EnergyNetWrapper,
+ PointNet2EncoderWrapper, ScaleNetWrapper)
+ from networks.gf_algorithms.sde import init_sde, ve_sde_numpy
+ from datasets.datasets_omni6dpose import process_batch_numpy
+ from utils.misc import get_pose_dim
+
+ prior_fn, _, sde_fn, sampling_eps, _ = init_sde('ve')
+ pointnet2_score_om_path = getattr(cfg, 'pretrained_pointnet2_score_model_path', None)
+
+ score_net = create_score_network(
+ checkpoint_path=cfg.pretrained_score_model_path,
+ device=cfg.device,
+ pointnet2_om_path=pointnet2_score_om_path,
)
+ sde_dir = {'prior_fn': prior_fn, 'sde_fn': ve_sde_numpy}
+ sampler = create_ode_sampler(score_network=score_net, sde=sde_dir, device=cfg.device)
+
+ energy_net = EnergyNetWrapper(cfg.pretrained_energy_model_path, device=cfg.device)
+ pointnet2_energy_encoder = PointNet2EncoderWrapper(
+ cfg.pretrained_pointnet2_energy_model_path, device=cfg.device)
+ print(f"Using EnergyNet OM: {cfg.pretrained_energy_model_path}")
+
+ if cfg.pretrained_scale_model_path:
+ scale_net = ScaleNetWrapper(cfg.pretrained_scale_model_path, device=cfg.device)
+ print(f"Using ScaleNet OM: {cfg.pretrained_scale_model_path}")
+
+ if cfg.dino != 'none':
+ dino_om_path = getattr(cfg, 'pretrained_dino_model_path', None)
+ if dino_om_path is not None and dino_om_path.endswith('.om'):
+ dino_model = DINOv2Wrapper(dino_om_path, device=cfg.device)
+ print(f"Using DINOv2 OM: {dino_om_path}")
+ else:
+ import torch.hub
+ dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(cfg.device)
+ dino_model.requires_grad_(False)
+ print("Using DINOv2 PyTorch")
+
+ def extract_dino_features(batch_sample):
+ roi_rgb = batch_sample['roi_rgb']
+ roi_xs = batch_sample['roi_xs']
+ roi_ys = batch_sample['roi_ys']
+ if is_om_model:
+ return dino_model(roi_rgb, roi_xs, roi_ys)
+ feat = dino_model.get_intermediate_layers(roi_rgb)[0]
+ xs = roi_xs // 14
+ ys = roi_ys // 14
+ pos = xs * 16 + ys
+ pos = torch.unsqueeze(pos, -1).expand(-1, -1, 384)
+ rgb_feat = torch.gather(feat, 1, pos)
+ rgb_feat.requires_grad_(False)
+ return rgb_feat
+ else:
+ def extract_dino_features(batch_sample):
+ return None
+
+def work_batch(test_batch, prev_pose):
+ if is_om_model:
+ batch_sample = process_batch_numpy(test_batch, pose_mode=cfg.pose_mode)
+ else:
+ batch_sample = process_batch(
+ batch_sample = test_batch,
+ device=cfg.device,
+ pose_mode=cfg.pose_mode,
+ )
+
+ bs = prev_pose.shape[0]
+ pose_dim = get_pose_dim(cfg.pose_mode) if is_om_model else prev_pose.shape[1]
+ repeat_num = cfg.eval_repeat_num
+
+ if is_om_model:
+ # OM score path
+ t0 = time.time()
+ # DINOv2 + PointNet2 feature extraction
+ rgb_feat = extract_dino_features(batch_sample)
+ pts_feat = score_net.extract_pts_feat(batch_sample['pts'], rgb_feat)
+
+ # Construct init_x: repeat prev_pose and add noise
+ _prev_pose = prev_pose.cpu().numpy().copy()
+ _prev_pose[:, -3:] -= batch_sample['pts_center']
+ noise = prior_fn((bs * repeat_num, pose_dim), T=cfg.T0).numpy()
+ prev_pose_repeated = np.repeat(_prev_pose, repeat_num, axis=0)
+ init_x_repeated = prev_pose_repeated + noise
+
+ pts_feat_repeated = np.repeat(pts_feat[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+ _, sampled_pose = sampler.sample(
+ pts_feat=pts_feat_repeated,
+ rgb_feat=None,
+ batch_size=bs * repeat_num,
+ pose_dim=pose_dim,
+ T=cfg.T0,
+ eps=sampling_eps,
+ rtol=1e-5,
+ atol=1e-5,
+ denoise=True,
+ init_x=init_x_repeated,
+ pts_center=None if batch_sample.get('pts_center') is None else
+ np.repeat(batch_sample['pts_center'][:, np.newaxis, :], repeat_num, axis=1).reshape(bs * repeat_num, -1),
+ )
+ score_pred_results = sampled_pose.reshape(bs, repeat_num, pose_dim)
- sorted_pose, sorted_energy = sort_poses_by_energy(score_pred_results, energy_pred_results)
+ score_feature = {
+ 'pts_feat': pts_feat,
+ 'rgb_feat': rgb_feat,
+ }
+
+ # OM energy
+ pts_with_rgb = np.concatenate([batch_sample['pts'], rgb_feat], axis=-1) # [bs, 1024, 387]
+ pts_feat_energy = pointnet2_energy_encoder(pts_with_rgb)
+
+ pose_samples = score_pred_results.reshape(bs * repeat_num, -1).astype(np.float32)
+ pose_samples[:, -3:] -= np.repeat(batch_sample['pts_center'], repeat_num, axis=0)
+ t = np.full((bs * repeat_num, 1), 1e-5, dtype=np.float32)
+ pts_feat_repeated_energy = np.repeat(pts_feat_energy[np.newaxis, ...], repeat_num, axis=1).reshape(bs * repeat_num, -1)
+
+ with torch.no_grad():
+ pred_energy = energy_net(pts_feat_repeated_energy, pose_samples, t)
+ energy_pred_results = pred_energy.reshape(bs, repeat_num, -1)
+ perf_stats['score_time'].append(time.time() - t0)
+ perf_stats['score_samples'] += bs
+
+ else:
+ # PTH path (original)
+ t0 = time.time()
+ _prev_pose = prev_pose.clone()
+ _prev_pose[:, -3:] -= batch_sample['pts_center']
+ cfg.agent_type = 'score'
+ score_pred_results, _ = score_agent.pred_func(
+ data=batch_sample,
+ repeat_num=cfg.eval_repeat_num,
+ T0=cfg.T0,
+ init_x=_prev_pose,
+ return_average_res=False,
+ return_process=False,
+ )
+ score_feature = {
+ 'pts_feat': batch_sample['pts_feat'].clone(),
+ 'rgb_feat': (None if batch_sample['rgb_feat'] is None else batch_sample['rgb_feat'].clone()),
+ }
+
+ cfg.agent_type = 'energy'
+ energy_pred_results = energy_agent.get_energy(
+ data=batch_sample,
+ pose_samples=score_pred_results,
+ T=1e-5,
+ mode='test',
+ extract_feature=True
+ )
+ perf_stats['score_time'].append(time.time() - t0)
+ perf_stats['score_samples'] += bs
+
+ # Convert numpy to tensor for sort and aggregate operations
+ if is_om_model:
+ score_pred_results = torch.from_numpy(score_pred_results)
+ energy_pred_results = torch.from_numpy(energy_pred_results)
+
+ # aggregate + scale
+ t0 = time.time()
+ sorted_pose, sorted_energy = sort_poses_by_energy(
+ score_pred_results, energy_pred_results)
bs = score_pred_results.shape[0]
retain_num = int(cfg.eval_repeat_num * cfg.retain_ratio)
good_pose = sorted_pose[:, :retain_num, :]
@@ -146,28 +280,36 @@ def work_batch(test_batch, prev_pose):
gt_length = test_batch['bbox_side_len'].numpy()
if cfg.pretrained_scale_model_path:
- cfg.agent_type = 'scale'
- batch_sample.update(score_feature)
- batch_sample['axes'] = aggregated_pose[:, :3, :3].to(cfg.device)
- with torch.no_grad():
- pred_length = scale_agent.net(batch_sample)
- pred_length = pred_length.cpu().numpy()
+ if is_om_model:
+ pred_length = scale_net(pts_feat, aggregated_pose[:, :3, :3])
+ else:
+ cfg.agent_type = 'scale'
+ batch_sample.update(score_feature)
+ batch_sample['axes'] = aggregated_pose[:, :3, :3].to(cfg.device)
+ with torch.no_grad():
+ pred_length = scale_agent.net(batch_sample)
+ pred_length = pred_length.cpu().numpy()
else:
pred_length = np.ones((pred_pose.shape[0], 3))
detect_match = DetectMatch(
- gt_affine=gt_pose, gt_size=gt_length,
- gt_sym_labels=array_to_SymLabel(test_batch['sym_info']),
+ gt_affine=gt_pose, gt_size=gt_length,
+ gt_sym_labels=array_to_SymLabel(test_batch['sym_info']),
gt_class_labels=test_batch['class_label'],
pred_affine=pred_pose, pred_size=pred_length,
# image_path=[path + 'color.png' for path in test_batch['path']],
camera_intrinsics=array_to_CameraIntrinsicsBase(test_batch['intrinsics'])
)
+ perf_stats['aggregate_time'].append(time.time() - t0)
+ perf_stats['aggregate_samples'] += bs
- prev_pose = torch.zeros_like(prev_pose, device=cfg.device)
+ prev_pose = torch.zeros_like(
+ prev_pose,
+ device=cfg.device if not is_om_model else 'cpu'
+ )
prev_pose[:, :-3] = get_pose_representation(aggregated_pose[:, :3, :3], cfg.pose_mode)
prev_pose[:, -3:] = aggregated_pose[:, :3, 3]
-
+
return detect_match, prev_pose
img_list = Dataset.glob_prefix(root = cfg.data_path)
@@ -212,6 +354,13 @@ pbar = tqdm(total=total_objects)
for i in range(30):
add_dataloader()
+perf_stats = {
+ 'score_time': [],
+ 'score_samples': 0,
+ 'aggregate_time': [],
+ 'aggregate_samples': 0,
+}
+
while 1:
test_batch = []
prev_pose = []
@@ -233,15 +382,24 @@ while 1:
f.write(dataloader.save_path + '\n')
dd.add(dataloader)
continue
- test_batch.append(batch)
length = dataloader._dataset.num_valid
+ if batch.get('_corrupted', torch.tensor(False)).any():
+ print(f"[SKIP] Corrupted frame in {dataloader.save_path}")
+ pbar.update(length)
+ continue
+ test_batch.append(batch)
try:
prev_pose.append(dataloader.prev_pose)
except:
- pose = torch.zeros(length, get_pose_dim(cfg.pose_mode), device=cfg.device) # on gpu
+ pose = torch.zeros(
+ length, get_pose_dim(cfg.pose_mode),
+ device=cfg.device if not is_om_model else 'cpu'
+ )
assert batch['affine'].shape[0] == length, set_trace()
for j in range(length):
- noise_gt_pose = add_noise_to_RT(batch['affine'][j].to(cfg.device).unsqueeze(0))[0]
+ noise_gt_pose = add_noise_to_RT(batch['affine'][j].to(
+ cfg.device if not is_om_model else 'cpu'
+ ).unsqueeze(0))[0]
pose[j, :-3] = get_pose_representation(
noise_gt_pose[:3, :3].unsqueeze(0),
pose_mode=cfg.pose_mode
@@ -252,7 +410,11 @@ while 1:
if split_pos[-1][0] > cfg.batch_size - 8:
break
if test_batch == []:
- break
+ for dl in dd:
+ dataloaders.remove(dl)
+ if len(dataloaders) == 0:
+ break
+ continue
keys = {key for key, value in test_batch[0].items() if type(value) != list}
test_batch = {
@@ -277,6 +439,43 @@ while 1:
pbar.close()
+# Print performance statistics
+print("\n" + "="*60)
+print("Performance Statistics")
+print("="*60)
+
+stages = [
+ ('Score + Energy', 'score_time', 'score_samples'),
+ ('Aggregate + Scale', 'aggregate_time', 'aggregate_samples'),
+]
+
+for stage_name, time_key, samples_key in stages:
+ times = perf_stats[time_key]
+ samples = perf_stats[samples_key]
+ if len(times) > 0:
+ t_total = sum(times)
+ fps = samples / t_total if t_total > 0 else 0
+ print(f"\n{stage_name}:")
+ print(f" Total batches: {len(times)}")
+ print(f" Total samples: {samples}")
+ print(f" Total time: {t_total:.3f}s")
+ print(f" Avg batch time: {t_total/len(times):.3f}s")
+ print(f" FPS: {fps:.3f}")
+ print(f" Avg latency per sample: {1000/fps:.2f}ms" if fps > 0 else "")
+
+total_samples = perf_stats['score_samples']
+total_time = sum(perf_stats['score_time']) + sum(perf_stats['aggregate_time'])
+
+if total_time > 0:
+ overall_fps = total_samples / total_time
+ print(f"\n{'='*60}")
+ print(f"Overall Pipeline:")
+ print(f" Total samples: {total_samples}")
+ print(f" Total time: {total_time:.3f}s")
+ print(f" Overall FPS: {overall_fps:.3f}")
+ print(f" Avg latency per sample: {1000/overall_fps:.2f}ms")
+ print("="*60)
+
all_dm = []
all_crit = []
for path in tqdm(video_paths):
@@ -1,15 +1,19 @@
import os
import sys
+
+# Disable OpenCV GUI for headless environments
+os.environ['QT_QPA_PLATFORM'] = 'offscreen'
import numpy as np
from tqdm import tqdm
import pickle
import torch
+import torch_npu
+torch_npu.npu.set_compile_mode(jit_compile=False)
import random
import gc
import cv2
-import open3d as o3d
-import pyrealsense2 as rs
-import pyrealsense2 as rs
+# import open3d as o3d # Not needed for offline inference
+# import pyrealsense2 as rs # Not needed for offline inference
import numpy as np
import glob
@@ -25,9 +29,9 @@ from utils.so3_visualize import visualize_so3
from cutoop.eval_utils import DetectMatch, Metrics
from configs.config import get_config
from datasets.datasets_infer import InferDataset
-
-from flask import Flask, request
-flask_app = Flask(__name__)
+from runners.evaluation_single import apply_spec_ops_patches
+# from flask import Flask, request
+# flask_app = Flask(__name__)
class GenPose2:
@@ -243,11 +247,13 @@ def visualize_pose(data:InferDataset, all_final_pose, all_final_length, visualiz
all_final_length = all_final_length[0].cpu().numpy()
for index, (obj_pose, obj_length) in enumerate(zip(all_final_pose, all_final_length)):
- if visualize_pts:
- pts = data.get_objects()['pts'].cpu().numpy()[index]
- pcd = o3d.geometry.PointCloud()
- pcd.points = o3d.utility.Vector3dVector(pts)
- o3d.visualization.draw_geometries([pcd])
+ # open3d Not needed for offline inference
+ # if visualize_pts:
+ # pts = data.get_objects()['pts'].cpu().numpy()[index]
+ # pcd = o3d.geometry.PointCloud()
+ # pcd.points = o3d.utility.Vector3dVector(pts)
+ # o3d.visualization.draw_geometries([pcd])
+ # print(f"Object {index}: visualize_pts is not supported (open3d disabled)")
color_img = DetectMatch._draw_image(
vis_img=color_img,
pred_affine=obj_pose,
@@ -264,17 +270,19 @@ def visualize_pose(data:InferDataset, all_final_pose, all_final_length, visualiz
thickness=True,
)
- if visualize_image:
- cv2.namedWindow('rgb')
- cv2.imshow('rgb', color_img)
- cv2.waitKey()
- cv2.destroyAllWindows()
+ # Not needed for offline inference
+ # if visualize_image:
+ # cv2.namedWindow('rgb')
+ # cv2.imshow('rgb', color_img)
+ # cv2.waitKey()
+ # cv2.destroyAllWindows()
return color_img
def main():
######################################## PARAMETERS ########################################
- DATA_PATH = 'data/Omni6DPose/ROPE/000007' # Path to the data
+ DATA_PATH = 'omin6dpose-000a/ROPE/000000/' # Path to the data
+ RESULT_DIR = 'result_images' # Output directory for result images
TRACKING = True # Tracking mode
# Tracking parameter, if the relative pose between the current frame and the previous frame
@@ -286,8 +294,13 @@ def main():
ENERGY_MODEL_PATH='results/ckpts/EnergyNet/energynet.pth' # Path to the energy model
SCALE_MODEL_PATH='results/ckpts/ScaleNet/scalenet.pth' # Path to the scale model
PREV_POSE = None # Previous pose
+ apply_spec_ops_patches()
######################################## PARAMETERS ########################################
+ # Create result directory
+ os.makedirs(RESULT_DIR, exist_ok=True)
+ print(f"Results will be saved to: {os.path.abspath(RESULT_DIR)}")
+
''' load data '''
# Get data from image file
color_images = sorted(glob.glob(DATA_PATH + '/*_color.png'))
@@ -297,20 +310,26 @@ def main():
scale_model_path=SCALE_MODEL_PATH,
)
- cv2.namedWindow('rgb')
+ # cv2.namedWindow('rgb') # Not needed for offline inference
for index, color_image in enumerate(tqdm(color_images)):
data_prefix = color_image.replace('color.png', '')
data = InferDataset.alternetive_init(data_prefix, img_size=GenPose2.cfg.img_size, device=GenPose2.cfg.device, n_pts=GenPose2.cfg.num_points)
pose, length = GenPose2.inference(data, PREV_POSE, TRACKING, TRACKING_T0)
color_image_w_pose = visualize_pose(data, pose, length, visualize_image=False)
+
+ # Save result image to result_images directory
+ image_filename = os.path.basename(color_image) # e.g., "000123_color.png"
+ output_filename = image_filename.replace('color.png', '_result.png') # "000123_result.png"
+ output_path = os.path.join(RESULT_DIR, output_filename)
+ cv2.imwrite(output_path, color_image_w_pose)
+
PREV_POSE = pose
- cv2.imshow('rgb', color_image_w_pose)
- cv2.waitKey(1)
+ # cv2.imshow('rgb', color_image_w_pose) # Not needed for offline inference
+ # cv2.waitKey(1)
- cv2.destroyAllWindows()
+ # cv2.destroyAllWindows() # Not needed for offline inference
if __name__ == '__main__':
main()
-
@@ -3,10 +3,10 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
--pretrained_score_model_path results/ckpts/ScoreNet/scorenet.pth \
--pretrained_energy_model_path results/ckpts/EnergyNet/energynet.pth \
--pretrained_scale_model_path results/ckpts/ScaleNet/scalenet.pth \
+--data_path omni6dpose-000000/ROPE/ \
--sampler_mode ode \
--percentage_data_for_test 1.0 \
+--batch_size 32 \
--seed 0 \
--result_dir single \
--eval_repeat_num 50 \
@@ -14,4 +14,5 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
--T0 0.55 \
--dino pointwise \
--num_worker 32 \
\ No newline at end of file
+--real_drop 3 \
+--device npu:0
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,23 @@
+#!/bin/bash
+# OM Model Evaluation Script
+# Uses separate PointNet2 OM and ScoreNet OM models for better performance
+CUDA_VISIBLE_DEVICES=0 python runners/evaluation_single.py \
+--pretrained_dino_model_path om_models/dinov2_vits14.om \
+--pretrained_pointnet2_score_model_path om_models/pointnet2_from_score.om \
+--pretrained_pointnet2_energy_model_path om_models/pointnet2_from_energy.om \
+--pretrained_score_model_path om_models/scorenet.om \
+--pretrained_energy_model_path om_models/energynet.om \
+--pretrained_scale_model_path om_models/scalenet.om \
+--data_path omni6dpose-000000/ROPE/ \
+--sampler_mode ode \
+--percentage_data_for_test 1.0 \
+--batch_size 16 \
+--seed 0 \
+--result_dir single_om \
+--eval_repeat_num 50 \
+--clustering 1 \
+--T0 0.55 \
+--dino pointwise \
+--num_worker 32 \
+--real_drop 3 \
+--device npu:0
@@ -3,10 +3,10 @@ CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \
--pretrained_score_model_path results/ckpts/ScoreNet/scorenet.pth \
--pretrained_energy_model_path results/ckpts/EnergyNet/energynet.pth \
--pretrained_scale_model_path results/ckpts/ScaleNet/scalenet.pth \
+--data_path omni6dpose-000000/ROPE/ \
--sampler_mode ode \
--percentage_data_for_test 1.0 \
+--batch_size 16 \
--seed 0 \
--result_dir tracking \
--eval_repeat_num 50 \
new file mode 100644
@@ -0,0 +1,21 @@
+#!/bin/bash
+# OM Model Tracking Evaluation Script
+CUDA_VISIBLE_DEVICES=0 python runners/evaluation_tracking.py \
+--pretrained_dino_model_path om_models/dinov2_vits14.om \
+--pretrained_pointnet2_score_model_path om_models/pointnet2_from_score.om \
+--pretrained_pointnet2_energy_model_path om_models/pointnet2_from_energy.om \
+--pretrained_score_model_path om_models/scorenet.om \
+--pretrained_energy_model_path om_models/energynet.om \
+--pretrained_scale_model_path om_models/scalenet.om \
+--data_path omni6dpose-000000/ROPE/ \
+--sampler_mode ode \
+--percentage_data_for_test 1.0 \
+--batch_size 4 \
+--seed 0 \
+--result_dir tracking_om \
+--eval_repeat_num 50 \
+--clustering 1 \
+--T0 0.25 \
+--dino pointwise \
+--num_worker 32 \
+--device npu:0
@@ -302,6 +302,24 @@ def normalize_rotation(rotation, rotation_mode):
raise NotImplementedError
return rotation
+
+def normalize_rotation_numpy(rotation, rotation_mode):
+ """Numpy version of normalize_rotation for OM inference (rot_matrix mode only)."""
+ a1 = rotation[:, :3]
+ a2 = rotation[:, 3:]
+ b1 = a1 / np.linalg.norm(a1, axis=-1, keepdims=True)
+ b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
+ b2 = b2 / np.linalg.norm(b2, axis=-1, keepdims=True)
+ b3 = np.stack([
+ b1[..., 1] * b2[..., 2] - b1[..., 2] * b2[..., 1],
+ b1[..., 2] * b2[..., 0] - b1[..., 0] * b2[..., 2],
+ b1[..., 0] * b2[..., 1] - b1[..., 1] * b2[..., 0],
+ ], axis=-1)
+ rot_matrix = np.stack((b1, b2, b3), axis=-1)
+ rotation[:, :3] = rot_matrix[:, :, 0]
+ rotation[:, 3:6] = rot_matrix[:, :, 1]
+ return rotation
+
if __name__ == '__main__':
quat = torch.randn(2, 3, 4)
@@ -553,6 +553,13 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
return quaternions[..., 1:] / sin_half_angles_over_angles
+
+def _simple_cross(a, b):
+ c1 = a[...,1]*b[...,2] - a[...,2]*b[...,1]
+ c2 = a[...,2]*b[...,0] - a[...,0]*b[...,2]
+ c3 = a[...,0]*b[...,1] - a[...,1]*b[...,0]
+ return torch.stack([c1, c2, c3], dim=-1)
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
@@ -573,7 +580,7 @@ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
- b3 = torch.cross(b1, b2, dim=-1)
+ b3 = _simple_cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-2)