from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
class QREmbeddingBag(nn.Module):
r"""Computes sums or means over two 'bags' of embeddings, one using the quotient
of the indices and the other using the remainder of the indices, without
instantiating the intermediate embeddings, then performs an operation to combine these.
For bags of constant length and no :attr:`per_sample_weights`, this class
* with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
* with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
* with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
operations.
QREmbeddingBag also supports per-sample weights as an argument to the forward
pass. This scales the output of the Embedding before performing a weighted
reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
:attr:`per_sample_weights`.
Known Issues:
Autograd breaks with multiple GPUs. It breaks only with multiple embeddings.
Args:
num_categories (int): total number of unique categories. The input indices must be in
0, 1, ..., num_categories - 1.
embedding_dim (list): list of sizes for each embedding vector in each table. If ``"add"``
or ``"mult"`` operation are used, these embedding dimensions must be
the same. If a single embedding_dim is used, then it will use this
embedding_dim for both embedding tables.
num_collisions (int): number of collisions to enforce.
operation (string, optional): ``"concat"``, ``"add"``, or ``"mult". Specifies the operation
to compose embeddings. ``"concat"`` concatenates the embeddings,
``"add"`` sums the embeddings, and ``"mult"`` multiplies
(component-wise) the embeddings.
Default: ``"mult"``
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
Note: this option is not supported when ``mode="max"``.
mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
into consideration. ``"mean"`` computes the average of the values
in the bag, ``"max"`` computes the max value over each bag.
Default: ``"mean"``
sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
Notes for more details regarding sparse gradients. Note: this option is not
supported when ``mode="max"``.
Attributes:
weight (Tensor): the learnable weights of each embedding table is the module of shape
`(num_embeddings, embedding_dim)` initialized using a uniform distribution
with sqrt(1 / num_categories).
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
:attr:`per_index_weights` (Tensor, optional)
- If :attr:`input` is 2D of shape `(B, N)`,
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
:attr:`offsets` is ignored and required to be ``None`` in this case.
- If :attr:`input` is 1D of shape `(N)`,
it will be treated as a concatenation of multiple bags (sequences).
:attr:`offsets` is required to be a 1D tensor containing the
starting index positions of each bag in :attr:`input`. Therefore,
for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
having ``B`` bags. Empty bags (i.e., having 0-length) will have
returned vectors filled by zeros.
per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
must have exactly the same shape as input and is treated as having the same
:attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
Output shape: `(B, embedding_dim)`
"""
__constants__ = ['num_categories', 'embedding_dim', 'num_collisions',
'operation', 'max_norm', 'norm_type', 'scale_grad_by_freq',
'mode', 'sparse']
def __init__(self, num_categories, embedding_dim, num_collisions,
operation='mult', max_norm=None, norm_type=2.,
scale_grad_by_freq=False, mode='mean', sparse=False,
_weight=None):
super(QREmbeddingBag, self).__init__()
assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'
self.num_categories = num_categories
if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
self.embedding_dim = [embedding_dim, embedding_dim]
else:
self.embedding_dim = embedding_dim
self.num_collisions = num_collisions
self.operation = operation
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if self.operation == 'add' or self.operation == 'mult':
assert self.embedding_dim[0] == self.embedding_dim[1], \
'Embedding dimensions do not match!'
self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
num_collisions]
if _weight is None:
self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
self.reset_parameters()
else:
assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
self.weight_q = Parameter(_weight[0])
self.weight_r = Parameter(_weight[1])
self.mode = mode
self.sparse = sparse
def reset_parameters(self):
nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories))
nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories))
def forward(self, input, offsets=None, per_sample_weights=None):
input_q = (input / self.num_collisions).long()
input_r = torch.remainder(input, self.num_collisions).long()
embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode,
self.sparse, per_sample_weights)
if self.operation == 'concat':
embed = torch.cat((embed_q, embed_r), dim=1)
elif self.operation == 'add':
embed = embed_q + embed_r
elif self.operation == 'mult':
embed = embed_q * embed_r
return embed
def extra_repr(self):
s = '{num_embeddings}, {embedding_dim}'
if self.max_norm is not None:
s += ', max_norm={max_norm}'
if self.norm_type != 2:
s += ', norm_type={norm_type}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
s += ', mode={mode}'
return s.format(**self.__dict__)