.../SpeechRecognition/Jasper/common/features.py | 5 ++---
PyTorch/SpeechRecognition/Jasper/common/helpers.py | 8 ++++----
PyTorch/SpeechRecognition/Jasper/jasper/model.py | 14 +++++++++++---
3 files changed, 17 insertions(+), 10 deletions(-)
@@ -5,7 +5,7 @@ import librosa
import torch
import torch.nn as nn
-from apex import amp
+# from apex import amp
class BaseFeatures(nn.Module):
@@ -46,8 +46,7 @@ class BaseFeatures(nn.Module):
dtype = audio.dtype
audio = audio.float()
if optim_level == 1:
- with amp.disable_casts():
- feat, feat_lens = self.calculate_features(audio, audio_lens)
+ pass
else:
feat, feat_lens = self.calculate_features(audio, audio_lens)
@@ -17,7 +17,7 @@ import os
import re
from collections import OrderedDict
-from apex import amp
+# from apex import amp
import torch
import torch.distributed as dist
@@ -234,7 +234,7 @@ class Checkpointer(object):
'state_dict': unwrap_ddp(model).state_dict(),
'ema_state_dict': unwrap_ddp(ema_model).state_dict() if ema_model is not None else None,
'optimizer': optimizer.state_dict(),
- 'amp': amp.state_dict() if self.use_amp else None,
+ 'amp': None,
}
if is_best:
@@ -293,8 +293,8 @@ class Checkpointer(object):
optimizer.load_state_dict(checkpoint['optimizer'])
- if self.use_amp:
- amp.load_state_dict(checkpoint['amp'])
+ # if self.use_amp:
+ # amp.load_state_dict(checkpoint['amp'])
meta['start_epoch'] = checkpoint.get('epoch')
meta['best_wer'] = checkpoint.get('best_wer', meta['best_wer'])
@@ -66,14 +66,22 @@ class MaskedConv1d(nn.Conv1d):
self.masked = masked
def get_seq_len(self, lens):
- return ((lens + 2 * self.padding[0] - self.dilation[0]
- * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1)
+ if torch.onnx.is_in_onnx_export():
+ return ((lens + 2. * self.padding[0] - self.dilation[0]
+ * (self.kernel_size[0] - 1.) - 1.) // self.stride[0] + 1.).int()
+ else:
+ return ((lens + 2 * self.padding[0] - self.dilation[0]
+ * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1)
def forward(self, x, x_lens=None):
if self.masked:
max_len = x.size(2)
idxs = torch.arange(max_len, dtype=x_lens.dtype, device=x_lens.device)
- mask = idxs.expand(x_lens.size(0), max_len) >= x_lens.unsqueeze(1)
+ if torch.onnx.is_in_onnx_export():
+ temp = torch.zeros(x_lens.size(0), max_len)
+ mask = idxs.expand_as(temp) >= x_lens.unsqueeze(1)
+ else:
+ mask = idxs.expand(x_lens.size(0), max_len) >= x_lens.unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
x_lens = self.get_seq_len(x_lens)
--