import argparse
import models
import sys
from pathlib import Path
import torch
import os
from waveglow import model as glow
sys.modules['glow'] = glow
def parse_args(parser):
"""
Parse commandline arguments.
"""
parser.add_argument('-i', '--input', type=str, required=True, default="phrases/tui_val100.tsv",
help='Full path to the input text (phareses separated by newlines)')
parser.add_argument('-o', '--output', default=None,
help='Output folder to save audio (file per phrase)')
parser.add_argument('--log-file', type=str, default=None,
help='Path to a DLLogger log file')
parser.add_argument('--save-mels', action='store_true', help='')
parser.add_argument('--cuda', action='store_true',
help='Run inference on a GPU using CUDA')
parser.add_argument('--cudnn-benchmark', action='store_true',
help='Enable cudnn benchmark mode')
parser.add_argument('--fastpitch', type=str,
help='Full path to the generator checkpoint file (skip to use ground truth mels)')
parser.add_argument('--waveglow', type=str,
help='Full path to the WaveGlow model checkpoint file (skip to only generate mels)')
parser.add_argument('-s', '--sigma-infer', default=0.9, type=float,
help='WaveGlow sigma')
parser.add_argument('-d', '--denoising-strength', default=0.01, type=float,
help='WaveGlow denoising')
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
help='Sampling rate')
parser.add_argument('--stft-hop-length', type=int, default=256,
help='STFT hop length for estimating audio length from mel size')
parser.add_argument('--amp', action='store_true',
help='Inference with AMP')
parser.add_argument('-bs', '--batch-size', type=int, default=64)
parser.add_argument('--warmup-steps', type=int, default=0,
help='Warmup iterations before measuring performance')
parser.add_argument('--repeats', type=int, default=1,
help='Repeat inference for benchmarking')
parser.add_argument('--torchscript', action='store_true',
help='Apply TorchScript')
parser.add_argument('--ema', action='store_true',
help='Use EMA averaged model (if saved in checkpoints)')
parser.add_argument('--dataset-path', type=str,
help='Path to dataset (for loading extra data fields)')
parser.add_argument('--speaker', type=int, default=0,
help='Speaker ID for a multi-speaker model')
parser.add_argument('--p-arpabet', type=float, default=1.0, help='')
parser.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
help='')
parser.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
help='')
transform = parser.add_argument_group('transform')
transform.add_argument('--fade-out', type=int, default=10,
help='Number of fadeout frames at the end')
transform.add_argument('--pace', type=float, default=1.0,
help='Adjust the pace of speech')
transform.add_argument('--pitch-transform-flatten', action='store_true',
help='Flatten the pitch')
transform.add_argument('--pitch-transform-invert', action='store_true',
help='Invert the pitch wrt mean value')
transform.add_argument('--pitch-transform-amplify', type=float, default=1.0,
help='Amplify pitch variability, typical values are in the range (1.0, 3.0).')
transform.add_argument('--pitch-transform-shift', type=float, default=0.0,
help='Raise/lower the pitch by <hz>')
transform.add_argument('--pitch-transform-custom', action='store_true',
help='Apply the transform from pitch_transform.py')
text_processing = parser.add_argument_group('Text processing parameters')
text_processing.add_argument('--text-cleaners', nargs='*',
default=['english_cleaners_v2'], type=str,
help='Type of text cleaners for input text')
text_processing.add_argument('--symbol-set', type=str, default='english_basic',
help='Define symbol set for input text')
cond = parser.add_argument_group('conditioning on additional attributes')
cond.add_argument('--n-speakers', type=int, default=1,
help='Number of speakers in the model.')
return parser
def load_model_from_ckpt(checkpoint_path, ema, model):
checkpoint_data = torch.load(checkpoint_path, map_location=torch.device("cpu"))
status = ''
if 'state_dict' in checkpoint_data:
sd = checkpoint_data['state_dict']
if ema and 'ema_state_dict' in checkpoint_data:
sd = checkpoint_data['ema_state_dict']
status += ' (EMA)'
elif ema and not 'ema_state_dict' in checkpoint_data:
print(f'WARNING: EMA weights missing for {checkpoint_data}')
if any(key.startswith('module.') for key in sd):
sd = {k.replace('module.', ''): v for k,v in sd.items()}
status += ' ' + str(model.load_state_dict(sd, strict=False))
else:
model = checkpoint_data['model']
print(f'Loaded {checkpoint_path}{status}')
return model
def load_and_setup_model(model_name, parser, checkpoint, amp, device,
unk_args=[], forward_is_infer=False, ema=True,
jitable=False):
model_parser = models.parse_model_args(model_name, parser, add_help=False)
model_args, model_unk_args = model_parser.parse_known_args()
unk_args[:] = list(set(unk_args) & set(model_unk_args))
model_config = models.get_model_config(model_name, model_args)
model = models.get_model(model_name, model_config, device,
forward_is_infer=forward_is_infer,
jitable=jitable)
if checkpoint is not None:
model = load_model_from_ckpt(checkpoint, ema, model)
if model_name == "WaveGlow":
for k, m in model.named_modules():
m._non_persistent_buffers_set = set()
model = model.remove_weightnorm(model)
if amp:
model.half()
model.eval()
return model.to(device)
def pth2onnx(model, dummy_input, output_file):
model.eval()
input_names = ["input"]
output_names = ['mel_out', 'dec_lens', 'dur_pred', 'pitch_pred', 'energy_pred']
torch.onnx.export(model, dummy_input, output_file, dynamic_axes={'input': {0: 'bs'}}, input_names=input_names, output_names=output_names, opset_version=11, verbose=False)
def main():
"""
Launches text to speech (inference).
Inference is executed on a single GPU.
"""
parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
allow_abbrev=False)
parser = parse_args(parser)
args, unk_args = parser.parse_known_args()
torch.backends.cudnn.benchmark = args.cudnn_benchmark
device = torch.device('cpu')
if args.fastpitch != 'SKIP':
generator = load_and_setup_model(
'FastPitch', parser, args.fastpitch, args.amp, device,
unk_args=unk_args, forward_is_infer=True, ema=args.ema,
jitable=args.torchscript)
if args.torchscript:
generator = torch.jit.script(generator)
else:
generator = None
bs = args.batch_size
text_padded = torch.LongTensor(bs, 200)
text_padded.zero_()
pth2onnx(model=generator, dummy_input=text_padded, output_file=f"FastPitch.onnx")
if __name__ == '__main__':
main()