from itertools import zip_longest
import numpy as np
import torch.distributed as dist
import math
class ChunkedGenerator:
"""
Batched data generator, used for training.
The sequences are split into equal-length chunks and padded as necessary.
Arguments:
batch_size -- the batch size to use for training
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
poses_2d -- list of input 2D keypoints, one element for each video
chunk_length -- number of output frames to predict for each training example (usually 1)
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
shuffle -- randomly shuffle the dataset before each epoch
random_seed -- initial seed to use for the random generator
augment -- augment the dataset by flipping poses horizontally
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
"""
def __init__(self, args, batch_size, cameras, poses_3d, poses_2d,
chunk_length, pad=0, causal_shift=0,
shuffle=True, random_seed=1234,
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None,
endless=False):
assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d))
assert cameras is None or len(cameras) == len(poses_2d)
pairs = []
for i in range(len(poses_2d)):
assert poses_3d is None or poses_3d[i].shape[0] == poses_2d[i].shape[0]
n_chunks = (poses_2d[i].shape[0] + chunk_length - 1) // chunk_length
offset = (n_chunks * chunk_length - poses_2d[i].shape[0]) // 2
bounds = np.arange(n_chunks+1)*chunk_length - offset
augment_vector = np.full(len(bounds - 1), False, dtype=bool)
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], augment_vector)
if augment:
pairs += zip(np.repeat(i, len(bounds - 1)), bounds[:-1], bounds[1:], ~augment_vector)
if cameras is not None:
self.batch_cam = np.empty((batch_size, cameras[0].shape[-1]))
if poses_3d is not None:
self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[0].shape[-2], poses_3d[0].shape[-1]))
self.batch_2d = np.empty((batch_size, chunk_length + 2*pad, poses_2d[0].shape[-2], poses_2d[0].shape[-1]))
self.num_batches = (len(pairs) + batch_size - 1) // batch_size
self.batch_size = batch_size
print(f"random seed is set to {random_seed}")
self.seed = random_seed
self.epoch = 0
self.pairs = pairs
self.shuffle = shuffle
self.pad = pad
self.causal_shift = causal_shift
self.endless = endless
self.state = None
self.sampler = args.sampler
self.drop_last = args.drop_last
self.rank = args.rank
self.num_replicas = args.num_gpus
assert self.num_replicas != 0
if args.sampler:
self.num_samples = math.ceil(len(self.pairs)/self.num_replicas)
else:
raise NotImplementedError("only support DDP sampler")
self.total_size = self.num_samples * self.num_replicas
self.cameras = cameras
self.poses_3d = poses_3d
self.poses_2d = poses_2d
self.augment = augment
self.kps_left = kps_left
self.kps_right = kps_right
self.joints_left = joints_left
self.joints_right = joints_right
def num_frames(self):
return self.num_batches * self.batch_size
def set_epoch(self, epoch):
self.epoch = epoch
def augment_enabled(self):
return self.augment
def next_pairs(self):
if self.state is None:
if self.shuffle:
print(f"self.epoch:{self.epoch} for rank:{self.rank}")
random = np.random.RandomState(self.seed + self.epoch)
pairs = random.permutation(self.pairs)
else:
pairs = self.pairs
return 0, pairs
else:
return self.state
def next_epoch(self):
enabled = True
while enabled:
start_idx, pairs = self.next_pairs()
assert len(pairs) == len(self.pairs)
if self.sampler:
sample_padding = self.total_size - len(self.pairs)
assert sample_padding < len(self.pairs)
if sample_padding > 0:
pairs = pairs + pairs[:sample_padding]
assert self.total_size == len(pairs), f"total_size:{self.total_size}, len(pairs):{len(pairs)}, sample_padding:{sample_padding}"
pairs = pairs[self.rank:self.total_size:self.num_replicas]
for b_i in range(start_idx, self.num_batches):
chunks = pairs[b_i*self.batch_size : (b_i+1)*self.batch_size]
if self.drop_last and len(chunks) < self.batch_size:
continue
for i, (seq_i, start_3d, end_3d, flip) in enumerate(chunks):
start_2d = start_3d - self.pad - self.causal_shift
end_2d = end_3d + self.pad - self.causal_shift
seq_2d = self.poses_2d[seq_i]
low_2d = max(start_2d, 0)
high_2d = min(end_2d, seq_2d.shape[0])
pad_left_2d = low_2d - start_2d
pad_right_2d = end_2d - high_2d
if pad_left_2d != 0 or pad_right_2d != 0:
self.batch_2d[i] = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0)), 'edge')
else:
self.batch_2d[i] = seq_2d[low_2d:high_2d]
if flip:
self.batch_2d[i, :, :, 0] *= -1
self.batch_2d[i, :, self.kps_left + self.kps_right] = self.batch_2d[i, :, self.kps_right + self.kps_left]
if self.poses_3d is not None:
seq_3d = self.poses_3d[seq_i]
low_3d = max(start_3d, 0)
high_3d = min(end_3d, seq_3d.shape[0])
pad_left_3d = low_3d - start_3d
pad_right_3d = end_3d - high_3d
if pad_left_3d != 0 or pad_right_3d != 0:
self.batch_3d[i] = np.pad(seq_3d[low_3d:high_3d], ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge')
else:
self.batch_3d[i] = seq_3d[low_3d:high_3d]
if flip:
self.batch_3d[i, :, :, 0] *= -1
self.batch_3d[i, :, self.joints_left + self.joints_right] = \
self.batch_3d[i, :, self.joints_right + self.joints_left]
if self.cameras is not None:
self.batch_cam[i] = self.cameras[seq_i]
if flip:
self.batch_cam[i, 2] *= -1
self.batch_cam[i, 7] *= -1
if self.endless:
self.state = (b_i + 1, pairs)
if self.poses_3d is None and self.cameras is None:
yield None, None, self.batch_2d[:len(chunks)]
elif self.poses_3d is not None and self.cameras is None:
yield None, self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
elif self.poses_3d is None:
yield self.batch_cam[:len(chunks)], None, self.batch_2d[:len(chunks)]
else:
yield self.batch_cam[:len(chunks)], self.batch_3d[:len(chunks)], self.batch_2d[:len(chunks)]
if self.endless:
self.state = None
else:
enabled = False
class UnchunkedGenerator:
"""
Non-batched data generator, used for testing.
Sequences are returned one at a time (i.e. batch size = 1), without chunking.
If data augmentation is enabled, the batches contain two sequences (i.e. batch size = 2),
the second of which is a mirrored version of the first.
Arguments:
cameras -- list of cameras, one element for each video (optional, used for semi-supervised training)
poses_3d -- list of ground-truth 3D poses, one element for each video (optional, used for supervised training)
poses_2d -- list of input 2D keypoints, one element for each video
pad -- 2D input padding to compensate for valid convolutions, per side (depends on the receptive field)
causal_shift -- asymmetric padding offset when causal convolutions are used (usually 0 or "pad")
augment -- augment the dataset by flipping poses horizontally
kps_left and kps_right -- list of left/right 2D keypoints if flipping is enabled
joints_left and joints_right -- list of left/right 3D joints if flipping is enabled
"""
def __init__(self, args, cameras, poses_3d, poses_2d, pad=0, causal_shift=0,
augment=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None):
assert poses_3d is None or len(poses_3d) == len(poses_2d)
assert cameras is None or len(cameras) == len(poses_2d)
self.augment = augment
self.kps_left = kps_left
self.kps_right = kps_right
self.joints_left = joints_left
self.joints_right = joints_right
self.pad = pad
self.causal_shift = causal_shift
self.cameras = [] if cameras is None else cameras
self.poses_3d = [] if poses_3d is None else poses_3d
self.poses_2d = poses_2d
self.rank = args.rank
self.num_replicas = args.num_gpus
assert self.num_replicas != 0
self.num_samples = math.ceil(len(poses_2d)/self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def num_frames(self):
count = 0
for p in self.poses_2d:
count += p.shape[0]
return count
def augment_enabled(self):
return self.augment
def set_augment(self, augment):
self.augment = augment
def next_epoch(self):
if self.cameras is not None:
sub_cameras = self.cameras[self.rank:self.total_size:self.num_replicas]
else:
sub_cameras = None
sample_padding = self.total_size - len(self.poses_2d)
assert sample_padding < len(self.poses_2d)
if sample_padding > 0:
p3d = self.poses_3d + self.poses_3d[:sample_padding]
p2d = self.poses_2d + self.poses_2d[:sample_padding]
assert self.total_size == len(p3d)
sub_poses_3d = p3d[self.rank:self.total_size:self.num_replicas]
sub_poses_2d = p2d[self.rank:self.total_size:self.num_replicas]
else:
sub_poses_3d = self.poses_3d[self.rank:self.total_size:self.num_replicas]
sub_poses_2d = self.poses_2d[self.rank:self.total_size:self.num_replicas]
for seq_cam, seq_3d, seq_2d in zip_longest(sub_cameras, sub_poses_3d, sub_poses_2d):
batch_cam = None if seq_cam is None else np.expand_dims(seq_cam, axis=0)
batch_3d = None if seq_3d is None else np.expand_dims(seq_3d, axis=0)
batch_2d = np.expand_dims(np.pad(seq_2d,
((self.pad + self.causal_shift, self.pad - self.causal_shift), (0, 0), (0, 0)),
'edge'), axis=0)
if self.augment:
if batch_cam is not None:
batch_cam = np.concatenate((batch_cam, batch_cam), axis=0)
batch_cam[1, 2] *= -1
batch_cam[1, 7] *= -1
if batch_3d is not None:
batch_3d = np.concatenate((batch_3d, batch_3d), axis=0)
batch_3d[1, :, :, 0] *= -1
batch_3d[1, :, self.joints_left + self.joints_right] = batch_3d[1, :, self.joints_right + self.joints_left]
batch_2d = np.concatenate((batch_2d, batch_2d), axis=0)
batch_2d[1, :, :, 0] *= -1
batch_2d[1, :, self.kps_left + self.kps_right] = batch_2d[1, :, self.kps_right + self.kps_left]
yield batch_cam, batch_3d, batch_2d