From 2d6205ea3f2b1b61f4eb3063ca3ccf49138c4465 Mon Sep 17 00:00:00 2001
From: yinin
Date: Sat, 26 Mar 2022 16:26:20 +0800
Subject: [PATCH] patch

---
 .../RawNet2/Pre-trained_model/model_RawNet2_original_code.py  | 2 +-
 python/RawNet2/dataloader.py                                  | 4 ++--
 python/RawNet2/parser.py                                      | 2 +-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py
index c5981fc..9e3df1d 100755
--- a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py
+++ b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py
@@ -265,7 +265,7 @@ class RawNet(nn.Module):
         
         self.sig = nn.Sigmoid()
         
-    def forward(self, x, y = 0, is_test=False, is_TS=False):
+    def forward(self, x, y = 0, is_test=True, is_TS=False):
         #follow sincNet recipe
         nb_samp = x.shape[0]
         len_seq = x.shape[1]
diff --git a/python/RawNet2/dataloader.py b/python/RawNet2/dataloader.py
index c1791e2..3dfa802 100644
--- a/python/RawNet2/dataloader.py
+++ b/python/RawNet2/dataloader.py
@@ -113,8 +113,8 @@ class TA_Dataset_VoxCeleb2(data.Dataset):
 
 		if not self.return_label:
 			return list_X
-		y = self.labels[ID.split('/')[0]]
-		return list_X, y 
+		#y = self.labels[ID.split('/')[0]]
+		return list_X, ID 
 
 	def _normalize_scale(self, x):
 		'''
diff --git a/python/RawNet2/parser.py b/python/RawNet2/parser.py
index bea9112..4c15f34 100644
--- a/python/RawNet2/parser.py
+++ b/python/RawNet2/parser.py
@@ -14,7 +14,7 @@ def str2bool(v):
 def get_args():
     parser = argparse.ArgumentParser()
     #dir
-    parser.add_argument('-name', type = str, required = True)
+    parser.add_argument('-name', type = str, default = 'rawnet2')
     parser.add_argument('-save_dir', type = str, default = 'DNNs/')
     parser.add_argument('-DB', type = str, default = 'DB/VoxCeleb1/')
     parser.add_argument('-DB_vox2', type = str, default = 'DB/VoxCeleb2/')
-- 
2.33.0.windows.2