import copy
import torch
from torch.autograd import Variable
import torch.nn.functional as F
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class WaveGlowLoss(torch.nn.Module):
def __init__(self, sigma=1.0):
super(WaveGlowLoss, self).__init__()
self.sigma = sigma
def forward(self, model_output):
z, log_s_list, log_det_W_list = model_output
for i, log_s in enumerate(log_s_list):
if i == 0:
log_s_total = torch.sum(log_s)
log_det_W_total = log_det_W_list[i]
else:
log_s_total = log_s_total + torch.sum(log_s)
log_det_W_total += log_det_W_list[i]
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
return loss/(z.size(0)*z.size(1)*z.size(2))
class Invertible1x1Conv(torch.nn.Module):
"""
The layer outputs both the convolution, and the log determinant
of its weight matrix. If reverse=True it does convolution with
inverse
"""
def __init__(self, c):
super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv2d(c, c, kernel_size=(1,1), stride=1, padding=0,
bias=False)
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
if torch.det(W) < 0:
W[:,0] = -1*W[:,0]
W = W.view(c, c, 1)
self.conv.weight.data = W.unsqueeze_(-1)
def forward(self, z, reverse=False):
batch_size, group_size, n_of_groups = z.size()
W = self.conv.weight.squeeze()
if reverse:
if not hasattr(self, 'W_inverse'):
W_inverse = W.float().inverse()
W_inverse = Variable(W_inverse[..., None])
if z.type() == 'torch.npu.HalfTensor':
W_inverse = W_inverse.half()
self.W_inverse = W_inverse
z = F.conv2d(z, self.W_inverse, bias=None, stride=1, padding=0)
return z
else:
log_det_W = (batch_size * n_of_groups * torch.logdet(W.cpu().float())).npu().half()
z.unsqueeze_(-1)
z = self.conv(z)
z.squeeze_(-1)
return z, log_det_W
class WN(torch.nn.Module):
"""
This is the WaveNet like layer for the affine coupling. The primary difference
from WaveNet is the convolutions need not be causal. There is also no dilation
size reset. The dilation only doubles on each layer
"""
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
kernel_size):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
assert(n_channels % 2 == 0)
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
start = torch.nn.Conv2d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight')
self.start = start
end = torch.nn.Conv2d(n_channels, 2*n_in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end
cond_layer = torch.nn.Conv2d(n_mel_channels, 2*n_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers):
dilation = 2 ** i
padding = int((kernel_size*dilation - dilation)/2)
in_layer = torch.nn.Conv2d(n_channels, 2*n_channels, (kernel_size,1),
dilation=(dilation, 1), padding=(padding, 0))
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
if i < n_layers - 1:
res_skip_channels = 2 * n_channels
else:
res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv2d(n_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, forward_input):
audio, spect = forward_input
audio = self.start(torch.unsqueeze(audio, -1)).squeeze_(-1)
output = torch.zeros_like(audio)
n_channels_tensor = torch.IntTensor([self.n_channels])
spect = self.cond_layer(torch.unsqueeze(spect, -1)).squeeze_(-1)
for i in range(self.n_layers):
spect_offset = i*2*self.n_channels
input_a = self.in_layers[i](torch.unsqueeze(audio, -1))
input_a = input_a.squeeze_(-1)
acts = fused_add_tanh_sigmoid_multiply(
input_a,
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts.unsqueeze_(-1))
res_skip_acts = res_skip_acts.squeeze_(-1)
if i < self.n_layers - 1:
audio = audio + res_skip_acts[:,:self.n_channels,:]
output = output + res_skip_acts[:,self.n_channels:,:]
else:
output = output + res_skip_acts
return self.end(output.unsqueeze_(-1)).squeeze_(-1)
class ConvTranse1D(torch.nn.ConvTranspose1d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros'):
super(ConvTranse1D, self).__init__(
in_channels, out_channels, kernel_size, stride,
padding, output_padding, groups, bias,
dilation, padding_mode)
def forward(self, input, output_size=None):
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
output_padding = self._output_padding(input, output_size, self.stride,
self.padding, self.kernel_size)
return F.conv_transpose1d(input,
self.weight,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation)
class WaveGlow(torch.nn.Module):
def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
n_early_size, WN_config):
super(WaveGlow, self).__init__()
self.upsample = ConvTranse1D(n_mel_channels, n_mel_channels, 1024, stride=256)
assert(n_group % 2 == 0)
self.n_flows = n_flows
self.n_group = n_group
self.n_early_every = n_early_every
self.n_early_size = n_early_size
self.WN = torch.nn.ModuleList()
self.convinv = torch.nn.ModuleList()
n_half = int(n_group/2)
n_remaining_channels = n_group
for k in range(n_flows):
if k % self.n_early_every == 0 and k > 0:
n_half = n_half - int(self.n_early_size/2)
n_remaining_channels = n_remaining_channels - self.n_early_size
self.convinv.append(Invertible1x1Conv(n_remaining_channels))
self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
self.n_remaining_channels = n_remaining_channels
def forward(self, forward_input):
"""
forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
forward_input[1] = audio: batch x time
"""
spect, audio = forward_input
spect = self.upsample(spect)
assert(spect.size(2) >= audio.size(1))
if spect.size(2) > audio.size(1):
spect = spect.cpu().float()
spect = spect[:, :, :audio.size(1)]
spect = spect.unfold(2, self.n_group, self.n_group)
spect = spect.npu().half()
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
output_audio = []
log_s_list = []
log_det_W_list = []
for k in range(self.n_flows):
if k % self.n_early_every == 0 and k > 0:
output_audio.append(audio[:,:self.n_early_size,:])
audio = audio[:,self.n_early_size:,:]
audio, log_det_W = self.convinv[k](audio)
log_det_W_list.append(log_det_W)
n_half = int(audio.size(1)/2)
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
output = self.WN[k]((audio_0, spect))
log_s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = torch.exp(log_s)*audio_1 + b
log_s_list.append(log_s)
audio = torch.cat([audio_0, audio_1],1)
output_audio.append(audio)
return torch.cat(output_audio,1), log_s_list, log_det_W_list
def infer(self, spect, sigma=1.0):
spect = self.upsample(spect)
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
spect = spect[:, :, :-time_cutoff]
spect = spect.cpu()
spect = spect.unfold(2, self.n_group, self.n_group)
spect = spect.npu()
spect = spect.permute(0, 2, 1, 3)
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
if spect.type() == 'torch.npu.HalfTensor':
audio = torch.npu.HalfTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
else:
audio = torch.npu.FloatTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
audio = torch.autograd.Variable(sigma*audio)
for k in reversed(range(self.n_flows)):
n_half = int(audio.size(1)/2)
audio_0 = audio[:,:n_half,:]
audio_1 = audio[:,n_half:,:]
output = self.WN[k]((audio_0, spect))
s = output[:, n_half:, :]
b = output[:, :n_half, :]
audio_1 = (audio_1 - b)/torch.exp(s)
audio = torch.cat([audio_0, audio_1],1)
audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0:
if spect.type() == 'torch.npu.HalfTensor':
z = torch.npu.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
else:
z = torch.npu.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
audio = torch.cat((sigma*z, audio),1)
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
return audio
@staticmethod
def remove_weightnorm(model):
waveglow = model
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow
def remove(conv_list):
new_conv_list = torch.nn.ModuleList()
for old_conv in conv_list:
old_conv = torch.nn.utils.remove_weight_norm(old_conv)
new_conv_list.append(old_conv)
return new_conv_list