05360171创建于 2022年3月18日历史提交
diff --git a/lib/PraNet_Res2Net.py b/lib/PraNet_Res2Net.py
index 25c723c..5786508 100644
--- a/lib/PraNet_Res2Net.py
+++ b/lib/PraNet_Res2Net.py
@@ -100,7 +100,7 @@ class PraNet(nn.Module):
     def __init__(self, channel=32):
         super(PraNet, self).__init__()
         # ---- ResNet Backbone ----
-        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
+        self.resnet = res2net50_v1b_26w_4s(pretrained=False)
         # ---- Receptive Field Block like module ----
         self.rfb2_1 = RFB_modified(512, channel)
         self.rfb3_1 = RFB_modified(1024, channel)
@@ -140,10 +140,13 @@ class PraNet(nn.Module):
         x4_rfb = self.rfb4_1(x4)        # channel -> 32
 
         ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb)
+        ra5_feat = ra5_feat.reshape(ra5_feat.size(1), ra5_feat.size(0), ra5_feat.size(2), ra5_feat.size(3))
         lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear')    # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+        lateral_map_5 = lateral_map_5.reshape(lateral_map_5.size(1), lateral_map_5.size(0), lateral_map_5.size(2), lateral_map_5.size(3))
 
         # ---- reverse attention branch_4 ----
         crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear')
+        crop_4 = crop_4.reshape(crop_4.size(1), crop_4.size(0), crop_4.size(2), crop_4.size(3))
         x = -1*(torch.sigmoid(crop_4)) + 1
         x = x.expand(-1, 2048, -1, -1).mul(x4)
         x = self.ra4_conv1(x)
@@ -152,10 +155,13 @@ class PraNet(nn.Module):
         x = F.relu(self.ra4_conv4(x))
         ra4_feat = self.ra4_conv5(x)
         x = ra4_feat + crop_4
-        lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear')  # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
+        temp = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+        lateral_map_4 = F.interpolate(temp, scale_factor=32, mode='bilinear')  # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
+        lateral_map_4 = lateral_map_4.reshape(lateral_map_4.size(1), lateral_map_4.size(0), lateral_map_4.size(2), lateral_map_4.size(3))
 
         # ---- reverse attention branch_3 ----
-        crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear')
+        crop_3 = F.interpolate(temp, scale_factor=2, mode='bilinear')
+        crop_3 = crop_3.reshape(crop_3.size(1), crop_3.size(0), crop_3.size(2), crop_3.size(3))
         x = -1*(torch.sigmoid(crop_3)) + 1
         x = x.expand(-1, 1024, -1, -1).mul(x3)
         x = self.ra3_conv1(x)
@@ -163,10 +169,13 @@ class PraNet(nn.Module):
         x = F.relu(self.ra3_conv3(x))
         ra3_feat = self.ra3_conv4(x)
         x = ra3_feat + crop_3
-        lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear')  # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
+        temp1 = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+        lateral_map_3 = F.interpolate(temp1, scale_factor=16, mode='bilinear')  # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
+        lateral_map_3 = lateral_map_3.reshape(lateral_map_3.size(1), lateral_map_3.size(0), lateral_map_3.size(2), lateral_map_3.size(3))
 
         # ---- reverse attention branch_2 ----
-        crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear')
+        crop_2 = F.interpolate(temp1, scale_factor=2, mode='bilinear')
+        crop_2 = crop_2.reshape(crop_2.size(1), crop_2.size(0), crop_2.size(2), crop_2.size(3))
         x = -1*(torch.sigmoid(crop_2)) + 1
         x = x.expand(-1, 512, -1, -1).mul(x2)
         x = self.ra2_conv1(x)
@@ -174,7 +183,9 @@ class PraNet(nn.Module):
         x = F.relu(self.ra2_conv3(x))
         ra2_feat = self.ra2_conv4(x)
         x = ra2_feat + crop_2
-        lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear')   # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+        temp2 = x.reshape(x.size(1), x.size(0), x.size(2), x.size(3))
+        lateral_map_2 = F.interpolate(temp2, scale_factor=8, mode='bilinear')   # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
+        lateral_map_2 = lateral_map_2.reshape(lateral_map_2.size(1), lateral_map_2.size(0), lateral_map_2.size(2), lateral_map_2.size(3))
 
         return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2
 
@@ -183,4 +194,4 @@ if __name__ == '__main__':
     ras = PraNet().cuda()
     input_tensor = torch.randn(1, 3, 352, 352).cuda()
 
-    out = ras(input_tensor)
\ No newline at end of file
+    out = ras(input_tensor)