Binary files ./b/GMA/core/__pycache__/corr.cpython-37.pyc and ./a/GMA/core/__pycache__/corr.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/extractor.cpython-37.pyc and ./a/GMA/core/__pycache__/extractor.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/gma.cpython-37.pyc and ./a/GMA/core/__pycache__/gma.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/network.cpython-37.pyc and ./a/GMA/core/__pycache__/network.cpython-37.pyc differ
Binary files ./b/GMA/core/__pycache__/update.cpython-37.pyc and ./a/GMA/core/__pycache__/update.cpython-37.pyc differ
diff -Nur ./b/GMA/core/datasets.py ./a/GMA/core/datasets.py
--- ./b/GMA/core/datasets.py	2022-11-15 08:24:39.806162468 +0000
+++ ./a/GMA/core/datasets.py	2022-11-15 06:30:06.850078261 +0000
@@ -59,6 +59,7 @@
             flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
         else:
             flow = frame_utils.read_gen(self.flow_list[index])
+            print(flow.shape)
 
         if self.occ_list is not None:
             occ = frame_utils.read_gen(self.occ_list[index])
@@ -125,7 +126,7 @@
 
 
 class MpiSintel(FlowDataset):
-    def __init__(self, aug_params=None, split='training', root='/home/zac/data/Sintel', dstype='clean',
+    def __init__(self, aug_params=None, split='training', root='./data/Sintel', dstype='clean',
                  occlusion=False, segmentation=False):
         super(MpiSintel, self).__init__(aug_params)
         flow_root = osp.join(root, split, 'flow')
diff -Nur ./b/GMA/core/gma.py ./a/GMA/core/gma.py
--- ./b/GMA/core/gma.py	2022-11-15 08:24:39.806162468 +0000
+++ ./a/GMA/core/gma.py	2022-11-15 01:20:19.705850530 +0000
@@ -58,19 +58,11 @@
 
         q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
         q = self.scale * q
-
-        if self.args.position_only:
-            sim = self.pos_emb(q)
-
-        elif self.args.position_and_content:
-            sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
-            sim_pos = self.pos_emb(q)
-            sim = sim_content + sim_pos
-
-        else:
-            sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
-
-        sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
+        b,h,x,y,d = q.shape
+        q = q.reshape(b,h,x*y,d)
+        b,h,x,y,d = k.shape
+        k = k.reshape(b,h,x*y,d).transpose(2,3)
+        sim = torch.matmul(q, k)
         attn = sim.softmax(dim=-1)
 
         return attn
@@ -104,7 +96,7 @@
 
         v = self.to_v(fmap)
         v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
-        out = einsum('b h i j, b h j d -> b h i d', attn, v)
+        out = torch.matmul(attn, v)
         out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
 
         if self.project is not None:
diff -Nur ./b/GMA/core/network.py ./a/GMA/core/network.py
--- ./b/GMA/core/network.py	2022-11-15 08:24:39.806162468 +0000
+++ ./a/GMA/core/network.py	2022-11-15 01:18:23.925849112 +0000
@@ -69,7 +69,7 @@
         up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
         return up_flow.reshape(N, 2, 8 * H, 8 * W)
 
-    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
+    def forward(self, image1, image2, iters=6, flow_init=None, upsample=True, test_mode=False):
         """ Estimate optical flow between pair of frames """
 
         image1 = 2 * (image1 / 255.0) - 1.0
@@ -121,9 +121,4 @@
             else:
                 flow_up = self.upsample_flow(coords1 - coords0, up_mask)
 
-            flow_predictions.append(flow_up)
-
-        if test_mode:
-            return coords1 - coords0, flow_up
-
-        return flow_predictions
+        return flow_up
Binary files ./b/GMA/core/utils/__pycache__/__init__.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/__init__.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/augmentor.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/augmentor.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/flow_viz.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/flow_viz.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/frame_utils.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/frame_utils.cpython-37.pyc differ
Binary files ./b/GMA/core/utils/__pycache__/utils.cpython-37.pyc and ./a/GMA/core/utils/__pycache__/utils.cpython-37.pyc differ
diff -Nur ./b/GMA/core/utils/utils.py ./a/GMA/core/utils/utils.py
--- ./b/GMA/core/utils/utils.py	2022-11-15 08:24:39.806162468 +0000
+++ ./a/GMA/core/utils/utils.py	2022-11-15 02:32:49.401903823 +0000
@@ -2,6 +2,7 @@
 import torch.nn.functional as F
 import numpy as np
 from scipy import interpolate
+from mmcv.ops.point_sample import bilinear_grid_sample
 # from torch_scatter import scatter_softmax, scatter_add
 
 
@@ -64,7 +65,7 @@
     ygrid = 2*ygrid/(H-1) - 1
 
     grid = torch.cat([xgrid, ygrid], dim=-1)
-    img = F.grid_sample(img, grid, align_corners=True)
+    img = bilinear_grid_sample(img, grid, align_corners=True)
 
     if mask:
         mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)