import torch
from torch import nn
from torch.nn import functional as F
import argparse
import sys
sys.path.append('./')
import models
from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model, prepare_input_sequence
from common.utils import to_gpu, get_mask_from_lengths
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('--tacotron2', type=str,
help='full path to the Tacotron2 model checkpoint file')
parser.add_argument('-o', '--output', type=str, required=True,
help='Directory for the exported Tacotron 2 ONNX model')
parser.add_argument('--fp16', action='store_true',
help='Export with half precision to ONNX')
parser.add_argument('-bs', '--batch_size', type=int, default=4,
help='Batch size')
parser.add_argument('-t', '--iter', type=int, default=1,
help='layers of one decoder net')
return parser
def encoder_infer(self, x, input_lengths):
device = x.device
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x.to(device))), 0.5, False)
x = x.transpose(1, 2)
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
lens = input_lengths*2
return outputs, lens
class Encoder(torch.nn.Module):
def __init__(self, tacotron2):
super(Encoder, self).__init__()
self.tacotron2 = tacotron2
self.tacotron2.encoder.lstm.flatten_parameters()
self.infer = encoder_infer
dec = tacotron2.decoder
self.p_attention_dropout = dec.p_attention_dropout
self.p_decoder_dropout = dec.p_decoder_dropout
self.prenet = dec.prenet
self.prenet.infer = prenet_infer
self.attention_rnn = nn.LSTM(dec.prenet_dim + dec.encoder_embedding_dim,
dec.attention_rnn_dim, 1)
lstmcell2lstm_params(self.attention_rnn, dec.attention_rnn)
self.attention_rnn.flatten_parameters()
self.attention_layer = dec.attention_layer
self.decoder_rnn = nn.LSTM(dec.attention_rnn_dim + dec.encoder_embedding_dim,
dec.decoder_rnn_dim, 1)
lstmcell2lstm_params(self.decoder_rnn, dec.decoder_rnn)
self.decoder_rnn.flatten_parameters()
self.linear_projection = dec.linear_projection
self.gate_layer = dec.gate_layer
def decode(self, decoder_input, in_attention_hidden, in_attention_cell,
in_decoder_hidden, in_decoder_cell, in_attention_weights,
in_attention_weights_cum, in_attention_context, memory,
processed_memory, mask):
cell_input = torch.cat((decoder_input, in_attention_context), -1)
_, (out_attention_hidden, out_attention_cell) = self.attention_rnn(
cell_input.unsqueeze(0), (in_attention_hidden.unsqueeze(0),
in_attention_cell.unsqueeze(0)))
out_attention_hidden = out_attention_hidden.squeeze(0)
out_attention_cell = out_attention_cell.squeeze(0)
out_attention_hidden = F.dropout(
out_attention_hidden, self.p_attention_dropout, False)
attention_weights_cat = torch.cat(
(in_attention_weights.unsqueeze(1),
in_attention_weights_cum.unsqueeze(1)), dim=1)
out_attention_context, out_attention_weights = self.attention_layer(
out_attention_hidden, memory, processed_memory,
attention_weights_cat, mask)
out_attention_weights_cum = in_attention_weights_cum + out_attention_weights
decoder_input_tmp = torch.cat(
(out_attention_hidden, out_attention_context), -1)
_, (out_decoder_hidden, out_decoder_cell) = self.decoder_rnn(
decoder_input_tmp.unsqueeze(0), (in_decoder_hidden.unsqueeze(0),
in_decoder_cell.unsqueeze(0)))
out_decoder_hidden = out_decoder_hidden.squeeze(0)
out_decoder_cell = out_decoder_cell.squeeze(0)
out_decoder_hidden = F.dropout(
out_decoder_hidden, self.p_decoder_dropout, False)
decoder_hidden_attention_context = torch.cat(
(out_decoder_hidden, out_attention_context), 1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return (decoder_output, gate_prediction, out_attention_hidden,
out_attention_cell, out_decoder_hidden, out_decoder_cell,
out_attention_weights, out_attention_weights_cum, out_attention_context)
def forward(self, sequence,
sequence_lengths,
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
mask,
not_finished,
mel_lengths):
embedded_inputs = self.tacotron2.embedding(sequence).transpose(1, 2)
memory, lens = self.infer(self.tacotron2.encoder, embedded_inputs, sequence_lengths)
processed_memory = self.tacotron2.decoder.attention_layer.memory_layer(memory)
gate_threshold = 0.5
decoder_input_pr = self.prenet.infer(self.prenet, decoder_input)
(mel_output, gate_output,
attention_hidden, attention_cell,
decoder_hidden, decoder_cell,
attention_weights, attention_weights_cum,
attention_context) = self.decode(decoder_input_pr, attention_hidden, attention_cell, decoder_hidden,
decoder_cell, attention_weights, attention_weights_cum,
attention_context, memory, processed_memory, mask)
mel_outputs = mel_output.unsqueeze(2)
gate_outputs = gate_output.unsqueeze(2)
t = torch.sigmoid(gate_output)
dec = torch.le(t, gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished * dec
mel_lengths += not_finished
decoder_input = mel_output
return (mel_output, attention_hidden,
attention_cell, decoder_hidden, decoder_cell,
attention_weights, attention_weights_cum, attention_context, memory,
processed_memory, mask, mel_outputs, gate_outputs, not_finished, mel_lengths)
class Postnet(torch.nn.Module):
def __init__(self, tacotron2):
super(Postnet, self).__init__()
self.tacotron2 = tacotron2
def forward(self, mel_outputs):
mel_outputs_postnet = self.tacotron2.postnet(mel_outputs)
return mel_outputs + mel_outputs_postnet
def lstmcell2lstm_params(lstm_mod, lstmcell_mod):
lstm_mod.weight_ih_l0 = torch.nn.Parameter(lstmcell_mod.weight_ih)
lstm_mod.weight_hh_l0 = torch.nn.Parameter(lstmcell_mod.weight_hh)
lstm_mod.bias_ih_l0 = torch.nn.Parameter(lstmcell_mod.bias_ih)
lstm_mod.bias_hh_l0 = torch.nn.Parameter(lstmcell_mod.bias_hh)
def prenet_infer(self, x):
x1 = x[:]
for linear in self.layers:
x1 = F.relu(linear(x1))
x0 = x1[0].unsqueeze(0)
mask = torch.le(torch.rand(256, device='cpu').to(x.dtype), 0.5).to(x.dtype)
tmp = torch.ones([x1.size(0), x1.size(1)])
mask = mask.expand_as(tmp)
x1 = x1*mask*2.0
return x1
class DecoderIter(torch.nn.Module):
def __init__(self, tacotron2, iter):
super(DecoderIter, self).__init__()
self.tacotron2 = tacotron2
dec = tacotron2.decoder
self.max_decoder_steps = iter
self.p_attention_dropout = dec.p_attention_dropout
self.p_decoder_dropout = dec.p_decoder_dropout
self.prenet = dec.prenet
self.prenet.infer = prenet_infer
self.attention_rnn = nn.LSTM(dec.prenet_dim + dec.encoder_embedding_dim,
dec.attention_rnn_dim, 1)
lstmcell2lstm_params(self.attention_rnn, dec.attention_rnn)
self.attention_rnn.flatten_parameters()
self.attention_layer = dec.attention_layer
self.decoder_rnn = nn.LSTM(dec.attention_rnn_dim + dec.encoder_embedding_dim,
dec.decoder_rnn_dim, 1)
lstmcell2lstm_params(self.decoder_rnn, dec.decoder_rnn)
self.decoder_rnn.flatten_parameters()
self.linear_projection = dec.linear_projection
self.gate_layer = dec.gate_layer
def decode(self, decoder_input, in_attention_hidden, in_attention_cell,
in_decoder_hidden, in_decoder_cell, in_attention_weights,
in_attention_weights_cum, in_attention_context, memory,
processed_memory, mask):
cell_input = torch.cat((decoder_input, in_attention_context), -1)
_, (out_attention_hidden, out_attention_cell) = self.attention_rnn(
cell_input.unsqueeze(0), (in_attention_hidden.unsqueeze(0),
in_attention_cell.unsqueeze(0)))
out_attention_hidden = out_attention_hidden.squeeze(0)
out_attention_cell = out_attention_cell.squeeze(0)
out_attention_hidden = F.dropout(
out_attention_hidden, self.p_attention_dropout, False)
attention_weights_cat = torch.cat(
(in_attention_weights.unsqueeze(1),
in_attention_weights_cum.unsqueeze(1)), dim=1)
out_attention_context, out_attention_weights = self.attention_layer(
out_attention_hidden, memory, processed_memory,
attention_weights_cat, mask)
out_attention_weights_cum = in_attention_weights_cum + out_attention_weights
decoder_input_tmp = torch.cat(
(out_attention_hidden, out_attention_context), -1)
_, (out_decoder_hidden, out_decoder_cell) = self.decoder_rnn(
decoder_input_tmp.unsqueeze(0), (in_decoder_hidden.unsqueeze(0),
in_decoder_cell.unsqueeze(0)))
out_decoder_hidden = out_decoder_hidden.squeeze(0)
out_decoder_cell = out_decoder_cell.squeeze(0)
out_decoder_hidden = F.dropout(
out_decoder_hidden, self.p_decoder_dropout, False)
decoder_hidden_attention_context = torch.cat(
(out_decoder_hidden, out_attention_context), 1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return (decoder_output, gate_prediction, out_attention_hidden,
out_attention_cell, out_decoder_hidden, out_decoder_cell,
out_attention_weights, out_attention_weights_cum, out_attention_context)
def forward(self,
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask,
mel_output_input,
gate_output_input,
not_finished,
mel_lengths):
i = 0
gate_threshold = 0.5
while i < self.max_decoder_steps:
decoder_input_pr = self.prenet.infer(self.prenet, decoder_input)
(mel_output,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context) = self.decode(decoder_input_pr,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask)
mel_output_input = torch.cat(
(mel_output_input, mel_output.unsqueeze(2)), dim = 2)
gate_output_input = torch.cat(
(gate_output_input, gate_output.unsqueeze(2)), dim = 2)
t = torch.sigmoid(gate_output)
dec = torch.le(t, gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished * dec
mel_lengths += not_finished
decoder_input = mel_output
i = i + 1
return (decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory, mask,
mel_output_input[:, :, 1:], gate_output_input[:, :, 1:], not_finished, mel_lengths)
def test_inference(encoder, decoder_iter, postnet):
encoder.eval()
decoder_iter.eval()
postnet.eval()
sys.path.append('./tensorrt')
from inference_trt import init_decoder_inputs
texts = ["Hello World, good day."]
sequences, sequence_lengths = prepare_input_sequence(texts)
measurements = {}
print("Running Tacotron2 Encoder")
with torch.no_grad():
memory, processed_memory, lens = encoder(sequences, sequence_lengths)
print("Running Tacotron2 Decoder")
device = memory.device
dtype = memory.dtype
mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device = device)
not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device = device)
mel_outputs, gate_outputs, alignments = (torch.zeros(1), torch.zeros(1), torch.zeros(1))
gate_threshold = 0.6
max_decoder_steps = 1000
first_iter = True
(decoder_input, attention_hidden, attention_cell, decoder_hidden,
decoder_cell, attention_weights, attention_weights_cum,
attention_context, memory, processed_memory,
mask) = init_decoder_inputs(memory, processed_memory, sequence_lengths)
while True:
with torch.no_grad():
(mel_output, gate_output,
attention_hidden, attention_cell,
decoder_hidden, decoder_cell,
attention_weights, attention_weights_cum,
attention_context) = decoder_iter(decoder_input, attention_hidden, attention_cell, decoder_hidden,
decoder_cell, attention_weights, attention_weights_cum,
attention_context, memory, processed_memory, mask)
if first_iter:
mel_outputs = torch.unsqueeze(mel_output, 2)
gate_outputs = torch.unsqueeze(gate_output, 2)
alignments = torch.unsqueeze(attention_weights, 2)
first_iter = False
else:
mel_outputs = torch.cat((mel_outputs, torch.unsqueeze(mel_output, 2)), 2)
gate_outputs = torch.cat((gate_outputs, torch.unsqueeze(gate_output, 2)), 2)
alignments = torch.cat((alignments, torch.unsqueeze(attention_weights, 2)), 2)
dec = torch.le(torch.sigmoid(gate_output), gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished*dec
mel_lengths += not_finished
if torch.sum(not_finished) == 0:
print("Stopping after ", mel_outputs.size(2), " decoder steps")
break
if mel_outputs.size(2) == max_decoder_steps:
print("Warning! Reached max decoder steps")
break
decoder_input = mel_output
print("Running Tacotron2 PostNet")
with torch.no_grad():
mel_outputs_postnet = postnet(mel_outputs)
return mel_outputs_postnet
def main():
parser = argparse.ArgumentParser(
description='PyTorch Tacotron 2 export to TRT')
parser = parse_args(parser)
args, _ = parser.parse_known_args()
tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
fp16_run=args.fp16, cpu_run=True)
iter = args.iter
opset_version = 11
batch_size = args.batch_size
sequences = torch.randint(low=0, high=148, size=(batch_size, 50),
dtype=torch.long)
sequence_lengths = torch.IntTensor([sequences.size(1)] * batch_size).int()
memory = torch.randn((batch_size, sequence_lengths[0], 512))
if args.fp16:
memory = memory.half()
memory_lengths = sequence_lengths
decoder_input = tacotron2.decoder.get_go_frame(memory)
mask = get_mask_from_lengths(memory_lengths)
(attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
processed_memory) = tacotron2.decoder.initialize_decoder_states(memory)
not_finished = torch.ones([batch_size], dtype=torch.int32)
mel_lengths = torch.zeros([batch_size], dtype=torch.int32)
dummy_input = (sequences,
sequence_lengths,
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
mask,
not_finished,
mel_lengths)
encoder = Encoder(tacotron2)
encoder.eval()
with torch.no_grad():
encoder(*dummy_input)
torch.onnx.export(encoder, dummy_input, args.output+"/"+"encoder.onnx",
opset_version=opset_version,
do_constant_folding=True,
input_names=["sequences",
"sequence_lengths",
"decoder_input",
"attention_hidden",
"attention_cell",
"decoder_hidden",
"decoder_cell",
"attention_weights",
"attention_weights_cum",
"attention_context",
"mask",
"not_finished_input",
"mel_lengths_input"],
output_names=["decoder_output",
"out_attention_hidden",
"out_attention_cell",
"out_decoder_hidden",
"out_decoder_cell",
"out_attention_weights",
"out_attention_weights_cum",
"out_attention_context",
"memory",
"processed_memory",
"mask",
"mel_outputs",
"gate_outputs",
"not_finished_output",
"mel_lengths_output"],
dynamic_axes={"sequences": {1: "text_seq"},
"memory": {1: "mem_seq"},
"processed_memory": {1: "mem_seq"},
"attention_weights": {1: "seq_len"},
"attention_weights_cum": {1: "seq_len"},
"mask": {1: "seq_len"},
"out_attention_weights": {1: "seq_len"},
"out_attention_weights_cum": {1: "seq_len"}
})
(mel_output, attention_hidden,
attention_cell, decoder_hidden, decoder_cell,
attention_weights, attention_weights_cum, attention_context, memory,
processed_memory, mask, mel_outputs, gate_outputs, not_finished, mel_lengths) = encoder(sequences,
sequence_lengths,
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
mask,
not_finished,
mel_lengths)
decoder_input = mel_output
decoder_iter = DecoderIter(tacotron2, iter)
memory = torch.randn((batch_size, sequence_lengths[0], 512))
if args.fp16:
memory = memory.half()
memory_lengths = sequence_lengths
mask = get_mask_from_lengths(memory_lengths)
(attention_hidden_t,
attention_cell_t,
decoder_hidden_t,
decoder_cell_t,
attention_weights_t,
attention_weights_cum_t,
attention_context_t,
processed_memory) = tacotron2.decoder.initialize_decoder_states(memory)
mel_output_input = torch.randn((batch_size, 80, 1))
gate_output_input = torch.randn((batch_size, 1, 1))
dummy_input = (decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask,
mel_output_input,
gate_output_input,
not_finished,
mel_lengths)
decoder_iter = DecoderIter(tacotron2, iter)
decoder_iter.eval()
with torch.no_grad():
decoder_iter(*dummy_input)
torch.onnx.export(decoder_iter, dummy_input, args.output+"/"+"decoder_iter.onnx",
opset_version=opset_version,
do_constant_folding=True,
use_external_data_format=True,
input_names=["decoder_input",
"attention_hidden",
"attention_cell",
"decoder_hidden",
"decoder_cell",
"attention_weights",
"attention_weights_cum",
"attention_context",
"memory",
"processed_memory",
"mask",
"mel_output_input",
"gate_output_input",
"not_finished_input",
"mel_lengths_input"],
output_names=["decoder_input_out",
"attention_hidden_out",
"attention_cell_out",
"decoder_hidden_out",
"decoder_cell_out",
"attention_weights_out",
"attention_weights_cum_out",
"attention_context_out",
"memory_out",
"processed_memory_out",
"mask_out",
"mel_output_cat",
"gate_output_cat",
"not_finished_out",
"mel_lengths_out"],
dynamic_axes={"attention_weights":{1:"seq_len"},
"attention_weights_cum":{1:"seq_len"},
"memory":{1:"seq_len"},
"processed_memory":{1:"seq_len"},
"mask" : {1:"seq_len"}
})
postnet = Postnet(tacotron2)
dummy_input = torch.randn((batch_size, 80, 620))
if args.fp16:
dummy_input = dummy_input.half()
torch.onnx.export(postnet, dummy_input, args.output+"/"+"postnet.onnx",
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel_outputs"],
output_names=["mel_outputs_postnet"],
dynamic_axes={"mel_outputs": {2: "mel_seq"},
"mel_outputs_postnet": {2: "mel_seq"}})
if __name__ == '__main__':
main()