from typing import List
import tensorflow as tf
from config import Config
class Model:
def __init__(self):
self._layer_dims = [1024, 512, 256, 128]
self._emb_dim = None
self._loss_list = []
self._all_layer_dims = None
self._h_w = []
self._h_b = []
self._h_w_head_0 = None
self._h_w_head_1 = None
self._h_b_head_0 = None
self._h_b_head_1 = None
def __call__(self, embedding_list: List[tf.Tensor], label_0: tf.Tensor, label_1: tf.Tensor):
with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE):
embedding = tf.concat(embedding_list, axis=1)
self._emb_dim = embedding.shape.as_list()[-1]
self._all_layer_dims = [self._emb_dim] + self._layer_dims + [1]
for i in range(len(self._all_layer_dims) - 2):
self._h_w.append(
tf.compat.v1.get_variable(
"h%d_w" % (i + 1),
shape=self._all_layer_dims[i: i + 2],
initializer=tf.truncated_normal_initializer(-0.01, 0.01, Config.random_seed),
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"],
)
)
self._h_b.append(
tf.compat.v1.get_variable(
"h%d_b" % (i + 1),
shape=[self._all_layer_dims[i + 1]],
initializer=tf.zeros_initializer,
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"],
)
)
i += 1
self._h_w_head_0 = tf.compat.v1.get_variable(
"h_w_head_0",
shape=self._all_layer_dims[i: i + 2],
initializer=tf.truncated_normal_initializer(-0.01, 0.01, Config.random_seed),
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"],
)
self._h_b_head_0 = tf.compat.v1.get_variable(
"h_b_head_0",
shape=[self._all_layer_dims[i + 1]],
initializer=tf.zeros_initializer,
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"],
)
self._h_w_head_1 = tf.compat.v1.get_variable(
"h_w_head_1",
shape=self._all_layer_dims[i: i + 2],
initializer=tf.truncated_normal_initializer(-0.01, 0.01, Config.random_seed),
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"],
)
self._h_b_head_1 = tf.compat.v1.get_variable(
"h_b_head_1",
shape=[self._all_layer_dims[i + 1]],
initializer=tf.zeros_initializer,
dtype=tf.float32,
collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"],
)
logit_list = self._forward(embedding)
for logit, label in zip(logit_list, (label_0, label_1)):
basic_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=label)
deep_loss = tf.reduce_mean(basic_loss)
self._loss_list.append(deep_loss)
@property
def loss(self):
return tf.add_n(self._loss_list)
@property
def loss_list(self):
return self._loss_list
def _forward(self, embedding):
hidden_output = tf.reshape(embedding, [-1, self._emb_dim])
for i, h_w_var in enumerate(self._h_w):
hidden_output = tf.nn.relu(hidden_output)
hidden_output = tf.matmul(hidden_output, h_w_var)
hidden_output = hidden_output + self._h_b[i]
def _output_head(hidden_output, h_w, h_b):
hidden_output = tf.nn.relu(hidden_output)
hidden_output_branch = tf.matmul(hidden_output, h_w)
logit = hidden_output_branch + h_b
logit = tf.reshape(logit, [-1])
return logit
logit_0 = _output_head(hidden_output, self._h_w_head_0, self._h_b_head_0)
logit_1 = _output_head(hidden_output, self._h_w_head_1, self._h_b_head_1)
logit_list = [logit_0, logit_1]
return logit_list