@@ -100,7 +100,7 @@ class PraNet(nn.Module):
def __init__(self, channel=32):
super(PraNet, self).__init__()
# ---- ResNet Backbone ----
- self.resnet = res2net50_v1b_26w_4s(pretrained=True)
+ self.resnet = res2net50_v1b_26w_4s(pretrained=False)
# ---- Receptive Field Block like module ----
self.rfb2_1 = RFB_modified(512, channel)
self.rfb3_1 = RFB_modified(1024, channel)
@@ -140,10 +140,13 @@ class PraNet(nn.Module):
x4_rfb = self.rfb4_1(x4) # channel -> 32
ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb)
+ ra5_feat = ra5_feat.reshape(ra5_feat.size(1), ra5_feat.size(0), ra5_feat.size(2), ra5_feat.size(3))
lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+ lateral_map_5 = lateral_map_5.reshape(lateral_map_5.size(1), lateral_map_5.size(0), lateral_map_5.size(2), lateral_map_5.size(3))
# ---- reverse attention branch_4 ----
crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear')
+ crop_4 = crop_4.reshape(crop_4.size(1), crop_4.size(0), crop_4.size(2), crop_4.size(3))
x = -1*(torch.sigmoid(crop_4)) + 1
x = x.expand(-1, 2048, -1, -1).mul(x4)
x = self.ra4_conv1(x)
@@ -152,10 +155,13 @@ class PraNet(nn.Module):
x = F.relu(self.ra4_conv4(x))
ra4_feat = self.ra4_conv5(x)
x = ra4_feat + crop_4
- lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
+ temp = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+ lateral_map_4 = F.interpolate(temp, scale_factor=32, mode='bilinear') # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
+ lateral_map_4 = lateral_map_4.reshape(lateral_map_4.size(1), lateral_map_4.size(0), lateral_map_4.size(2), lateral_map_4.size(3))
# ---- reverse attention branch_3 ----
- crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear')
+ crop_3 = F.interpolate(temp, scale_factor=2, mode='bilinear')
+ crop_3 = crop_3.reshape(crop_3.size(1), crop_3.size(0), crop_3.size(2), crop_3.size(3))
x = -1*(torch.sigmoid(crop_3)) + 1
x = x.expand(-1, 1024, -1, -1).mul(x3)
x = self.ra3_conv1(x)
@@ -163,10 +169,13 @@ class PraNet(nn.Module):
x = F.relu(self.ra3_conv3(x))
ra3_feat = self.ra3_conv4(x)
x = ra3_feat + crop_3
- lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
+ temp1 = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+ lateral_map_3 = F.interpolate(temp1, scale_factor=16, mode='bilinear') # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
+ lateral_map_3 = lateral_map_3.reshape(lateral_map_3.size(1), lateral_map_3.size(0), lateral_map_3.size(2), lateral_map_3.size(3))
# ---- reverse attention branch_2 ----
- crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear')
+ crop_2 = F.interpolate(temp1, scale_factor=2, mode='bilinear')
+ crop_2 = crop_2.reshape(crop_2.size(1), crop_2.size(0), crop_2.size(2), crop_2.size(3))
x = -1*(torch.sigmoid(crop_2)) + 1
x = x.expand(-1, 512, -1, -1).mul(x2)
x = self.ra2_conv1(x)
@@ -174,7 +183,9 @@ class PraNet(nn.Module):
x = F.relu(self.ra2_conv3(x))
ra2_feat = self.ra2_conv4(x)
x = ra2_feat + crop_2
- lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+ temp2 = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+ lateral_map_2 = F.interpolate(temp2, scale_factor=8, mode='bilinear') # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+ lateral_map_2 = lateral_map_2.reshape(lateral_map_2.size(1), lateral_map_2.size(0), lateral_map_2.size(2), lateral_map_2.size(3))
return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2
@@ -183,4 +194,4 @@ if __name__ == '__main__':
ras = PraNet().cuda()
input_tensor = torch.randn(1, 3, 352, 352).cuda()
- out = ras(input_tensor)
\ No newline at end of file
+ out = ras(input_tensor)