@@ -97,7 +97,7 @@ class ResidualRecurrentEncoder(nn.Module):
# bidirectional layer
x = self.dropout(x)
- x = pack_padded_sequence(x, lengths.cpu().numpy(),
+ x = pack_padded_sequence(x, lengths.cpu(),
batch_first=self.batch_first)
x, _ = self.rnn_layers[0](x)
x, _ = pad_packed_sequence(x, batch_first=self.batch_first)
@@ -19,7 +19,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import torch
import torch.nn as nn
+from torch.nn.functional import log_softmax
import seq2seq.data.config as config
from seq2seq.models.decoder import ResidualRecurrentDecoder
@@ -32,7 +34,7 @@ class GNMT(Seq2Seq):
GNMT v2 model
"""
def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2,
- batch_first=False, share_embedding=True):
+ batch_first=False, share_embedding=True, max_seq_len=6):
"""
Constructor for the GNMT v2 model.
@@ -49,6 +51,8 @@ class GNMT(Seq2Seq):
super(GNMT, self).__init__(batch_first=batch_first)
+ self.max_seq_len = max_seq_len
+
if share_embedding:
embedder = nn.Embedding(vocab_size, hidden_size,
padding_idx=config.PAD)
@@ -66,7 +70,22 @@ class GNMT(Seq2Seq):
def forward(self, input_encoder, input_enc_len, input_decoder):
context = self.encode(input_encoder, input_enc_len)
- context = (context, input_enc_len, None)
- output, _, _ = self.decode(input_decoder, context)
+ context = [context, input_enc_len, None]
+
+ device = input_encoder.device
+ translation = torch.zeros([input_encoder.shape[0], 1], dtype=torch.int32, device=device)
+
+ translation += config.BOS
+ words = input_decoder
+ word_view = (-1, 1)
+
+ for idx in range(1, self.max_seq_len):
+ words = words.view(word_view)
+
+ logits, _, context = self.decode(words, context, True)
+ logprobs = log_softmax(logits, dim=-1)
+ words = torch.argmax(logprobs, dim=-1, keepdim=False).int()
+
+ translation = torch.cat((translation, words), dim=1)
- return output
+ return translation