05360171创建于 2022年3月18日历史提交
diff --git a/egs/aishell/local/score.sh b/egs/aishell/local/score.sh
index 103fbea..542a7a0 100755
--- a/egs/aishell/local/score.sh
+++ b/egs/aishell/local/score.sh
@@ -14,7 +14,7 @@ fi
 dir=$1
 dic=$2
 
-json2trn.py ${dir}/data.json ${dic} ${dir}/ref.trn ${dir}/hyp.trn
+python3.7 ../../src/utils/json2trn.py ${dir}/data.json ${dic} ${dir}/ref.trn ${dir}/hyp.trn
 
 if [ ! -z ${nlsyms} ]; then
   cp ${dir}/ref.trn ${dir}/ref.trn.org
diff --git a/egs/aishell/run.sh b/egs/aishell/run.sh
index 90e91f9..5725699 100644
--- a/egs/aishell/run.sh
+++ b/egs/aishell/run.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 # -- IMPORTANT
-data=/home/work_nfs/common/data # Modify to your aishell data path
+data=/home/xy # Modify to your aishell data path
 stage=-1  # Modify to control start from witch stage
 # --
 
@@ -129,71 +129,3 @@ if [ $stage -le 2 ]; then
              data/$data ${dict} > ${feat_dir}/data.json
     done
 fi
-
-if [ -z ${tag} ]; then
-    expdir=exp/train_m${LFR_m}_n${LFR_n}_in${d_input}_elayer${n_layers_enc}_head${n_head}_k${d_k}_v${d_v}_model${d_model}_inner${d_inner}_drop${dropout}_pe${pe_maxlen}_emb${d_word_vec}_dlayer${n_layers_dec}_share${tgt_emb_prj_weight_sharing}_ls${label_smoothing}_epoch${epochs}_shuffle${shuffle}_bs${batch_size}_bf${batch_frames}_mli${maxlen_in}_mlo${maxlen_out}_k${k}_warm${warmup_steps}
-    if ${do_delta}; then
-        expdir=${expdir}_delta
-    fi
-else
-    expdir=exp/train_${tag}
-fi
-mkdir -p ${expdir}
-
-if [ ${stage} -le 3 ]; then
-    echo "stage 3: Network Training"
-    ${cuda_cmd} --gpu ${ngpu} ${expdir}/train.log \
-        train.py \
-        --train-json ${feat_train_dir}/data.json \
-        --valid-json ${feat_dev_dir}/data.json \
-        --dict ${dict} \
-        --LFR_m ${LFR_m} \
-        --LFR_n ${LFR_n} \
-        --d_input $d_input \
-        --n_layers_enc $n_layers_enc \
-        --n_head $n_head \
-        --d_k $d_k \
-        --d_v $d_v \
-        --d_model $d_model \
-        --d_inner $d_inner \
-        --dropout $dropout \
-        --pe_maxlen $pe_maxlen \
-        --d_word_vec $d_word_vec \
-        --n_layers_dec $n_layers_dec \
-        --tgt_emb_prj_weight_sharing $tgt_emb_prj_weight_sharing \
-        --label_smoothing ${label_smoothing} \
-        --epochs $epochs \
-        --shuffle $shuffle \
-        --batch-size $batch_size \
-        --batch_frames $batch_frames \
-        --maxlen-in $maxlen_in \
-        --maxlen-out $maxlen_out \
-        --k $k \
-        --warmup_steps $warmup_steps \
-        --save-folder ${expdir} \
-        --checkpoint $checkpoint \
-        --continue-from "$continue_from" \
-        --print-freq ${print_freq} \
-        --visdom $visdom \
-        --visdom_lr $visdom_lr \
-        --visdom_epoch $visdom_epoch \
-        --visdom-id "$visdom_id"
-fi
-
-if [ ${stage} -le 4 ]; then
-    echo "stage 4: Decoding"
-    decode_dir=${expdir}/decode_test_beam${beam_size}_nbest${nbest}_ml${decode_max_len}
-    mkdir -p ${decode_dir}
-    ${cuda_cmd} --gpu ${ngpu} ${decode_dir}/decode.log \
-        recognize.py \
-        --recog-json ${feat_test_dir}/data.json \
-        --dict $dict \
-        --result-label ${decode_dir}/data.json \
-        --model-path ${expdir}/final.pth.tar \
-        --beam-size $beam_size \
-        --nbest $nbest \
-        --decode-max-len $decode_max_len
-
-    # Compute CER
-    local/score.sh --nlsyms ${nlsyms} ${decode_dir} ${dict}
-fi
diff --git a/src/transformer/attention.py b/src/transformer/attention.py
index 8906fc0..e3eac88 100644
--- a/src/transformer/attention.py
+++ b/src/transformer/attention.py
@@ -49,7 +49,9 @@ class MultiHeadAttention(nn.Module):
         v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
 
         if mask is not None:
+            mask = mask.to(torch.int32)
             mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
+            mask = mask.to(torch.bool)
 
         output, attn = self.attention(q, k, v, mask=mask)
 
diff --git a/src/utils/utils.py b/src/utils/utils.py
index baa3bfa..c63b24e 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -2,10 +2,11 @@
 IGNORE_ID = -1
 
 
-def pad_list(xs, pad_value):
+def pad_list(xs, pad_value, max_len = None):
     # From: espnet/src/nets/e2e_asr_th.py: pad_list()
     n_batch = len(xs)
-    max_len = max(x.size(0) for x in xs)
+    if max_len is None:
+        max_len = max(x.size(0) for x in xs)
     pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
     for i in range(n_batch):
         pad[i, :xs[i].size(0)] = xs[i]