diff -Nur ./b/losses/center_loss.py ./a/losses/center_loss.py
@@ -17,11 +17,7 @@
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
-
- if self.use_gpu:
- self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
- else:
- self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
+ self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, x, labels):
"""
diff -Nur ./b/modelling/bases.py ./a/modelling/bases.py
@@ -170,11 +170,11 @@
def validation_step(self, batch, batch_idx):
self.backbone.eval()
self.bn.eval()
- x, class_labels, camid, idx = batch
+ x = batch
with torch.no_grad():
_, emb = self.backbone(x)
emb = self.bn(emb)
- return {"emb": emb, "labels": class_labels, "camid": camid, "idx": idx}
+ return emb
@rank_zero_only
def validation_create_centroids(
@@ -291,10 +291,6 @@
log_data = {"mAP": mAP}
- # TODO This line below is hacky, but it works when grad_monitoring is active
- self.trainer.logger_connector.callback_metrics.update(log_data)
- log_data = {**log_data, **topks}
- self.trainer.logger.log_metrics(log_data, step=self.trainer.current_epoch)
def validation_epoch_end(self, outputs):
if self.trainer.global_rank == 0 and self.trainer.local_rank == 0:
@@ -384,7 +380,7 @@
return masks, labels_list_copy
@rank_zero_only
- def test_step(self, batch, batch_idx):
+ def test_step(self, batch, batch_idx=1):
ret = self.validation_step(batch, batch_idx)
return ret
diff -Nur ./b/utils/reid_metric.py ./a/utils/reid_metric.py
@@ -74,20 +74,11 @@
self.num_query = num_query
self.max_rank = max_rank
self.feat_norm = feat_norm
- self.current_epoch = pl_module.trainer.current_epoch
self.hparms = pl_module.hparams
self.dist_func = get_dist_func(self.hparms.SOLVER.DISTANCE_FUNC)
self.pl_module = pl_module
- try:
- self.save_root_dir = pl_module.trainer.logger.log_dir
- except:
- self.save_root_dir = pl_module.trainer.logger[0].log_dir
- try:
- self.dataset = pl_module.trainer.val_dataloaders[0].dataset.samples
- except:
- self.dataset = pl_module.trainer.test_dataloaders[0].dataset.samples
# @staticmethod
def _commpute_batches_double(self, qf, gf):