@@ -1,3 +1,18 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
import math
from typing import List, Union
@@ -6,7 +21,8 @@
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
-from torch.cuda.amp import autocast
+#from torch.cuda.amp import autocast
+
from torch.nn import CTCLoss
from deepspeech_pytorch.configs.train_config import SpectConfig, BiDirectionalConfig, OptimConfig, AdamConfig, \
@@ -57,6 +73,7 @@
:return: Masked output from the module
"""
for module in self.seq_module:
+ # lengths = x[3]
x = module(x)
mask = torch.BoolTensor(x.size()).fill_(0)
if x.is_cuda:
@@ -211,8 +228,13 @@
target_decoder=self.evaluation_decoder
)
- def forward(self, x, lengths):
+ def forward(self, x):
+
+ lengths = x[1]
+ x = x[0]
+
lengths = lengths.cpu().int()
+ lengths.cpu().int()
output_lengths = self.get_seq_lens(lengths)
x, _ = self.conv(x, output_lengths)
@@ -222,7 +244,6 @@
for rnn in self.rnns:
x = rnn(x, output_lengths)
-
if not self.bidirectional: # no need for lookahead layer in bidirectional
x = self.lookahead(x)
@@ -246,8 +267,8 @@
inputs, targets, input_percentages, target_sizes = batch
input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
inputs = inputs.to(self.device)
- with autocast(enabled=self.precision == 16):
- out, output_sizes = self(inputs, input_sizes)
+# with autocast(enabled=self.precision == 16):
+ out, output_sizes = self(inputs, input_sizes)
decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)
self.wer(
preds=out,