from typing import Optional, Dict, Union
import tensorflow as tf
LOSS_OP_NAME = "loss"
LABEL_OP_NAME = "label"
PRED_OP_NAME = "pred"
SPARSE_FEATURE_LEN = 26
DENSE_FEATURE_LEN = 13
class DLRM:
"""Deep Learning Recommendation Model (DLRM) class."""
def __init__(self, toml_config: Dict[str, Union[str, int, float]]):
self._variable_scope_name = "mlp"
self._l1_regularizer = toml_config["model"]["l1_regularizer"]
self._bottom_stack_dnn1_shape = toml_config["model"]["bottom_stack_dnn1_shape"]
self._bottom_stack_dnn2_shape = toml_config["model"]["bottom_stack_dnn2_shape"]
self._bottom_stack_dnn3_shape = toml_config["model"]["bottom_stack_dnn3_shape"]
self._top_stack_dnn1_shape = toml_config["model"]["top_stack_dnn1_shape"]
self._top_stack_dnn2_shape = toml_config["model"]["top_stack_dnn2_shape"]
self._top_stack_dnn3_shape = toml_config["model"]["top_stack_dnn3_shape"]
self._top_stack_dnn4_shape = toml_config["model"]["top_stack_dnn4_shape"]
self._top_stack_dnn5_shape = toml_config["model"]["top_stack_dnn5_shape"]
def build_model(
self,
embedding: Optional[tf.Tensor] = None,
dense_feature: Optional[tf.Tensor] = None,
label: Optional[tf.Tensor] = None,
seed: Optional[int] = None,
) -> Dict[str, tf.Tensor]:
"""Builds the DLRM model.
Args:
embedding (Optional[tf.Tensor]): Tensor containing the embedding features.
dense_feature (Optional[tf.Tensor]): Tensor containing the dense features.
label (Optional[tf.Tensor]): Tensor containing the labels.
seed (Optional[int]): Random seed for reproducibility.
Returns:
Dict[str, tf.Tensor]: A dictionary containing the loss, prediction, and label tensors.
"""
with tf.variable_scope(self._variable_scope_name, reuse=tf.AUTO_REUSE):
def _dot_interaction(input_tensor: tf.Tensor) -> tf.Tensor:
num_features = tf.shape(input_tensor)[1]
batch_size = tf.shape(input_tensor)[0]
xactions = tf.matmul(input_tensor, input_tensor, transpose_b=True)
ones = tf.ones_like(xactions, dtype=tf.float32)
upper_tri_mask = tf.linalg.band_part(ones, 0, -1)
mask = tf.cast(tf.cast(upper_tri_mask, tf.bool), tf.float32)
activations = mask * tf.zeros_like(xactions) + (1 - mask) * xactions
out_dim = num_features * num_features
activations = tf.reshape(activations, (batch_size, out_dim))
return activations
dense_embedding_vec = self._bottom_stack(dense_feature, seed)
dense_embedding = tf.expand_dims(dense_embedding_vec, 1)
interaction_args = tf.concat([dense_embedding, embedding], axis=1)
interaction_output = _dot_interaction(interaction_args)
feature_interaction_output = tf.concat([dense_embedding_vec, interaction_output], axis=1)
logits = self._top_stack(feature_interaction_output, seed)
label_float = tf.cast(label, dtype=tf.float32)
per_sample_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=label_float, logits=logits)
loss = tf.reduce_mean(per_sample_loss)
prediction = tf.sigmoid(logits)
return {LOSS_OP_NAME: loss, PRED_OP_NAME: prediction, LABEL_OP_NAME: label}
def _bottom_stack(self, input_tensor: tf.Tensor, seed: int) -> tf.Tensor:
dnn1 = tf.layers.dense(
input_tensor,
self._bottom_stack_dnn1_shape,
activation="relu",
name="bs1",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn2 = tf.layers.dense(
dnn1,
self._bottom_stack_dnn2_shape,
activation="relu",
name="bs2",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn3 = tf.layers.dense(
dnn2,
self._bottom_stack_dnn3_shape,
activation="relu",
name="bs3",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
return dnn3
def _top_stack(self, input_tensor: tf.Tensor, seed: int) -> tf.Tensor:
dnn1 = tf.layers.dense(
input_tensor,
self._top_stack_dnn1_shape,
activation="relu",
name="ts1",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn2 = tf.layers.dense(
dnn1,
self._top_stack_dnn2_shape,
activation="relu",
name="ts2",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn3 = tf.layers.dense(
dnn2,
self._top_stack_dnn3_shape,
activation="relu",
name="ts3",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn4 = tf.layers.dense(
dnn3,
self._top_stack_dnn4_shape,
activation="relu",
name="ts4",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
dnn5 = tf.layers.dense(
dnn4,
self._top_stack_dnn5_shape,
activation=None,
name="ts5",
use_bias=True,
kernel_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
bias_initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=seed
),
kernel_regularizer=tf.contrib.layers.l1_regularizer(self._l1_regularizer),
)
return dnn5