e84b8a85创建于 2018年1月13日历史提交
#!/usr/bin/env python
#-*-coding:utf-8-*-

from tensorflow.contrib.crf import viterbi_decode

import numpy as np

# def decode(logits, trans, lengths, num_tags):
#     small = -1000
#
#     start = np.asarray([[small] * num_tags + [0 , small]])
#     end = np.asarray([[small] * num_tags + [small, 0]])
#     path = []
#
#     for logit, length in zip(logits,lengths):
#         logit = logit[ :length]
#         pad = small * np.ones([length, 2])
#         logit = np.concatenate([logit, pad], axis=-1)
#
#         logit = np.concatenate([start, logit, end], axis = 0)
#         viterbi , viterbi_score = viterbi_decode(logit, trans)
#         path.append(np.array(viterbi[1: -1]))
#
#     return path


def vdecode(logits, trans, sequence_lengths, tag_num):
    viterbi_sequences = []
    small = -1000.0
    start = np.asarray([[small] * tag_num + [0]])
    for logit, length in zip(logits, sequence_lengths):
        score = logit[:length]
        pad = small * np.ones([length, 1])
        score = np.concatenate([score, pad], axis=1)
        score = np.concatenate([start, score], axis=0)
        viterbi_seq, viterbi_score = viterbi_decode(score, trans)
        viterbi_sequences.append(viterbi_seq[1:])

    return viterbi_sequences