diff --git a/pytorch/data_utils.py b/pytorch/data_utils.py
index df762a7..1642018 100644
--- a/pytorch/data_utils.py
+++ b/pytorch/data_utils.py
@@ -1,12 +1,11 @@
 import os, sys
 import glob
-
-from collections import Counter, OrderedDict
 import numpy as np
 import torch
 
 from utils.vocabulary import Vocab
 
+
 class LMOrderedIterator(object):
     def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
         """
@@ -95,8 +94,6 @@ class LMShuffledIterator(object):
         n_retain = 0
 
         while True:
-            # data   : [n_retain+bptt x bsz]
-            # target : [bptt x bsz]
             data[n_retain:].fill_(-1)
             target.fill_(-1)
 
@@ -214,7 +211,7 @@ class Corpus(object):
             self.valid = self.vocab.encode_file(
                 os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
             self.test  = self.vocab.encode_file(
-                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
+                os.path.join(path, 'test.py.txt'), ordered=False, add_double_eos=True)
 
     def get_iterator(self, split, *args, **kwargs):
         if split == 'train':
@@ -223,7 +220,7 @@ class Corpus(object):
             elif self.dataset == 'lm1b':
                 kwargs['shuffle'] = True
                 data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
-        elif split in ['valid', 'test']:
+        elif split in ['valid', 'test', 'onnx']:
             data = self.valid if split == 'valid' else self.test
             if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
                 data_iter = LMOrderedIterator(data, *args, **kwargs)
@@ -259,12 +256,13 @@ def get_lm_corpus(datadir, dataset):
 
     return corpus
 
+
 if __name__ == '__main__':
     import argparse
-    parser = argparse.ArgumentParser(description='unit test')
-    parser.add_argument('--datadir', type=str, default='../data/text8',
+    parser = argparse.ArgumentParser(description='unit test.py')
+    parser.add_argument('--datadir', type=str, default='../data/enwik8',
                         help='location of the data corpus')
-    parser.add_argument('--dataset', type=str, default='text8',
+    parser.add_argument('--dataset', type=str, default='enwik8',
                         choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
                         help='dataset name')
     args = parser.parse_args()
diff --git a/pytorch/mem_transformer.py b/pytorch/mem_transformer.py
index 45147df..baf1bc0 100644
--- a/pytorch/mem_transformer.py
+++ b/pytorch/mem_transformer.py
@@ -9,22 +9,126 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 sys.path.append('utils')
-from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
-from log_uniform_sampler import LogUniformSampler, sample_logits
+from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
+from utils.log_uniform_sampler import LogUniformSampler, sample_logits
+import pdb
+import traceback
+pdb.set_trace = lambda:1
+
+
+def einsum1(eq, opand):
+    """
+    ibnd,jbnd->ijbn
+    """
+    tmp = torch.einsum(eq, opand)
+    if not torch.onnx.is_in_onnx_export():
+        return tmp
+    try:
+        x0, y0 = opand
+        assert len(x0.shape) == 4
+        assert x0.shape[1:] == y0.shape[1:], "bad shape {}, {}".format(x0.shape, y0.shape)
+        i, b, n, d = x0.shape
+        j, _, _, _ = y0.shape
+        x = x0.clone()
+        y = y0.clone()
+        x = x.reshape(i, b * n, d).permute(1, 0, 2)
+        y = y.reshape(j, b * n, d).permute(1, 2, 0)
+        z = torch.bmm(x, y)
+        z = z.permute(1, 2, 0).reshape(i, j, b, n)
+        assert tmp.equal(z)
+        return z
+    except Exception as e:
+        #print('str(e):\t\t', str(e))
+        #print('repr(e):\t', repr(e))
+        #print('traceback.print_exc():', traceback.print_exc())
+        #print('traceback.format_exc():\n%s' % traceback.format_exc())
+        pdb.set_trace()
+
+
+def einsum2(eq, opand):
+    """
+    ibnd,jnd->ijbn
+    """
+    tmp = torch.einsum(eq, opand)
+    if not torch.onnx.is_in_onnx_export():
+        return tmp
+    try:
+        x0, y0 = opand
+        assert len(y0.shape) == 3
+        i, b, n, d = x0.shape
+        j, _, _ = y0.shape
+        x = x0.clone()
+        y = y0.clone()
+        y = y.unsqueeze(1).expand(j, b, n, d)
+        x = x.reshape(i, b * n, d).permute(1, 0, 2)
+        y = y.reshape(j, b * n, d).permute(1, 2, 0)
+        z = torch.bmm(x, y)
+        z = z.permute(1, 2, 0).reshape(i, j, b, n)
+        assert tmp.equal(z)
+        return z
+    except:
+        pdb.set_trace()
+
+
+def einsum3(eq, opand):
+    """
+    ijbn,jbnd->ibnd
+    """
+    tmp = torch.einsum(eq, opand)
+    if not torch.onnx.is_in_onnx_export():
+        return tmp
+    try:
+        x0, y0 = opand
+        assert len(x0.shape) == 4
+        assert x0.shape[2:] == y0.shape[1:3], "bad shape {}, {}".format(x0.shape, y0.shape)
+        i, j, b, n = x0.shape
+        j, _, _, d = y0.shape
+        x = x0.clone()
+        y = y0.clone()
+        x = x.reshape(i, j, b * n).permute(2, 0, 1)
+        y = y.reshape(j, b * n, d).permute(1, 0, 2)
+        z = torch.bmm(x, y)
+        z = z.permute(1, 0, 2).reshape(i, b, n, d)
+        assert tmp.equal(z)
+        return z
+    except:
+        pdb.set_trace()
+
+
+def triu_onnx(x, diagonal=0):
+    assert len(x.shape) == 2
+    m, l = x.shape
+    mask = torch.arange(l, device=x.device).expand(m, l)
+    arange = torch.arange(m, device=x.device)
+    arange = arange.unsqueeze(-1)
+    if diagonal:
+        arange = arange + diagonal
+    mask = mask >= arange
+    return x.masked_fill(mask==0, 0)
+
+
+def tril_onnx(x, diagonal=0):
+    return x - triu_onnx(x, diagonal)
+
+
+torch.triu = triu_onnx
+torch.tril = tril_onnx
+
+
+
 
 class PositionalEmbedding(nn.Module):
     def __init__(self, demb):
         super(PositionalEmbedding, self).__init__()
-
+        
         self.demb = demb
-
         inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
         self.register_buffer('inv_freq', inv_freq)
 
     def forward(self, pos_seq, bsz=None):
+        # print(torch.cuda.synchronize(), "打点1")
         sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
         pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
-
         if bsz is not None:
             return pos_emb[:,None,:].expand(-1, bsz, -1)
         else:
@@ -110,20 +214,20 @@ class MultiHeadAttn(nn.Module):
         head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
 
         # [qlen x klen x bsz x n_head]
-        attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
+        attn_score = einsum1('ibnd,jbnd->ijbn', (head_q, head_k))
         attn_score.mul_(self.scale)
         if attn_mask is not None and attn_mask.any().item():
             if attn_mask.dim() == 2:
-                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
+                attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
             elif attn_mask.dim() == 3:
-                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
+                attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
 
         # [qlen x klen x bsz x n_head]
         attn_prob = F.softmax(attn_score, dim=1)
         attn_prob = self.dropatt(attn_prob)
 
         # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
-        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
+        attn_vec = einsum3('ijbn,jbnd->ibnd', (attn_prob, head_v))
         attn_vec = attn_vec.contiguous().view(
             attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
 
@@ -198,7 +302,8 @@ class RelMultiHeadAttn(nn.Module):
 
         x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
 
-        x = x_padded[1:].view_as(x)
+        #x = x_padded[1:].view_as(x)
+        x = x_padded[1:].view(x.shape)
 
         if zero_triu:
             ones = torch.ones((x.size(0), x.size(1)))
@@ -212,12 +317,13 @@ class RelMultiHeadAttn(nn.Module):
 class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
     def __init__(self, *args, **kwargs):
         super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
-
+        
         self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
 
     def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
         qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
-
+        
+        pdb.set_trace()
         if mems is not None:
             cat = torch.cat([mems, w], 0)
             if self.pre_lnorm:
@@ -247,31 +353,35 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
 
         #### compute attention score
         rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
-        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
+        AC = einsum1('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
 
         rr_head_q = w_head_q + r_r_bias
-        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
+        BD = einsum2('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
         BD = self._rel_shift(BD)
 
         # [qlen x klen x bsz x n_head]
         attn_score = AC + BD
         attn_score.mul_(self.scale)
 
+
         #### compute attention probability
+        ##############################################################################################################################33
+        # edit this for Warning
+
         if attn_mask is not None and attn_mask.any().item():
             if attn_mask.dim() == 2:
-                attn_score = attn_score.float().masked_fill(
-                    attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
+                attn_score = attn_score.float().masked_fill(attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
             elif attn_mask.dim() == 3:
-                attn_score = attn_score.float().masked_fill(
-                    attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
+                attn_score = attn_score.float().masked_fill(attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
+
+        ################################################################################################################################
 
         # [qlen x klen x bsz x n_head]
         attn_prob = F.softmax(attn_score, dim=1)
         attn_prob = self.dropatt(attn_prob)
 
         #### compute attention vector
-        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
+        attn_vec = einsum3('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
 
         # [qlen x bsz x n_head x d_head]
         attn_vec = attn_vec.contiguous().view(
@@ -335,8 +445,8 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
         #### compute attention score
         rw_head_q = w_head_q + r_w_bias[None]                                   # qlen x bsz x n_head x d_head
 
-        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
-        B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
+        AC = einsum1('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
+        B_ = einsum2('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
         D_ = r_bias[None, :, None]                                              # 1    x klen x 1   x n_head
         BD = self._rel_shift(B_ + D_)
 
@@ -347,16 +457,16 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
         #### compute attention probability
         if attn_mask is not None and attn_mask.any().item():
             if attn_mask.dim() == 2:
-                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
+                attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
             elif attn_mask.dim() == 3:
-                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
+                attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
 
         # [qlen x klen x bsz x n_head]
         attn_prob = F.softmax(attn_score, dim=1)
         attn_prob = self.dropatt(attn_prob)
 
         #### compute attention vector
-        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
+        attn_vec = einsum3('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
 
         # [qlen x bsz x n_head x d_head]
         attn_vec = attn_vec.contiguous().view(
@@ -384,7 +494,7 @@ class DecoderLayer(nn.Module):
                                      pre_lnorm=kwargs.get('pre_lnorm'))
 
     def forward(self, dec_inp, dec_attn_mask=None, mems=None):
-
+      
         output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
                                mems=mems)
         output = self.pos_ff(output)
@@ -402,7 +512,7 @@ class RelLearnableDecoderLayer(nn.Module):
                                      pre_lnorm=kwargs.get('pre_lnorm'))
 
     def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
-
+        
         output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
                                attn_mask=dec_attn_mask,
                                mems=mems)
@@ -415,13 +525,14 @@ class RelPartialLearnableDecoderLayer(nn.Module):
                  **kwargs):
         super(RelPartialLearnableDecoderLayer, self).__init__()
 
+          
         self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
                             d_head, dropout, **kwargs)
+       
         self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                      pre_lnorm=kwargs.get('pre_lnorm'))
-
+      
     def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
-
         output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
                                attn_mask=dec_attn_mask,
                                mems=mems)
@@ -434,7 +545,7 @@ class AdaptiveEmbedding(nn.Module):
     def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 
                  sample_softmax=False):
         super(AdaptiveEmbedding, self).__init__()
-
+        
         self.n_token = n_token
         self.d_embed = d_embed
 
@@ -462,6 +573,7 @@ class AdaptiveEmbedding(nn.Module):
                 self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
 
     def forward(self, inp):
+        
         if self.div_val == 1:
             embed = self.emb_layers[0](inp)
             if self.d_proj != self.d_embed:
@@ -492,7 +604,8 @@ class AdaptiveEmbedding(nn.Module):
 
         return embed
 
-class MemTransformerLM(nn.Module):
+
+class MemTransformerLM(nn.Module):    # 打点11,model入口
     def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
                  dropout, dropatt, tie_weight=True, d_embed=None, 
                  div_val=1, tie_projs=[False], pre_lnorm=False,
@@ -509,8 +622,10 @@ class MemTransformerLM(nn.Module):
         self.n_head = n_head
         self.d_head = d_head
 
+
         self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, 
                                           div_val=div_val)
+      
 
         self.drop = nn.Dropout(dropout)
 
@@ -524,7 +639,7 @@ class MemTransformerLM(nn.Module):
         self.attn_type = attn_type
 
         self.layers = nn.ModuleList()
-        if attn_type == 0: # the default attention
+        if attn_type == 0: # the default attention     # 执行           
             for i in range(n_layer):
                 self.layers.append(
                     RelPartialLearnableDecoderLayer(
@@ -532,7 +647,8 @@ class MemTransformerLM(nn.Module):
                         tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                         dropatt=dropatt, pre_lnorm=pre_lnorm)
                 )
-        elif attn_type == 1: # learnable embeddings
+           
+        elif attn_type == 1: # learnable embeddings     # 未执行
             for i in range(n_layer):
                 self.layers.append(
                     RelLearnableDecoderLayer(
@@ -540,7 +656,7 @@ class MemTransformerLM(nn.Module):
                         tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                         dropatt=dropatt, pre_lnorm=pre_lnorm)
                 )
-        elif attn_type in [2, 3]: # absolute embeddings
+        elif attn_type in [2, 3]: # absolute embeddings     # 未执行
             for i in range(n_layer):
                 self.layers.append(
                     DecoderLayer(
@@ -550,7 +666,7 @@ class MemTransformerLM(nn.Module):
 
         self.sample_softmax = sample_softmax
         # use sampled softmax
-        if sample_softmax > 0:
+        if sample_softmax > 0:    # 不执行
             self.out_layer = nn.Linear(d_model, n_token)
             if tie_weight:
                 self.out_layer.weight = self.word_emb.weight
@@ -558,10 +674,10 @@ class MemTransformerLM(nn.Module):
             self.sampler = LogUniformSampler(n_token, sample_softmax)
 
         # use adaptive softmax (including standard softmax)
-        else:
+        else:     # 执行
             self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 
                                                     cutoffs, div_val=div_val)
-
+            
             if tie_weight:
                 for i in range(len(self.crit.out_layers)):
                     self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
@@ -582,8 +698,8 @@ class MemTransformerLM(nn.Module):
         self.sample_softmax = -1
 
     def _create_params(self):
-        if self.attn_type == 0: # default attention
-            self.pos_emb = PositionalEmbedding(self.d_model)
+        if self.attn_type == 0: # default attention     # 执行           
+            self.pos_emb = PositionalEmbedding(self.d_model)           
             self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
             self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
         elif self.attn_type == 1: # learnable
@@ -598,8 +714,8 @@ class MemTransformerLM(nn.Module):
         elif self.attn_type == 3: # absolute deeper SA
             self.r_emb = nn.Parameter(torch.Tensor(
                     self.n_layer, self.max_klen, self.n_head, self.d_head))
-
-    def reset_length(self, tgt_len, ext_len, mem_len):
+        
+    def reset_length(self, tgt_len, ext_len, mem_len):        
         self.tgt_len = tgt_len
         self.mem_len = mem_len
         self.ext_len = ext_len
@@ -619,7 +735,7 @@ class MemTransformerLM(nn.Module):
     def _update_mems(self, hids, mems, qlen, mlen):
         # does not deal with None
         if mems is None: return None
-
+        
         # mems is not None
         assert len(hids) == len(mems), 'len(hids) != len(mems)'
 
@@ -633,7 +749,6 @@ class MemTransformerLM(nn.Module):
             end_idx = mlen + max(0, qlen - 0 - self.ext_len)
             beg_idx = max(0, end_idx - self.mem_len)
             for i in range(len(hids)):
-
                 cat = torch.cat([mems[i], hids[i]], dim=0)
                 new_mems.append(cat[beg_idx:end_idx].detach())
 
@@ -641,9 +756,7 @@ class MemTransformerLM(nn.Module):
 
     def _forward(self, dec_inp, mems=None):
         qlen, bsz = dec_inp.size()
-
         word_emb = self.word_emb(dec_inp)
-
         mlen = mems[0].size(0) if mems is not None else 0
         klen = mlen + qlen
         if self.same_length:
@@ -663,19 +776,21 @@ class MemTransformerLM(nn.Module):
         if self.attn_type == 0: # default
             pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 
                                    dtype=word_emb.dtype)
+
             if self.clamp_len > 0:
                 pos_seq.clamp_(max=self.clamp_len)
             pos_emb = self.pos_emb(pos_seq)
 
             core_out = self.drop(word_emb)
             pos_emb = self.drop(pos_emb)
-
-            hids.append(core_out)
-            for i, layer in enumerate(self.layers):
+        
+            hids.append(core_out) 
+            for i, layer in enumerate(self.layers):               
                 mems_i = None if mems is None else mems[i]
                 core_out = layer(core_out, pos_emb, self.r_w_bias,
                         self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                 hids.append(core_out)
+            
         elif self.attn_type == 1: # learnable
             core_out = self.drop(word_emb)
             hids.append(core_out)
@@ -727,7 +842,7 @@ class MemTransformerLM(nn.Module):
                 core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                  mems=mems_i)
                 hids.append(core_out)
-
+        pdb.set_trace()
         core_out = self.drop(core_out)
 
         new_mems = self._update_mems(hids, mems, mlen, qlen)
@@ -735,13 +850,18 @@ class MemTransformerLM(nn.Module):
         return core_out, new_mems
 
     def forward(self, data, target, *mems):
+        
         # nn.DataParallel does not allow size(0) tensors to be broadcasted.
         # So, have to initialize size(0) mems inside the model forward.
         # Moreover, have to return new_mems to allow nn.DataParallel to piece
         # them together.
+        
         if not mems: mems = self.init_mems()
-
-        tgt_len = target.size(0)
+        
+        if torch.onnx.is_in_onnx_export():
+            tgt_len = target.size(0).numpy()
+        else:
+            tgt_len = target.size(0)
         hidden, new_mems = self._forward(data, mems=mems)
 
         pred_hid = hidden[-tgt_len:]
@@ -751,8 +871,11 @@ class MemTransformerLM(nn.Module):
                 self.out_layer.bias, target, pred_hid, self.sampler)
             loss = -F.log_softmax(logit, -1)[:, :, 0]
         else:
-            loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
-            loss = loss.view(tgt_len, -1)
+            if torch.onnx.is_in_onnx_export():
+                loss = self.crit(pred_hid.reshape(-1, pred_hid.size(-1)), target.reshape(-1))
+            else:
+                loss = self.crit(pred_hid.reshape(-1, pred_hid.size(-1)), target.reshape(-1))
+                loss = loss.reshape(tgt_len, -1)
 
         if new_mems is None:
             return [loss]
@@ -762,7 +885,7 @@ class MemTransformerLM(nn.Module):
 if __name__ == '__main__':
     import argparse
 
-    parser = argparse.ArgumentParser(description='unit test')
+    parser = argparse.ArgumentParser(description='unit test.py')
 
     parser.add_argument('--n_layer', type=int, default=4, help='')
     parser.add_argument('--n_rel_layer', type=int, default=4, help='')
@@ -774,7 +897,7 @@ if __name__ == '__main__':
     parser.add_argument('--dropout', type=float, default=0.0, help='')
     parser.add_argument('--cuda', action='store_true', help='')
     parser.add_argument('--seed', type=int, default=1111, help='')
-    parser.add_argument('--multi_gpu', action='store_true', help='')
+    parser.add_argument('--multi_gpu',  action='store_true', help='')
 
     args = parser.parse_args()
 
@@ -801,12 +924,12 @@ if __name__ == '__main__':
                             d_embed=d_embed, div_val=div_val, 
                             tie_projs=tie_projs, pre_lnorm=True,
                             tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 
-                            cutoffs=cutoffs, attn_type=0).to(device)
-
-            print(sum(p.numel() for p in model.parameters()))
+                            cutoffs=cutoffs, attn_type=0)
+    
+            #print(sum(p.numel() for p in model.parameters()))
 
             mems = tuple()
             for idx, (inp, tgt, seqlen) in enumerate(diter):
-                print('batch {}'.format(idx))
+                #print('batch {}'.format(idx))
                 out = model(inp, tgt, *mems)
                 mems = out[1:]
diff --git a/pytorch/utils/adaptive_softmax.py b/pytorch/utils/adaptive_softmax.py
index 68ae016..68c59f8 100644
--- a/pytorch/utils/adaptive_softmax.py
+++ b/pytorch/utils/adaptive_softmax.py
@@ -67,6 +67,10 @@ class AdaptiveLogSoftmax(nn.Module):
             head_logprob_i = head_logprob.index_select(0, indices_i)
 
             if i == 0:
+
+                print(f'target_i[:,None]: {target_i[:, None]}')
+                print(f'target_i[:,None].shape: {target_i[:, None].shape}')
+
                 logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
             else:
                 weight_i = weight[l_idx:h_idx]
@@ -77,6 +81,8 @@ class AdaptiveLogSoftmax(nn.Module):
                 tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
                 tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
 
+                print(f'target_i[:,None]: {target_i[:, None]}')
+                print(f'target_i[:,None].shape: {target_i[:, None].shape}')
                 logprob_i = head_logprob_i[:, -i] \
                           + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
 
diff --git a/pytorch/utils/data_parallel.py b/pytorch/utils/data_parallel.py
index d7e1811..dd4041e 100644
--- a/pytorch/utils/data_parallel.py
+++ b/pytorch/utils/data_parallel.py
@@ -68,6 +68,12 @@ class BalancedDataParallel(DataParallel):
         if self.gpu0_bsz == 0:
             replicas = replicas[1:]
         outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
+
+        print(f'outputs: {outputs}')
+        print(f'type(outputs): {type(outputs)}')
+        print(f'len(outputs): {len(outputs)}')
+        print(f'self.output_device: {self.output_device}')
+
         return self.gather(outputs, self.output_device)
 
     def parallel_apply(self, replicas, device_ids, inputs, kwargs):
diff --git a/pytorch/utils/exp_utils.py b/pytorch/utils/exp_utils.py
index e44f7c2..568c3e6 100644
--- a/pytorch/utils/exp_utils.py
+++ b/pytorch/utils/exp_utils.py
@@ -1,8 +1,6 @@
 import functools
-import os, shutil
-
-import numpy as np
-
+import os
+import shutil
 import torch
 
 
@@ -10,31 +8,18 @@ def logging(s, log_path, print_=True, log_=True):
     if print_:
         print(s)
     if log_:
-        with open(log_path, 'a+') as f_log:
+        with open(log_path, 'a+') as f_log:   
             f_log.write(s + '\n')
 
-def get_logger(log_path, **kwargs):
-    return functools.partial(logging, log_path=log_path, **kwargs)
 
-def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
-    if debug:
-        print('Debug Mode : no experiment dir created')
-        return functools.partial(logging, log_path=None, log_=False)
-
-    if not os.path.exists(dir_path):
-        os.makedirs(dir_path)
+def get_logger(log_path, **kwargs):   
+    return functools.partial(logging, log_path=log_path, **kwargs)  
 
+def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
     print('Experiment dir : {}'.format(dir_path))
-    if scripts_to_save is not None:
-        script_path = os.path.join(dir_path, 'scripts')
-        if not os.path.exists(script_path):
-            os.makedirs(script_path)
-        for script in scripts_to_save:
-            dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
-            shutil.copyfile(script, dst_file)
+    return get_logger('log.txt')
 
-    return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
 
 def save_checkpoint(model, optimizer, path, epoch):
-    torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
-    torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
+    torch.save(model, 'model_{}.pt'.format(epoch))
+    torch.save(optimizer.state_dict(), 'optimizer_{}.pt'.format(epoch))
diff --git a/pytorch/utils/log_uniform_sampler.py b/pytorch/utils/log_uniform_sampler.py
index 503f635..e1a631c 100644
--- a/pytorch/utils/log_uniform_sampler.py
+++ b/pytorch/utils/log_uniform_sampler.py
@@ -78,41 +78,6 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
     return logits
 
 
-# class LogUniformSampler(object):
-#     def __init__(self, range_max, unique=False):
-#         """
-#         Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
-#             `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
-#         """
-#         self.range_max = range_max
-#         log_indices = torch.arange(1., range_max+2., 1.).log_()
-#         self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
-
-#         self.unique = unique
-
-#         if self.unique:
-#             self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
-
-#     def sample(self, n_sample, labels):
-#         pos_sample, new_labels = labels.unique(return_inverse=True)
-#         n_pos_sample = pos_sample.size(0)
-#         n_neg_sample = n_sample - n_pos_sample
-
-#         if self.unique:
-#             self.exclude_mask.index_fill_(0, pos_sample, 1)
-#             sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
-#             self.exclude_mask.index_fill_(0, pos_sample, 0)
-#         else:
-#             sample_dist = self.dist
-
-#         neg_sample = torch.multinomial(sample_dist, n_neg_sample)
-
-#         sample = torch.cat([pos_sample, neg_sample])
-#         sample_prob = self.dist[sample]
-
-#         return new_labels, sample, sample_prob
-
-
 if __name__ == '__main__':
     S, B = 3, 4
     n_vocab = 10000
@@ -121,20 +86,7 @@ if __name__ == '__main__':
 
     labels = torch.LongTensor(S, B).random_(0, n_vocab)
 
-    # sampler = LogUniformSampler(n_vocab, unique=False)
-    # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
-
     sampler = LogUniformSampler(n_vocab, unique=True)
-    # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
-
-    # print('true_probs', true_probs.numpy().tolist())
-    # print('samp_probs', samp_probs.numpy().tolist())
-    # print('neg_samples', neg_samples.numpy().tolist())
-
-    # print('sum', torch.sum(sampler.dist).item())
-
-    # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
-
     embedding = nn.Embedding(n_vocab, H)
     bias = torch.zeros(n_vocab)
     inputs = torch.Tensor(S, B, H).normal_()
diff --git a/pytorch/utils/proj_adaptive_softmax.py b/pytorch/utils/proj_adaptive_softmax.py
index a0fbfeb..941dc46 100644
--- a/pytorch/utils/proj_adaptive_softmax.py
+++ b/pytorch/utils/proj_adaptive_softmax.py
@@ -9,11 +9,11 @@ import torch.nn.functional as F
 CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
 CUDA_MINOR = int(torch.version.cuda.split('.')[1])
 
+
 class ProjectedAdaptiveLogSoftmax(nn.Module):
     def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
                  keep_order=False):
         super(ProjectedAdaptiveLogSoftmax, self).__init__()
-
         self.n_token = n_token
         self.d_embed = d_embed
         self.d_proj = d_proj
@@ -83,8 +83,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
         if self.n_clusters == 0:
             logit = self._compute_logit(hidden, self.out_layers[0].weight,
                                         self.out_layers[0].bias, self.out_projs[0])
+
             nll = -F.log_softmax(logit, dim=-1) \
                     .gather(1, target.unsqueeze(1)).squeeze(1)
+
         else:
             # construct weights and biases
             weights, biases = [], []
diff --git a/pytorch/utils/vocabulary.py b/pytorch/utils/vocabulary.py
index b6b8249..2cb2091 100644
--- a/pytorch/utils/vocabulary.py
+++ b/pytorch/utils/vocabulary.py
@@ -1,8 +1,8 @@
 import os
 from collections import Counter, OrderedDict
-
 import torch
 
+
 class Vocab(object):
     def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
                  delimiter=None, vocab_file=None):
@@ -14,6 +14,7 @@ class Vocab(object):
         self.delimiter = delimiter
         self.vocab_file = vocab_file
 
+
     def tokenize(self, line, add_eos=False, add_double_eos=False):
         line = line.strip()
         # convert to lower case
@@ -33,10 +34,13 @@ class Vocab(object):
         else:
             return symbols
 
+
+
     def count_file(self, path, verbose=False, add_eos=False):
         if verbose: print('counting file {} ...'.format(path))
         assert os.path.exists(path)
 
+        # if not verbose
         sents = []
         with open(path, 'r', encoding='utf-8') as f:
             for idx, line in enumerate(f):
@@ -48,6 +52,7 @@ class Vocab(object):
 
         return sents
 
+
     def count_sents(self, sents, verbose=False):
         """
             sents : a list of sentences, each a list of tokenized symbols