import time
from easydict import EasyDict as edict
import tensorflow as tf
model_cfg = edict()
model_cfg.loss_mode = "batch"
LOSS_OP_NAME = "loss"
LABEL_OP_NAME = "label"
VAR_LIST = "variable"
PRED_OP_NAME = "pred"
class MyModel:
def __init__(self):
self.kernel_init = None
self._loss_fn = None
self.is_training = None
def build_model(self, model_args):
wide_embedding = model_args.wide_embedding
deep_embedding = model_args.deep_embedding
label = model_args.label
is_training = model_args.is_training
seed = model_args.seed
dropout_rate = model_args.dropout_rate
batch_norm = model_args.batch_norm
with tf.variable_scope("wide_deep", reuse=tf.AUTO_REUSE):
self._loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
self.is_training = is_training
batch_size, wide_num, wide_emb_dim = wide_embedding.shape
wide_input = tf.reshape(wide_embedding[:, :, 0], shape=(batch_size, wide_num * 1))
wide_output = tf.reshape(tf.reduce_sum(wide_input, axis=1), shape=(-1, 1))
batch_size, deep_num, deep_emb_dim = deep_embedding.shape
deep_input = tf.reshape(deep_embedding, shape=(batch_size, deep_num * deep_emb_dim))
hidden_units = [256, 128, 64]
net = deep_input
for i, unit in enumerate(hidden_units):
net = tf.layers.dense(net, units=unit, activation='relu', name=f'hidden_layer_{i}',
kernel_initializer=tf.glorot_uniform_initializer(seed=seed),
bias_initializer=tf.zeros_initializer())
if dropout_rate is not None and 0.0 < dropout_rate < 1.0:
net = tf.layers.dropout(net, dropout_rate, training=self.is_training)
if batch_norm:
net = tf.layers.batch_normalization(net, training=self.is_training)
deep_output = tf.layers.dense(net, units=1, activation=None, name='deep_output',
kernel_initializer=tf.glorot_uniform_initializer(seed=seed),
bias_initializer=tf.zeros_initializer())
total_logits = 0.5 * tf.add(wide_output, deep_output, name='total_logits')
loss = self._loss_fn(label, total_logits)
prediction = tf.sigmoid(total_logits)
trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='wide_deep')
return {LOSS_OP_NAME: loss,
PRED_OP_NAME: prediction,
LABEL_OP_NAME: label,
VAR_LIST: trainable_variables}
my_model = MyModel()