import os
import sys
import copy
import time
import yaml
import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torch.onnx
from collections import OrderedDict
import ssl
sys.path.append('./')
from models.model_ctc import *
supported_rnn = {'nn.LSTM':nn.LSTM, 'nn.GRU': nn.GRU, 'nn.RNN':nn.RNN}
supported_activate = {'relu':nn.ReLU, 'tanh':nn.Tanh, 'sigmoid':nn.Sigmoid}
parser = argparse.ArgumentParser(description='cnn_lstm_ctc')
parser.add_argument('--conf', default='conf/ctc_config.yaml', help='conf file with argument of LSTM and training')
parser.add_argument('--batchsize', default=1, help='batchszie for transfer onnx batch')
class Vocab(object):
def __init__(self, vocab_file):
self.vocab_file = vocab_file
self.word2index = {"blank": 0, "UNK": 1}
self.index2word = {0: "blank", 1: "UNK"}
self.word2count = {}
self.n_words = 2
self.read_lang()
def add_sentence(self, sentence):
for word in sentence.split(' '):
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def read_lang(self):
print("Reading vocabulary from {}".format(self.vocab_file))
with open(self.vocab_file, 'r') as rf:
line = rf.readline()
while line:
line = line.strip().split(' ')
if len(line) > 1:
sen = ' '.join(line[1:])
else:
sen = line[0]
self.add_sentence(sen)
line = rf.readline()
print("Vocabulary size is {}".format(self.n_words))
def proc_nodes_module(checkpoint, AttrName):
new_state_dict = OrderedDict()
for k, v in checkpoint[AttrName].items():
if(k[0:7] == "module."):
name = k[7:]
else:
name = k[0:]
new_state_dict[name]=v
return new_state_dict
def run_epoch(epoch_id, model, data_iter, loss_fn, device,
opts, sum_writer, optimizer=None, print_every=20, is_training=True):
if is_training:
model.train()
else:
model.eval()
batch_time = 0
data_time = 0
total_loss = 0
total_tokens = 0
total_errs = 0
cur_loss = 0
i = 0
steps_per_epoch = len(data_iter)
end = time.time()
for i, data in enumerate(data_iter):
data_time += (time.time() - end)
global_step = epoch_id * steps_per_epoch + i
inputs, input_sizes, targets, target_sizes, utt_list = data
with torch.autograd.profiler.profile(record_shapes=True, use_cuda=True) as prof:
inputs = inputs.to(device)
input_sizes = input_sizes.to(device)
targets = targets.to(device)
target_sizes = target_sizes.to(device)
out = model(inputs)
out_len, batch_size, _ = out.size()
input_sizes = (input_sizes * out_len).long()
loss = loss_fn(out, targets, input_sizes, target_sizes)
loss /= batch_size
cur_loss += loss.item()
total_loss += loss.item()
prob, index = torch.max(out, dim=-1)
batch_errs, batch_tokens = model.compute_wer(index.transpose(0, 1).cpu().numpy(),
input_sizes.cpu().numpy(), targets.cpu().numpy(), target_sizes.cpu().numpy())
total_errs += batch_errs
total_tokens += batch_tokens
if is_training:
optimizer.zero_grad()
if opts.opt_level and opts.use_gpu:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
sum_writer.add_scalar('Accuary/train/total_loss', total_loss / (i+1), global_step)
sum_writer.add_scalar('Accuary/train/total_wer', total_errs / total_tokens, global_step)
prof.export_chrome_trace('prof/'+str(i) + "_cuda_lstm.prof")
batch_time += (time.time() - end)
if is_training:
print('Epoch: [%d] [%d / %d], Time %.6f Data %.6f s, total_loss = %.5f s, total_wer = %.5f' % (epoch_id,
i+1, steps_per_epoch, batch_time / (i+1), data_time / (i+1), total_loss / (i+1),
total_errs / total_tokens ))
end = time.time()
average_loss = total_loss / (i+1)
training = "Train" if is_training else "Valid"
return 1-total_errs / total_tokens, average_loss
class Config(object):
batch_size = 4
dropout = 0.1
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main(conf, batchsize):
checkpoint = torch.load("./checkpoint/ctc_fbank_cnn/ctc_best_model.pth", map_location='cpu')
checkpoint['state_dict'] = proc_nodes_module(checkpoint, 'state_dict')
opts = Config()
for k, v in conf.items():
setattr(opts, k, v)
print('{:50}:{}'.format(k, v))
device = torch.device('cpu')
sum_writer = SummaryWriter(opts.summary_path)
if opts.seed is not None:
seed_everything(opts.seed)
vocab = Vocab(opts.vocab_file)
rnn_type = supported_rnn[opts.rnn_type]
rnn_param = {"rnn_input_size":opts.rnn_input_size,
"rnn_hidden_size":opts.rnn_hidden_size, "rnn_layers":opts.rnn_layers,
"rnn_type":rnn_type, "bidirectional":opts.bidirectional, "batch_norm":opts.batch_norm}
num_class = vocab.n_words
opts.output_class_dim = vocab.n_words
drop_out = opts.drop_out
add_cnn = opts.add_cnn
cnn_param = {}
channel = eval(opts.channel)
kernel_size = eval(opts.kernel_size)
stride = eval(opts.stride)
padding = eval(opts.padding)
pooling = eval(opts.pooling)
activation_function = supported_activate[opts.activation_function]
cnn_param['batch_norm'] = opts.batch_norm
cnn_param['activate_function'] = activation_function
cnn_param["layer"] = []
for layer in range(opts.layers):
layer_param = [channel[layer], kernel_size[layer], stride[layer], padding[layer]]
if pooling is not None:
layer_param.append(pooling[layer])
else:
layer_param.append(None)
cnn_param["layer"].append(layer_param)
model = CTC_Model(add_cnn=add_cnn, cnn_param=cnn_param, rnn_param=rnn_param, num_class=num_class, drop_out=drop_out)
model = model.to('cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
input_names = ["actual_input_1"]
output_names = ["output1"]
batch_size = int(batchsize)
dummy_input = torch.randn(batch_size, 390, 243, device='cpu')
dynamic_axes = {'actual_input_1': {0: '-1'}, 'output1': {1: '-1'}}
output_file = "lstm_ctc_{}batch.onnx".format(str(batch_size))
torch.onnx.export(model, dummy_input, output_file, input_names = input_names,
output_names = output_names, opset_version=11)
if __name__ == '__main__':
ssl._create_default_https_context = ssl._create_unverified_context
args = parser.parse_args()
batchsize = args.batchsize
try:
config_path = args.conf
conf = yaml.safe_load(open(config_path, 'r'))
except:
print("No input config or config file missing, please check.")
sys.exit(1)
main(conf, batchsize)