diff --git a/PyTorch/Translation/GNMT/seq2seq/models/encoder.py b/PyTorch/Translation/GNMT/seq2seq/models/encoder.py
index e9a234e7..c4f7037f 100644
--- a/PyTorch/Translation/GNMT/seq2seq/models/encoder.py
+++ b/PyTorch/Translation/GNMT/seq2seq/models/encoder.py
@@ -97,7 +97,7 @@ class ResidualRecurrentEncoder(nn.Module):
 
         # bidirectional layer
         x = self.dropout(x)
-        x = pack_padded_sequence(x, lengths.cpu().numpy(),
+        x = pack_padded_sequence(x, lengths.cpu(),
                                  batch_first=self.batch_first)
         x, _ = self.rnn_layers[0](x)
         x, _ = pad_packed_sequence(x, batch_first=self.batch_first)
diff --git a/PyTorch/Translation/GNMT/seq2seq/models/gnmt.py b/PyTorch/Translation/GNMT/seq2seq/models/gnmt.py
index 5643852c..bc251e96 100644
--- a/PyTorch/Translation/GNMT/seq2seq/models/gnmt.py
+++ b/PyTorch/Translation/GNMT/seq2seq/models/gnmt.py
@@ -19,7 +19,9 @@
 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 # SOFTWARE.
 
+import torch
 import torch.nn as nn
+from torch.nn.functional import log_softmax
 
 import seq2seq.data.config as config
 from seq2seq.models.decoder import ResidualRecurrentDecoder
@@ -32,7 +34,7 @@ class GNMT(Seq2Seq):
     GNMT v2 model
     """
     def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2,
-                 batch_first=False, share_embedding=True):
+                 batch_first=False, share_embedding=True, max_seq_len=6):
         """
         Constructor for the GNMT v2 model.
 
@@ -49,6 +51,8 @@ class GNMT(Seq2Seq):
 
         super(GNMT, self).__init__(batch_first=batch_first)
 
+        self.max_seq_len = max_seq_len
+
         if share_embedding:
             embedder = nn.Embedding(vocab_size, hidden_size,
                                     padding_idx=config.PAD)
@@ -66,7 +70,22 @@ class GNMT(Seq2Seq):
 
     def forward(self, input_encoder, input_enc_len, input_decoder):
         context = self.encode(input_encoder, input_enc_len)
-        context = (context, input_enc_len, None)
-        output, _, _ = self.decode(input_decoder, context)
+        context = [context, input_enc_len, None]
+
+        device = input_encoder.device
+        translation = torch.zeros([input_encoder.shape[0], 1], dtype=torch.int32, device=device)
+
+        translation += config.BOS
+        words = input_decoder
+        word_view = (-1, 1)
+
+        for idx in range(1, self.max_seq_len):
+            words = words.view(word_view)
+
+            logits, _, context = self.decode(words, context, True)
+            logprobs = log_softmax(logits, dim=-1)
+            words = torch.argmax(logprobs, dim=-1, keepdim=False).int()
+
+            translation = torch.cat((translation, words), dim=1)
 
-        return output
+        return translation