diff --git a/models.py b/models.py
index 9b04cfc..8b7d788 100755
--- a/models.py
+++ b/models.py
@@ -172,8 +172,12 @@ class RCF(nn.Module):
def crop(variable, th, tw):
h, w = variable.shape[2], variable.shape[3]
- x1 = int(round((w - tw) / 2.))
- y1 = int(round((h - th) / 2.))
+ if isinstance(th, torch.Tensor):
+ x1 = int(torch.round((w - tw) / 2.))
+ y1 = int(torch.round((h - th) / 2.))
+ else:
+ x1 = int(round((w - tw) / 2.))
+ y1 = int(round((h - th) / 2.))
return variable[:, :, y1 : y1 + th, x1 : x1 + tw]