"""Library for implementing cascade (sequences) of different neural modules.
Authors
* Peter Plantinga 2020
"""
import torch
import inspect
import logging
import operator
import functools
from speechbrain.nnet.linear import Linear
from speechbrain.utils.callchains import lengths_arg_exists
logger = logging.getLogger(__name__)
class Sequential(torch.nn.ModuleDict):
"""A sequence of modules with potentially inferring shape on construction.
If layers are passed with names, these can be referenced with dot notation.
Arguments
---------
input_shape : iterable
A list or tuple of ints or None, representing the expected shape of an
input tensor. None represents a variable-length dimension. If no
``input_shape`` is passed, no shape inference will be performed.
*layers, **named_layers
The inputs are treated as a list of layers to be
applied in sequence. The output shape of each layer is used to
infer the shape of the following layer. If a tuple is returned,
only the shape of the first element is used to determine input
shape of the next layer (e.g. RNN returns output, hidden).
Example
-------
>>> inputs = torch.rand(10, 40, 50)
>>> model = Sequential(input_shape=inputs.shape)
>>> model.append(Linear, n_neurons=100, layer_name="layer1")
>>> model.append(Linear, n_neurons=200, layer_name="layer2")
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([10, 40, 200])
>>> outputs = model.layer1(inputs)
>>> outputs.shape
torch.Size([10, 40, 100])
"""
def __init__(self, *layers, input_shape=None, **named_layers):
super().__init__()
if not layers and input_shape is None and not named_layers:
raise ValueError("Must pass either layers or input shape")
self.length_layers = []
self.input_shape = input_shape
if input_shape and None in input_shape:
self.input_shape = list(input_shape)
for i, dim in enumerate(self.input_shape):
if i == 0 and dim is None:
dim = 1
self.input_shape[i] = dim or 256
for layer in layers:
self.append(layer)
for name, layer in named_layers.items():
self.append(layer, layer_name=name)
def append(self, layer, *args, layer_name=None, **kwargs):
"""Add a layer to the list of layers, inferring shape if necessary.
Arguments
---------
layer : A torch.nn.Module class or object
If the layer is a class, it should accept an argument called
``input_shape`` which will be inferred and passed. If the layer
is a module object, it is added as-is.
layer_name : str
The name of the layer, for reference. If the name is in use,
``_{count}`` will be appended.
*args, **kwargs
These are passed to the layer if it is constructed.
"""
if layer_name is None:
layer_name = str(len(self))
elif layer_name in self:
index = 0
while f"{layer_name}_{index}" in self:
index += 1
layer_name = f"{layer_name}_{index}"
if self.input_shape:
argspec = inspect.getfullargspec(layer)
if "input_shape" in argspec.args + argspec.kwonlyargs:
input_shape = self.get_output_shape()
layer = layer(*args, input_shape=input_shape, **kwargs)
try:
self.add_module(layer_name, layer)
except TypeError:
raise ValueError(
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"using `append()`."
)
def get_output_shape(self):
"""Returns expected shape of the output.
Computed by passing dummy input constructed with the
``self.input_shape`` attribute.
"""
with torch.no_grad():
dummy_input = torch.zeros(self.input_shape)
dummy_output = self(dummy_input)
return dummy_output.shape
def forward(self, x):
"""Applies layers in sequence, passing only the first element of tuples.
Arguments
---------
x : torch.Tensor
The input tensor to run through the network.
"""
for layer in self.values():
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x
class LengthsCapableSequential(Sequential):
"""Sequential model that can take ``lengths`` in the forward method.
This is useful for Sequential models that include RNNs where it is
important to avoid padding, or for some feature normalization layers.
Unfortunately, this module is not jit-able because the compiler doesn't
know ahead of time if the length will be passed, and some layers don't
accept the length parameter.
"""
def __init__(self, *args, **kwargs):
self.takes_lengths = []
super().__init__(*args, **kwargs)
def append(self, *args, **kwargs):
"""Add a layer to the list of layers, inferring shape if necessary.
"""
super().append(*args, **kwargs)
latest_forward_method = list(self.values())[-1].forward
self.takes_lengths.append(lengths_arg_exists(latest_forward_method))
def forward(self, x, lengths=None):
"""Applies layers in sequence, passing only the first element of tuples.
In addition, forward the ``lengths`` argument to all layers that accept
a ``lengths`` argument in their ``forward()`` method (e.g. RNNs).
Arguments
---------
x : torch.Tensor
The input tensor to run through the network.
lengths : torch.Tensor
The relative lengths of each signal in the tensor.
"""
for layer, give_lengths in zip(self.values(), self.takes_lengths):
if give_lengths:
x = layer(x, lengths=lengths)
else:
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x
class ModuleList(torch.nn.Module):
"""This class implements a wrapper to torch.nn.ModuleList with a forward()
method to forward all the layers sequentially.
For some pretrained model with the SpeechBrain older implementation of
Sequential class, user can use this class to load those pretrained models
Arguments
---------
*layers : torch class
Torch objects to be put in a ModuleList.
"""
def __init__(self, *layers):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
"""Applies the computation pipeline."""
for layer in self.layers:
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x
def append(self, module):
"""Appends module to the layers list."""
self.layers.append(module)
def extend(self, modules):
"""Appends module to the layers list."""
self.layers.extend(modules)
def insert(self, index, module):
"""Inserts module to the layers list."""
self.layers.insert(module)
class ConnectBlocks(torch.nn.Module):
"""Connect a sequence of blocks with shortcut connections.
Note: all shortcuts start from the output of the first block,
since the first block may change the shape significantly.
Arguments
---------
input_shape : tuple
The shape of the
shortcut_type : str
One of:
* "residual" - first block output passed to final output,
* "dense" - input of each block is from all previous blocks,
* "skip" - output of each block is passed to final output.
shortcut_projection : bool
Only has an effect if `shortcut_type` is passed. Whether to add a
linear projection layer to the shortcut connection before combining
with the output, to handle different sizes.
shortcut_combine_fn : str or function
Either a pre-defined function (one of "add", "sub", "mul", "div",
"avg", "cat") or a user-defined function that takes the shortcut
and next input, and combines them, as well as `init_params`
in case parameters need to be initialized inside of the function.
Example
-------
>>> inputs = torch.rand(10, 100, 20)
>>> model = ConnectBlocks(
... input_shape=inputs.shape, shortcut_projection=True
... )
>>> model.append(Linear, n_neurons=10)
>>> model.append(Linear, n_neurons=10, end_of_block=True)
>>> model.append(Linear, n_neurons=10)
>>> model.append(Linear, n_neurons=10, end_of_block=True)
>>> outputs = model(inputs)
>>> outputs.shape
torch.Size([10, 100, 10])
"""
def __init__(
self,
input_shape,
shortcut_type="residual",
shortcut_projection=False,
shortcut_combine_fn=torch.add,
):
super().__init__()
self.first_input_shape = input_shape
self.block_input_shape = input_shape
self.new_block = True
self.blocks = torch.nn.ModuleList()
if shortcut_type not in ["residual", "dense", "skip"]:
raise ValueError(
"'shortcuts' must be one of 'residual', 'dense', or 'skip'"
)
self.shortcut_type = shortcut_type
self.shortcut_projection = shortcut_projection
if shortcut_projection:
self.projections = torch.nn.ModuleList()
self.shortcut_combine_fn = shortcut_combine_fn
def append(self, layer, *args, **kwargs):
"""Appends the specified module to the shortcut model.
Arguments
---------
layer : torch.nn.Module class
This layer will get initialized with *args and **kwargs. Also,
the argument ``input_shape`` will be passed if the layer takes it.
*args, **kwargs
Passed unchanged to the layer **EXCEPT** the kwarg ``end_of_block``
which is used to indicate that the shortcut should be added in.
"""
if self.new_block:
self.blocks.append(Sequential(input_shape=self.block_input_shape))
self.new_block = False
end_of_block = False
if "end_of_block" in kwargs:
end_of_block = kwargs["end_of_block"]
del kwargs["end_of_block"]
self.blocks[-1].append(layer, *args, **kwargs)
if end_of_block:
dummy_input = torch.zeros(self.block_input_shape)
dummy_output = self.blocks[-1](dummy_input)
if self.shortcut_projection:
projection_size = functools.reduce(
operator.mul, dummy_output.shape[2:], 1
)
if self.shortcut_type == "residual":
shape = self.first_input_shape
dummy_input = torch.zeros(self.first_input_shape)
else:
shape = self.block_input_shape
self.projections.append(
Linear(
n_neurons=projection_size,
input_shape=shape,
bias=False,
combine_dims=True,
)
)
self.new_block = True
dummy_output = self._combine(dummy_input, dummy_output, -1)
self.block_input_shape = dummy_output.shape
def forward(self, x):
"""
Arguments
---------
x : torch.Tensor
The inputs to the replicated modules.
"""
shortcut = x
for i, block in enumerate(self.blocks):
x = block(x)
if self.shortcut_type == "skip":
shortcut = self._combine(shortcut, x, i)
if self.shortcut_type == "dense":
x = shortcut = self._combine(shortcut, x, i)
if self.shortcut_type == "residual":
x = self._combine(shortcut, x, i)
if self.shortcut_type == "skip":
return shortcut
else:
return x
def _combine(self, shortcut, x, block_index=0):
"""Handle combining shortcut with outputs."""
if self.shortcut_projection:
shortcut = self.projections[block_index](shortcut)
shortcut = shortcut.reshape(x.shape)
return self.shortcut_combine_fn(shortcut, x)