import torch
try:
from dlrm.cuda_ext import dotBasedInteract
except:
pass
def padding_size(n: int) -> int:
nearest_multiple = ((n - 1) // 8 + 1) * 8
return nearest_multiple - n
class Interaction:
@property
def num_interactions(self) -> int:
raise NotImplementedError()
def interact(self, bottom_output, bottom_mlp_output):
"""
:param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
:param bottom_mlp_output
:return:
"""
raise NotImplementedError()
class DotInteraction(Interaction):
def __init__(self, embedding_num: int, embedding_dim: int):
"""
Interactions are among outputs of all the embedding tables and bottom MLP, total number of
(num_embedding_tables + 1) vectors with size embedding_dim. ``dot`` product interaction computes dot product
between any 2 vectors. Output of interaction will have shape [num_interactions, embedding_dim].
"""
self._num_interaction_inputs = embedding_num + 1
self._embedding_dim = embedding_dim
self._tril_indices = torch.tensor([[i for i in range(self._num_interaction_inputs)
for _ in range(i)],
[j for i in range(self._num_interaction_inputs)
for j in range(i)]]).cuda()
@property
def _raw_num_interactions(self) -> int:
return (self._num_interaction_inputs * (self._num_interaction_inputs - 1)) // 2 + self._embedding_dim
@property
def num_interactions(self) -> int:
n = self._raw_num_interactions
return n + padding_size(n)
def interact(self, bottom_output, bottom_mlp_output):
"""
:param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
:param bottom_mlp_output
:return:
"""
batch_size = bottom_output.size()[0]
interaction = torch.bmm(bottom_output, torch.transpose(bottom_output, 1, 2))
interaction_flat = interaction[:, self._tril_indices[0], self._tril_indices[1]]
padding_dim = padding_size(self._raw_num_interactions)
zeros_padding = torch.zeros(batch_size, padding_dim, dtype=bottom_output.dtype, device=bottom_output.device)
interaction_output = torch.cat(
(bottom_mlp_output, interaction_flat, zeros_padding), dim=1)
return interaction_output
class CudaDotInteraction(Interaction):
def __init__(self, dot_interaction: DotInteraction):
self._dot_interaction = dot_interaction
@property
def num_interactions(self):
return self._dot_interaction.num_interactions
def interact(self, bottom_output, bottom_mlp_output):
"""
:param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
:param bottom_mlp_output
:return:
"""
return dotBasedInteract(bottom_output, bottom_mlp_output)
class CatInteraction(Interaction):
def __init__(self, embedding_num: int, embedding_dim: int):
"""
Interactions are among outputs of all the embedding tables and bottom MLP, total number of
(num_embedding_tables + 1) vectors with size embdding_dim. ``cat`` interaction concatenate all the vectors
together. Output of interaction will have shape [num_interactions, embedding_dim].
"""
self._num_interaction_inputs = embedding_num + 1
self._embedding_dim = embedding_dim
@property
def num_interactions(self) -> int:
return self._num_interaction_inputs * self._embedding_dim
def interact(self, bottom_output, bottom_mlp_output):
"""
:param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
:param bottom_mlp_output
:return:
"""
return bottom_output.view(-1, self.num_interactions)