--- ./torchreid/models/densenet.py 2021-11-18 15:31:38.902515545 +0800
+++ ./densenet.py 2021-11-18 15:34:13.080681447 +0800
@@ -62,7 +62,8 @@
def forward(self, x, p=None):
x = self.base(x)
- x = F.avg_pool2d(x, x.size()[2:])
+ x_shape = [int(s) for s in x.shape[2:]]
+ x = F.avg_pool2d(x, x_shape)
f = x.view(x.size(0), -1)
if self.keyptaware and self.multitask:
f = torch.cat((f, p.float()), dim=1)