05360171创建于 2022年3月18日历史提交
diff --git a/models.py b/models.py
index 9b04cfc..8b7d788 100755
--- a/models.py
+++ b/models.py
@@ -172,8 +172,12 @@ class RCF(nn.Module):
 
 def crop(variable, th, tw):
         h, w = variable.shape[2], variable.shape[3]
-        x1 = int(round((w - tw) / 2.))
-        y1 = int(round((h - th) / 2.))
+        if isinstance(th, torch.Tensor):
+            x1 = int(torch.round((w - tw) / 2.))
+            y1 = int(torch.round((h - th) / 2.))
+        else:
+            x1 = int(round((w - tw) / 2.))
+            y1 = int(round((h - th) / 2.))
         return variable[:, :, y1 : y1 + th, x1 : x1 + tw]