Binary files ./b/GMA/core/__pycache__/corr.cpython-37.pyc and ./a/GMA/core/__pycache__/corr.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/extractor.cpython-37.pyc and ./a/GMA/core/__pycache__/extractor.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/gma.cpython-37.pyc and ./a/GMA/core/__pycache__/gma.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/network.cpython-37.pyc and ./a/GMA/core/__pycache__/network.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/update.cpython-37.pyc and ./a/GMA/core/__pycache__/update.cpython-37.pyc differ
diff -Nur ./b/GMA/core/datasets.py ./a/GMA/core/datasets.py
@@ -59,6 +59,7 @@
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
+ print(flow.shape)
if self.occ_list is not None:
occ = frame_utils.read_gen(self.occ_list[index])
@@ -125,7 +126,7 @@
class MpiSintel(FlowDataset):
- def __init__(self, aug_params=None, split='training', root='/home/zac/data/Sintel', dstype='clean',
+ def __init__(self, aug_params=None, split='training', root='./data/Sintel', dstype='clean',
occlusion=False, segmentation=False):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
diff -Nur ./b/GMA/core/gma.py ./a/GMA/core/gma.py
@@ -58,19 +58,11 @@
q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
q = self.scale * q
-
- if self.args.position_only:
- sim = self.pos_emb(q)
-
- elif self.args.position_and_content:
- sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
- sim_pos = self.pos_emb(q)
- sim = sim_content + sim_pos
-
- else:
- sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
-
- sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
+ b,h,x,y,d = q.shape
+ q = q.reshape(b,h,x*y,d)
+ b,h,x,y,d = k.shape
+ k = k.reshape(b,h,x*y,d).transpose(2,3)
+ sim = torch.matmul(q, k)
attn = sim.softmax(dim=-1)
return attn
@@ -104,7 +96,7 @@
v = self.to_v(fmap)
v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
- out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = torch.matmul(attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
if self.project is not None:
diff -Nur ./b/GMA/core/network.py ./a/GMA/core/network.py
@@ -69,7 +69,7 @@
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
- def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
+ def forward(self, image1, image2, iters=6, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
@@ -121,9 +121,4 @@
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
- flow_predictions.append(flow_up)
-
- if test_mode:
- return coords1 - coords0, flow_up
-
- return flow_predictions
+ return flow_up
Binary files ./b/GMA/core/utils/__pycache__/__init__.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/__init__.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/augmentor.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/augmentor.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/flow_viz.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/flow_viz.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/frame_utils.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/frame_utils.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/utils.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/utils.cpython-37.pyc differ
diff -Nur ./b/GMA/core/utils/utils.py ./a/GMA/core/utils/utils.py
@@ -2,6 +2,7 @@
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
+from mmcv.ops.point_sample import bilinear_grid_sample
# from torch_scatter import scatter_softmax, scatter_add
@@ -64,7 +65,7 @@
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
- img = F.grid_sample(img, grid, align_corners=True)
+ img = bilinear_grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)