diff -Nur ./b/losses/center_loss.py ./a/losses/center_loss.py
--- ./b/losses/center_loss.py	2023-01-05 02:28:47.726602949 +0000
+++ ./a/losses/center_loss.py	2023-01-05 02:45:51.326615490 +0000
@@ -17,11 +17,7 @@
         self.num_classes = num_classes
         self.feat_dim = feat_dim
         self.use_gpu = use_gpu
-
-        if self.use_gpu:
-            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
-        else:
-            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
+        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
 
     def forward(self, x, labels):
         """
diff -Nur ./b/modelling/bases.py ./a/modelling/bases.py
--- ./b/modelling/bases.py	2023-01-05 02:28:47.726602949 +0000
+++ ./a/modelling/bases.py	2023-01-05 02:47:38.838616807 +0000
@@ -170,11 +170,11 @@
     def validation_step(self, batch, batch_idx):
         self.backbone.eval()
         self.bn.eval()
-        x, class_labels, camid, idx = batch
+        x = batch
         with torch.no_grad():
             _, emb = self.backbone(x)
             emb = self.bn(emb)
-        return {"emb": emb, "labels": class_labels, "camid": camid, "idx": idx}
+        return emb
 
     @rank_zero_only
     def validation_create_centroids(
@@ -291,10 +291,6 @@
 
         log_data = {"mAP": mAP}
 
-        # TODO This line below is hacky, but it works when grad_monitoring is active
-        self.trainer.logger_connector.callback_metrics.update(log_data)
-        log_data = {**log_data, **topks}
-        self.trainer.logger.log_metrics(log_data, step=self.trainer.current_epoch)
 
     def validation_epoch_end(self, outputs):
         if self.trainer.global_rank == 0 and self.trainer.local_rank == 0:
@@ -384,7 +380,7 @@
         return masks, labels_list_copy
 
     @rank_zero_only
-    def test_step(self, batch, batch_idx):
+    def test_step(self, batch, batch_idx=1):
         ret = self.validation_step(batch, batch_idx)
         return ret
 
diff -Nur ./b/utils/reid_metric.py ./a/utils/reid_metric.py
--- ./b/utils/reid_metric.py	2023-01-05 02:28:47.730602949 +0000
+++ ./a/utils/reid_metric.py	2023-01-05 02:45:51.326615490 +0000
@@ -74,20 +74,11 @@
         self.num_query = num_query
         self.max_rank = max_rank
         self.feat_norm = feat_norm
-        self.current_epoch = pl_module.trainer.current_epoch
         self.hparms = pl_module.hparams
         self.dist_func = get_dist_func(self.hparms.SOLVER.DISTANCE_FUNC)
         self.pl_module = pl_module
 
-        try:
-            self.save_root_dir = pl_module.trainer.logger.log_dir
-        except:
-            self.save_root_dir = pl_module.trainer.logger[0].log_dir
 
-        try:
-            self.dataset = pl_module.trainer.val_dataloaders[0].dataset.samples
-        except:
-            self.dataset = pl_module.trainer.test_dataloaders[0].dataset.samples
 
     # @staticmethod
     def _commpute_batches_double(self, qf, gf):