809918bf创建于 2022年12月20日历史提交
diff --git a/config.json b/config.json
index 7160474..608553c 100755
--- a/config.json
+++ b/config.json
@@ -12,7 +12,7 @@
         "with_tensorboard": false
     },
     "data_config": {
-        "training_files": "train_files.txt",
+        "training_files": "test_files.txt",
         "segment_length": 16000,
         "sampling_rate": 22050,
         "filter_length": 1024,
diff --git a/glow.py b/glow.py
index 7a76964..89fc584 100644
--- a/glow.py
+++ b/glow.py
@@ -248,48 +248,49 @@ class WaveGlow(torch.nn.Module):
         output_audio.append(audio)
         return torch.cat(output_audio,1), log_s_list, log_det_W_list
 
-    def infer(self, spect, sigma=1.0):
+    def infer(self, spect, sigma=0.9):
+        stride = 256
+        n_group = 8
+        z_size2 = (spect.size(2)*stride)//n_group
+        z = torch.randn(spect.size(0), n_group, z_size2)
         spect = self.upsample(spect)
         # trim conv artifacts. maybe pad spec to kernel multiple
         time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
         spect = spect[:, :, :-time_cutoff]
 
-        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
-        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
-
-        if spect.type() == 'torch.cuda.HalfTensor':
-            audio = torch.cuda.HalfTensor(spect.size(0),
-                                          self.n_remaining_channels,
-                                          spect.size(2)).normal_()
-        else:
-            audio = torch.cuda.FloatTensor(spect.size(0),
-                                           self.n_remaining_channels,
-                                           spect.size(2)).normal_()
+        mel_dim = 80
+        batch_size = spect.size(0)
+        length_spect_group = spect.size(2)//8
+        spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
+        spect = spect.permute(0, 2, 1, 3)
+        spect = spect.contiguous()
+        spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
+        spect = spect.permute(0, 2, 1)
+        spect = spect.contiguous()
 
-        audio = torch.autograd.Variable(sigma*audio)
+        audio = z[:, :self.n_remaining_channels, :]
+        z = z[:, self.n_remaining_channels:self.n_group, :]
+        audio = sigma*audio
 
         for k in reversed(range(self.n_flows)):
-            n_half = int(audio.size(1)/2)
-            audio_0 = audio[:,:n_half,:]
-            audio_1 = audio[:,n_half:,:]
+            n_half = int(audio.size(1) // 2)
+            audio_0 = audio[:, :n_half, :]
+            audio_1 = audio[:, n_half:(n_half+n_half), :]
 
             output = self.WN[k]((audio_0, spect))
 
-            s = output[:, n_half:, :]
+            s = output[:, n_half:(n_half+n_half), :]
             b = output[:, :n_half, :]
-            audio_1 = (audio_1 - b)/torch.exp(s)
-            audio = torch.cat([audio_0, audio_1],1)
+            audio_1 = (audio_1 - b) / torch.exp(s)
+            audio = torch.cat([audio_0, audio_1], 1)
 
             audio = self.convinv[k](audio, reverse=True)
 
             if k % self.n_early_every == 0 and k > 0:
-                if spect.type() == 'torch.cuda.HalfTensor':
-                    z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
-                else:
-                    z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
-                audio = torch.cat((sigma*z, audio),1)
+                audio = torch.cat((z[:, :self.n_early_size, :], audio), 1)
+                z = z[:, self.n_early_size:self.n_group, :]
 
-        audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
+        audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
         return audio
 
     @staticmethod
diff --git a/requirements.txt b/requirements.txt
index 20c8b3e..ae27cbb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
-torch==1.0
-matplotlib==2.1.0
+torch==1.8.0
+matplotlib
 tensorflow
-numpy==1.13.3
+numpy
 inflect==0.2.5
 librosa==0.6.0
-scipy==1.0.0
+scipy
 tensorboardX==1.1
 Unidecode==1.0.22
 pillow