989d7c43创建于 2025年4月23日历史提交
import warnings

import torch
import torch_npu

warnings.filterwarnings(action='once', category=FutureWarning)


class BiLSTM(torch.nn.Module):
    r"""Applies an NPU compatible bidirectional LSTM operation to an input
    sequence.

    The implementation of this BidirectionalLSTM is mainly based on the principle of bidirectional LSTM.
    Since NPU do not support the parameter bidirectional in torch.nn.lstm to be True,
    we reimplement it by joining two unidirection LSTM together to form a bidirectional LSTM

    Paper: [Bidirectional recurrent neural networks]

    Args:
        input_size: The number of expected features in the input `x`
        hidden_size: The number of features in the hidden state `h`


    Inputs: input, (h_0, c_0)
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence.
          The input can also be a packed variable length sequence.
          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
          :func:`torch.nn.utils.rnn.pack_sequence` for details.
        - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          If the LSTM is bidirectional, num_directions should be 2, else it should be 1.
        - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the initial cell state for each element in the batch.

          If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.


    Outputs: output, (h_n, c_n)
        - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
          containing the output features `(h_t)` from the last layer of the LSTM,
          for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
          given as the input, the output will also be a packed sequence.

          For the unpacked case, the directions can be separated
          using ``output.view(seq_len, batch, num_directions, hidden_size)``,
          with forward and backward being direction `0` and `1` respectively.
          Similarly, the directions can be separated in the packed case.
        - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the hidden state for `t = seq_len`.

          Like *output*, the layers can be separated using
          ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
        - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
          containing the cell state for `t = seq_len`.


    Examples::
        >>> r = BiLSTM(512, 256)
        >>> input_tensor = torch.randn(26, 2560, 512)
        >>> output = r(input_tensor)
    """
    def __init__(self, input_size, hidden_size):
        super(BiLSTM, self).__init__()

        warnings.warn("torch_npu.contrib.BiLSTM is deprecated. "
                      "Please check document for replacement.", FutureWarning)
        self.fw_rnn = torch.nn.LSTM(input_size, hidden_size, bidirectional=False)
        self.bw_rnn = torch.nn.LSTM(input_size, hidden_size, bidirectional=False)

    def forward(self, inputs):
        input_fw = inputs
        recurrent_fw, _ = self.fw_rnn(input_fw)
        input_bw = torch.flip(inputs, [0])
        recurrent_bw, _ = self.bw_rnn(input_bw)
        recurrent_bw = torch.flip(recurrent_bw, [0])
        recurrent = torch.cat((recurrent_fw, recurrent_bw), 2)

        return recurrent