05360171创建于 2022年3月18日历史提交
diff --git a/common/camera.py b/common/camera.py
index 4a835bc..0669d2d 100644
--- a/common/camera.py
+++ b/common/camera.py
@@ -87,4 +87,54 @@ def project_to_2d_linear(X, camera_params):
     
     XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
     
-    return f*XX + c
\ No newline at end of file
+    return f*XX + c
+
+def prepare_data(dataset, keypoints):
+    print('Preparing data...')
+    for subject in dataset.subjects():
+        for action in dataset[subject].keys():
+            anim = dataset[subject][action]
+            
+            if 'positions' in anim:
+                positions_3d = []
+                for cam in anim['cameras']:
+                    pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
+                    pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
+                    positions_3d.append(pos_3d)
+                anim['positions_3d'] = positions_3d
+
+    print('Loading 2D detections...')
+    keypoints_metadata = keypoints['metadata'].item()
+    keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
+    kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
+    joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
+    keypoints = keypoints['positions_2d'].item()
+
+    for subject in dataset.subjects():
+        assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
+        for action in dataset[subject].keys():
+            assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject)
+            if 'positions_3d' not in dataset[subject][action]:
+                continue
+                
+            for cam_idx in range(len(keypoints[subject][action])):
+                
+                # We check for >= instead of == because some videos in H3.6M contain extra frames
+                mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0]
+                assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length
+                
+                if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
+                    # Shorten sequence
+                    keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]
+
+            assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d'])
+            
+    for subject in keypoints.keys():
+        for action in keypoints[subject]:
+            for cam_idx, kps in enumerate(keypoints[subject][action]):
+                # Normalize camera frame
+                cam = dataset.cameras()[subject][cam_idx]
+                kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
+                keypoints[subject][action][cam_idx] = kps
+
+    return dataset, keypoints, kps_left, kps_right, joints_left, joints_right
\ No newline at end of file
diff --git a/common/generators.py b/common/generators.py
index 7fdcd3e..649cf14 100644
--- a/common/generators.py
+++ b/common/generators.py
@@ -200,6 +200,7 @@ class UnchunkedGenerator:
         self.cameras = [] if cameras is None else cameras
         self.poses_3d = [] if poses_3d is None else poses_3d
         self.poses_2d = poses_2d
+        self.max_length = 6115
         
     def num_frames(self):
         count = 0
@@ -217,8 +218,10 @@ class UnchunkedGenerator:
         for seq_cam, seq_3d, seq_2d in zip_longest(self.cameras, self.poses_3d, self.poses_2d):
             batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0)
             batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0)
+            delta = self.max_length - seq_2d.shape[0] - 2 * self.pad
+            assert delta >= 0, f"self.max_length:{self.max_length}, seq_2d.shape[0]:{seq_2d.shape[0]}, seq_3d.shape[0]:{seq_3d.shape[0]}, self.pad:{self.pad}, delta:{delta}"
             batch_2d = np.expand_dims(np.pad(seq_2d,
-                            ((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
+                            ((self.pad + self.causal_shift, self.pad - self.causal_shift + delta), (0, 0), (0, 0)),
                             'edge'), axis=0)
             if self.augment:
                 # Append flipped version
@@ -236,4 +239,4 @@ class UnchunkedGenerator:
                 batch_2d[1, :, :, 0] *= -1
                 batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left]
 
-            yield batch_cam, batch_3d, batch_2d
\ No newline at end of file
+            yield batch_cam, batch_3d, batch_2d, delta
\ No newline at end of file
diff --git a/common/model.py b/common/model.py
index abbe969..36939b6 100644
--- a/common/model.py
+++ b/common/model.py
@@ -29,8 +29,8 @@ class TemporalModelBase(nn.Module):
         self.relu = nn.ReLU(inplace=True)
         
         self.pad = [ filter_widths[0] // 2 ]
-        self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1)
-        self.shrink = nn.Conv1d(channels, num_joints_out*3, 1)
+        self.expand_bn = nn.BatchNorm2d(channels, momentum=0.1)
+        self.shrink = nn.Conv2d(channels, num_joints_out*3, 1)
         
 
     def set_bn_momentum(self, momentum):
@@ -68,8 +68,12 @@ class TemporalModelBase(nn.Module):
         sz = x.shape[:3]
         x = x.view(x.shape[0], x.shape[1], -1)
         x = x.permute(0, 2, 1)
+
+        x.unsqueeze_(-1)
         
         x = self._forward_blocks(x)
+
+        x.squeeze_(-1)
         
         x = x.permute(0, 2, 1)
         x = x.view(sz[0], -1, self.num_joints_out, 3)
@@ -99,7 +103,7 @@ class TemporalModel(TemporalModelBase):
         """
         super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
         
-        self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False)
+        self.expand_conv = nn.Conv2d(num_joints_in*in_features, channels, (filter_widths[0],1), bias=False)
         
         layers_conv = []
         layers_bn = []
@@ -110,13 +114,13 @@ class TemporalModel(TemporalModelBase):
             self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
             self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0)
             
-            layers_conv.append(nn.Conv1d(channels, channels,
-                                         filter_widths[i] if not dense else (2*self.pad[-1] + 1),
+            layers_conv.append(nn.Conv2d(channels, channels,
+                                         (filter_widths[i],1) if not dense else ((2*self.pad[-1] + 1),1),
                                          dilation=next_dilation if not dense else 1,
                                          bias=False))
-            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
-            layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
-            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
+            layers_bn.append(nn.BatchNorm2d(channels, momentum=0.1))
+            layers_conv.append(nn.Conv2d(channels, channels, 1, dilation=1, bias=False))
+            layers_bn.append(nn.BatchNorm2d(channels, momentum=0.1))
             
             next_dilation *= filter_widths[i]
             
@@ -164,7 +168,7 @@ class TemporalModelOptimized1f(TemporalModelBase):
         """
         super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
         
-        self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], stride=filter_widths[0], bias=False)
+        self.expand_conv = nn.Conv2d(num_joints_in*in_features, channels, (filter_widths[0],1), stride=filter_widths[0], bias=False)
         
         layers_conv = []
         layers_bn = []
@@ -175,10 +179,10 @@ class TemporalModelOptimized1f(TemporalModelBase):
             self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
             self.causal_shift.append((filter_widths[i]//2) if causal else 0)
             
-            layers_conv.append(nn.Conv1d(channels, channels, filter_widths[i], stride=filter_widths[i], bias=False))
-            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
-            layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
-            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
+            layers_conv.append(nn.Conv2d(channels, channels, (filter_widths[i],1), stride=filter_widths[i], bias=False))
+            layers_bn.append(nn.BatchNorm2d(channels, momentum=0.1))
+            layers_conv.append(nn.Conv2d(channels, channels, 1, dilation=1, bias=False))
+            layers_bn.append(nn.BatchNorm2d(channels, momentum=0.1))
             next_dilation *= filter_widths[i]
             
         self.layers_conv = nn.ModuleList(layers_conv)
diff --git a/common/utils.py b/common/utils.py
index b410826..b429097 100644
--- a/common/utils.py
+++ b/common/utils.py
@@ -44,4 +44,30 @@ def wrap(func, *args, unsqueeze=False):
 def deterministic_random(min_value, max_value, data):
     digest = hashlib.sha256(data.encode()).digest()
     raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False)
-    return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
\ No newline at end of file
+    return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value
+
+
+def fetch_actions(args, actions, keypoints, dataset):
+    out_poses_3d = []
+    out_poses_2d = []
+
+    for subject, action in actions:
+        poses_2d = keypoints[subject][action]
+        for i in range(len(poses_2d)): # Iterate across cameras
+            out_poses_2d.append(poses_2d[i])
+
+        poses_3d = dataset[subject][action]['positions_3d']
+        assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
+        for i in range(len(poses_3d)): # Iterate across cameras
+            out_poses_3d.append(poses_3d[i])
+
+    stride = args.downsample
+    if stride > 1:
+        # Downsample as requested
+        for i in range(len(out_poses_2d)):
+            out_poses_2d[i] = out_poses_2d[i][::stride]
+            if out_poses_3d is not None:
+                out_poses_3d[i] = out_poses_3d[i][::stride]
+        
+    return out_poses_3d, out_poses_2d
+
diff --git a/run.py b/run.py
index 2dca16c..5da3acc 100644
--- a/run.py
+++ b/run.py
@@ -22,11 +22,22 @@ from common.model import *
 from common.loss import *
 from common.generators import ChunkedGenerator, UnchunkedGenerator
 from time import time
-from common.utils import deterministic_random
+from common.utils import deterministic_random, fetch_actions
+from collections import OrderedDict
 
 args = parse_args()
 print(args)
 
+def proc_nodes_module(checkpoint):
+    new_state_dict = OrderedDict()
+    for k, v in checkpoint.items():
+        if "module." in k:
+            name = k.replace("module.", "")
+        else:
+            name = k
+        new_state_dict[name] = v
+    return new_state_dict
+
 try:
     # Create checkpoint directory if it does not exist
     os.makedirs(args.checkpoint)
@@ -48,53 +59,56 @@ elif args.dataset.startswith('custom'):
 else:
     raise KeyError('Invalid dataset')
 
-print('Preparing data...')
-for subject in dataset.subjects():
-    for action in dataset[subject].keys():
-        anim = dataset[subject][action]
-        
-        if 'positions' in anim:
-            positions_3d = []
-            for cam in anim['cameras']:
-                pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
-                pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
-                positions_3d.append(pos_3d)
-            anim['positions_3d'] = positions_3d
-
-print('Loading 2D detections...')
 keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz', allow_pickle=True)
-keypoints_metadata = keypoints['metadata'].item()
-keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
-kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
-joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
-keypoints = keypoints['positions_2d'].item()
-
-for subject in dataset.subjects():
-    assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
-    for action in dataset[subject].keys():
-        assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject)
-        if 'positions_3d' not in dataset[subject][action]:
-            continue
+
+dataset, keypoints = prepare_data(dataset, keypoints)
+
+# print('Preparing data...')
+# for subject in dataset.subjects():
+#     for action in dataset[subject].keys():
+#         anim = dataset[subject][action]
+        
+#         if 'positions' in anim:
+#             positions_3d = []
+#             for cam in anim['cameras']:
+#                 pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
+#                 pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
+#                 positions_3d.append(pos_3d)
+#             anim['positions_3d'] = positions_3d
+
+# print('Loading 2D detections...')
+# keypoints_metadata = keypoints['metadata'].item()
+# keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
+# kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
+# joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
+# keypoints = keypoints['positions_2d'].item()
+
+# for subject in dataset.subjects():
+#     assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
+#     for action in dataset[subject].keys():
+#         assert action in keypoints[subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject)
+#         if 'positions_3d' not in dataset[subject][action]:
+#             continue
             
-        for cam_idx in range(len(keypoints[subject][action])):
+#         for cam_idx in range(len(keypoints[subject][action])):
             
-            # We check for >= instead of == because some videos in H3.6M contain extra frames
-            mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0]
-            assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length
+#             # We check for >= instead of == because some videos in H3.6M contain extra frames
+#             mocap_length = dataset[subject][action]['positions_3d'][cam_idx].shape[0]
+#             assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length
             
-            if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
-                # Shorten sequence
-                keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]
+#             if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
+#                 # Shorten sequence
+#                 keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]
 
-        assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d'])
+#         assert len(keypoints[subject][action]) == len(dataset[subject][action]['positions_3d'])
         
-for subject in keypoints.keys():
-    for action in keypoints[subject]:
-        for cam_idx, kps in enumerate(keypoints[subject][action]):
-            # Normalize camera frame
-            cam = dataset.cameras()[subject][cam_idx]
-            kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
-            keypoints[subject][action][cam_idx] = kps
+# for subject in keypoints.keys():
+#     for action in keypoints[subject]:
+#         for cam_idx, kps in enumerate(keypoints[subject][action]):
+#             # Normalize camera frame
+#             cam = dataset.cameras()[subject][cam_idx]
+#             kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
+#             keypoints[subject][action][cam_idx] = kps
 
 subjects_train = args.subjects_train.split(',')
 subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(',')
@@ -205,9 +219,11 @@ if args.resume or args.evaluate:
     chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
     print('Loading checkpoint', chk_filename)
     checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
-    print('This model was trained for {} epochs'.format(checkpoint['epoch']))
-    model_pos_train.load_state_dict(checkpoint['model_pos'])
-    model_pos.load_state_dict(checkpoint['model_pos'])
+    if 'epoch' in checkpoint:
+        print('This model was trained for {} epochs'.format(checkpoint['epoch']))
+    state_dict= proc_nodes_module(checkpoint['model_pos'])
+    model_pos_train.load_state_dict(state_dict)
+    model_pos.load_state_dict(state_dict)
     
     if args.evaluate and 'model_traj' in checkpoint:
         # Load trajectory model if it contained in the checkpoint (e.g. for inference in the wild)
@@ -798,30 +814,6 @@ else:
             all_actions[action_name].append((subject, action))
             all_actions_by_subject[subject][action_name].append((subject, action))
 
-    def fetch_actions(actions):
-        out_poses_3d = []
-        out_poses_2d = []
-
-        for subject, action in actions:
-            poses_2d = keypoints[subject][action]
-            for i in range(len(poses_2d)): # Iterate across cameras
-                out_poses_2d.append(poses_2d[i])
-
-            poses_3d = dataset[subject][action]['positions_3d']
-            assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
-            for i in range(len(poses_3d)): # Iterate across cameras
-                out_poses_3d.append(poses_3d[i])
-
-        stride = args.downsample
-        if stride > 1:
-            # Downsample as requested
-            for i in range(len(out_poses_2d)):
-                out_poses_2d[i] = out_poses_2d[i][::stride]
-                if out_poses_3d is not None:
-                    out_poses_3d[i] = out_poses_3d[i][::stride]
-        
-        return out_poses_3d, out_poses_2d
-
     def run_evaluation(actions, action_filter=None):
         errors_p1 = []
         errors_p2 = []
@@ -838,7 +830,7 @@ else:
                 if not found:
                     continue
 
-            poses_act, poses_2d_act = fetch_actions(actions[action_key])
+            poses_act, poses_2d_act = fetch_actions(args, actions[action_key], keypoints, dataset)
             gen = UnchunkedGenerator(None, poses_act, poses_2d_act,
                                      pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
                                      kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)