diff -uprN origin/SeMask-Segmentation/SeMask-FPN/configs/_base_/models/semfpn_semask_swin.py SeMask-Segmentation/SeMask-FPN/configs/_base_/models/semfpn_semask_swin.py
@@ -1,5 +1,5 @@
# model settings
-norm_cfg = dict(type='SyncBN', requires_grad=True)
+norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
diff -uprN origin/SeMask-Segmentation/SeMask-FPN/mmseg/models/backbones/semask_swin_transformer.py SeMask-Segmentation/SeMask-FPN/mmseg/models/backbones/semask_swin_transformer.py
@@ -48,7 +48,8 @@ def window_partition(x, window_size):
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ x = x.view(B, torch.div(H, window_size, rounding_mode='floor'), window_size,
+ torch.div(W, window_size, rounding_mode='floor'), window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
@@ -64,9 +65,15 @@ def window_reverse(windows, window_size,
Returns:
x: (B, H, W, C)
"""
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ B_divide = torch.true_divide(H * W, window_size)
+ B_divide = torch.true_divide(B_divide, window_size)
+ B = torch.true_divide(windows.shape[0], B_divide)
+ B = B.type(torch.int64)
+ x = torch.reshape(windows, (B, torch.div(H, window_size, rounding_mode='floor'),
+ torch.div(W, window_size, rounding_mode='floor'),
+ window_size, window_size, -1))
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ x = torch.reshape(x, (B, H, W, -1))
return x
@@ -126,7 +133,7 @@ class WindowAttention(nn.Module):
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, torch.div(C, self.num_heads, rounding_mode='floor')).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
@@ -139,7 +146,7 @@ class WindowAttention(nn.Module):
if mask is not None:
nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(torch.div(B_, nW, rounding_mode='floor'), nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@@ -205,7 +212,7 @@ class SwinTransformerBlock(nn.Module):
"""
B, L, C = x.shape
H, W = self.H, self.W
- assert L == H * W, "input feature has wrong size"
+ # assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
@@ -243,8 +250,7 @@ class SwinTransformerBlock(nn.Module):
else:
x = shifted_x
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
+ x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
@@ -394,9 +400,8 @@ class SWSeMaskBlock(nn.Module):
x = shifted_x
sem_map = shifted_sem
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
- sem_map = sem_map[:, :H, :W, :].contiguous()
+ x = x[:, :H, :W, :].contiguous()
+ sem_map = sem_map[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
sem_map = sem_map.view(B, H * W, K)
@@ -478,15 +483,23 @@ class PatchMerging(nn.Module):
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
+ # assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
- pad_input = (H % 2 == 1) or (W % 2 == 1)
- if pad_input:
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
-
+ H_pad = H % 2
+ W_pad = W % 2
+ pad_input = abs(H_pad * W_pad - abs(H_pad - W_pad))
+ zero = torch.zeros(1, dtype=torch.int64)
+ pad_input = torch.add(zero, pad_input)
+ one = F.one_hot(pad_input, num_classes=2)
+ one = torch.reshape(one, (2, 1, 1, 1, 1))
+ all_x = torch.stack([x, F.pad(x, (0, 0, 0, W % 2, 0, H % 2))])
+ x = torch.mul(one, all_x)
+ x_1 = x[0, :, :, :, :]
+ x_2 = x[1, :, :, :, :]
+ x = x_1 + x_2
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
@@ -595,8 +608,10 @@ class BasicLayer(nn.Module):
"""
# calculate attention mask for SW-MSA
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ Hp = np.ceil(torch.true_divide(H, self.window_size)) * self.window_size
+ Wp = np.ceil(torch.true_divide(W, self.window_size)) * self.window_size
+ Hp = Hp.type(torch.int32)
+ Wp = Wp.type(torch.int32)
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
@@ -630,7 +645,7 @@ class BasicLayer(nn.Module):
if self.downsample is not None:
x_down = self.downsample(x, H, W)
- Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ Wh, Ww = torch.div(H + 1, 2, rounding_mode="floor"), torch.div(W + 1, 2, rounding_mode="floor")
return x, seg_map, H, W, x_down, Wh, Ww
else:
return x, seg_map, H, W, x, H, W
@@ -662,13 +677,6 @@ class PatchEmbed(nn.Module):
def forward(self, x):
"""Forward function."""
- # padding
- _, _, H, W = x.size()
- if W % self.patch_size[1] != 0:
- x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
- if H % self.patch_size[0] != 0:
- x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
-
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
diff -uprN origin/SeMask-Segmentation/SeMask-FPN/mmseg/models/segmentors/encoder_decoder.py SeMask-Segmentation/SeMask-FPN/mmseg/models/segmentors/encoder_decoder.py
@@ -272,8 +272,6 @@ class EncoderDecoder(BaseSegmentor):
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
- # our inference backend only support 4D output
- seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
diff -uprN origin/SeMask-Segmentation/SeMask-FPN/mmseg/ops/wrappers.py SeMask-Segmentation/SeMask-FPN/mmseg/ops/wrappers.py
@@ -24,8 +24,7 @@ def resize(input,
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
- if isinstance(size, torch.Size):
- size = tuple(int(x) for x in size)
+
return F.interpolate(input, size, scale_factor, mode, align_corners)