config/taichi-256.yaml | 2 +-
logger.py | 7 ++--
modules/dense_motion.py | 6 ++--
modules/generator.py | 5 +--
reconstruction.py | 79 ++++++++++++++++++++++++++++++++++-------
5 files changed, 77 insertions(+), 22 deletions(-)
@@ -9,7 +9,7 @@
# video id.
dataset_params:
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
- root_dir: data/taichi-png
+ root_dir: data/taichi
# Image shape, needed for staked .png format.
frame_shape: [256, 256, 3]
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
@@ -2,6 +2,8 @@ from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
+from point_grid_my import bilinear_grid_sample
+from inverse_without_cat import invmat
class DenseMotionNetwork(nn.Module):
@@ -53,7 +55,7 @@ class DenseMotionNetwork(nn.Module):
identity_grid = identity_grid.view(1, 1, h, w, 2)
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
if 'jacobian' in kp_driving:
- jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
+ jacobian = torch.matmul(kp_source['jacobian'], invmat(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
@@ -74,7 +76,7 @@ class DenseMotionNetwork(nn.Module):
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
- sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
+ sparse_deformed = bilinear_grid_sample(source_repeat, sparse_motions)
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
return sparse_deformed
@@ -3,6 +3,7 @@ from torch import nn
import torch.nn.functional as F
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
from modules.dense_motion import DenseMotionNetwork
+from point_grid_my import bilinear_grid_sample
class OcclusionAwareGenerator(nn.Module):
@@ -54,7 +55,7 @@ class OcclusionAwareGenerator(nn.Module):
deformation = deformation.permute(0, 3, 1, 2)
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
deformation = deformation.permute(0, 2, 3, 1)
- return F.grid_sample(inp, deformation)
+ return bilinear_grid_sample(inp, deformation, align_corners=True)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
@@ -80,7 +81,7 @@ class OcclusionAwareGenerator(nn.Module):
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
- occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear', align_corners=True)
out = out * occlusion_map
output_dict["deformed"] = self.deform_input(source_image, deformation)
--
2.39.0.windows.2