import os
import time
from glob import glob
from typing import Dict, List, Optional, Tuple, Union
import tensorflow as tf
import numpy as np
from sklearn.metrics import roc_auc_score
from npu_bridge.hccl import hccl_ops
from npu_bridge.npu_init import util
import mxrec
from rec_sdk_common.communication.hccl.hccl_info import get_rank_size
from model import DLRM, LOSS_OP_NAME, LABEL_OP_NAME, PRED_OP_NAME, SPARSE_FEATURE_LEN, DENSE_FEATURE_LEN
from config import Config, sess_config
from optimizer import get_dense_and_sparse_optimizer
from logger import logger
class TaskRunner:
def __init__(
self,
config: Config,
toml_config: Dict[str, Union[str, int, float]],
seed: Optional[int] = None,
):
self._config = config
self._toml_config = toml_config
self._seed = seed
self._sess = tf.compat.v1.Session(config=sess_config(dump_data=False, is_deterministic=config.deterministic))
def __del__(self):
self._sess.close()
def run(self):
train_model, train_iterator = self._model_forward(is_training=True)
eval_model, eval_iterator = self._model_forward(is_training=False)
train_ops = self._get_train_ops(train_model)
with tf.control_dependencies(train_ops):
train_ops = tf.no_op()
self._config.learning_rate = [self._config.learning_rate[0], self._config.learning_rate[1]]
self._sess.run(mxrec.get_init_hashtable_op())
self._sess.run(tf.compat.v1.global_variables_initializer())
self._sess.run(train_iterator.initializer)
cost_sum = 0
best_auc = 0
iteration_per_loop = 10
train_ops = util.set_iteration_per_loop(self._sess, train_ops, iteration_per_loop)
start_step = 1
for i in range(start_step, start_step + self._config.train_steps):
logger.info("################ training at step %d ################", i * iteration_per_loop)
start_time = time.time()
try:
_, loss = self._sess.run([train_ops, train_model.get(LOSS_OP_NAME)])
lr = self._sess.run(self._config.learning_rate)
global_step = self._sess.run(self._config.global_step)
except tf.errors.OutOfRangeError:
logger.info("Encounter the end of Sequence for training.")
break
end_time = time.time()
cost_time = end_time - start_time
qps = (1 / cost_time) * self._config.rank_size * self._config.batch_size * iteration_per_loop
cost_sum += cost_time
logger.info("Training step: %d; training loss: %s.", i * iteration_per_loop, loss)
logger.info("Training step: %d; lr: %s.", i * iteration_per_loop, lr)
logger.info("Global step: %d", global_step)
logger.info("Training step: %d; current QPS: %s.", i * iteration_per_loop, qps)
if i % self._config.train_interval == 0:
test_auc, test_mean_log_loss = self._evaluate(eval_model, eval_iterator)
logger.info("Test auc: %s; log_loss: %s", test_auc, test_mean_log_loss)
best_auc = max(best_auc, test_auc)
logger.info("Training step: %d, best auc: %s", i * iteration_per_loop, best_auc)
logger.info("Training finished, the best auc is %s.", best_auc)
def _get_train_ops(self, train_model: Dict[str, tf.Tensor]) -> List[tf.Tensor]:
trainable_variables = []
sparse_variables = mxrec.get_sparse_embedding()
dense_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)
trainable_variables.extend(dense_variables)
trainable_variables.extend(sparse_variables)
train_ops = []
dense_optimizer, sparse_optimizer = get_dense_and_sparse_optimizer(self._config)
loss = train_model.get(LOSS_OP_NAME)
grads = tf.gradients(loss, trainable_variables)
grads_and_vars = [(grad, variable) for grad, variable in zip(grads, trainable_variables)]
dense_grads_and_vars = []
for grad, var in grads_and_vars[:-len(sparse_variables)]:
if grad is None:
continue
if get_rank_size() == 1:
dense_grads_and_vars.append((grad, var))
continue
grad = hccl_ops.allreduce(grad, "sum")
dense_grads_and_vars.append((grad / get_rank_size(), var))
train_ops.append(dense_optimizer.apply_gradients(dense_grads_and_vars))
sparse_grads_and_vars = list(grads_and_vars[-len(sparse_variables):])
train_ops.append(sparse_optimizer.apply_gradients(sparse_grads_and_vars))
train_ops.extend(
[
self._config.global_step.assign(self._config.global_step + 1),
self._config.learning_rate[0],
self._config.learning_rate[1],
]
)
return train_ops
def _evaluate(
self, eval_model: Dict[str, tf.Tensor], eval_iterator: tf.compat.v1.data.Iterator
) -> Tuple[float, np.ndarray]:
log_loss_list = []
pred_list = []
label_list = []
eval_current_steps = 0
finished = False
self._sess.run([eval_iterator.initializer])
while not finished:
try:
eval_current_steps += 1
eval_start = time.time()
eval_loss, pred, label = self._sess.run(
[eval_model.get(LOSS_OP_NAME), eval_model.get(PRED_OP_NAME), eval_model.get(LABEL_OP_NAME)]
)
eval_cost = time.time() - eval_start
qps_eval = (1 / eval_cost) * get_rank_size() * self._config.batch_size
log_loss_list += list(eval_loss.reshape(-1))
pred_list += list(pred.reshape(-1))
label_list += list(label.reshape(-1))
logger.info("Evaluation step: %d, qps: %s.", eval_current_steps, qps_eval)
if eval_current_steps == self._config.eval_steps:
finished = True
except tf.errors.OutOfRangeError:
logger.info("Encounter the end of Sequence for evaluation.")
finished = True
auc = roc_auc_score(label_list, pred_list)
mean_log_loss = np.mean(log_loss_list)
return auc, mean_log_loss
def _make_batch_and_iterator(
self, is_training: bool = True, num_parallel: int = 8
) -> Tuple[Dict[str, tf.Tensor], tf.compat.v1.data.Iterator]:
def _extract_fn(data_record: tf.Tensor):
features = {
"label": tf.compat.v1.FixedLenFeature(shape=(self._config.line_per_sample,), dtype=tf.int64),
"sparse_feature": tf.compat.v1.FixedLenFeature(
shape=(SPARSE_FEATURE_LEN * self._config.line_per_sample,), dtype=tf.int64
),
"dense_feature": tf.compat.v1.FixedLenFeature(
shape=(DENSE_FEATURE_LEN * self._config.line_per_sample,), dtype=tf.float32
),
}
sample = tf.compat.v1.parse_single_example(data_record, features)
return sample
def _reshape_fn(batch: Dict[str, tf.Tensor]):
batch["label"] = tf.reshape(batch["label"], [-1, 1])
batch["dense_feature"] = tf.reshape(batch["dense_feature"], [-1, DENSE_FEATURE_LEN])
batch["dense_feature"] = tf.math.log(tf.math.add(batch["dense_feature"], 3.0))
batch["sparse_feature"] = tf.reshape(batch["sparse_feature"], [-1, SPARSE_FEATURE_LEN])
return batch
if is_training:
files_list = glob(os.path.join(self._config.data_path, self._config.train_file_pattern) + "/*.tfrecord")
else:
files_list = glob(os.path.join(self._config.data_path, self._config.test_file_pattern) + "/*.tfrecord")
dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel)
batch_size = self._config.batch_size // self._config.line_per_sample
dataset = dataset.shard(self._config.rank_size, self._config.rank_id)
if is_training:
dataset = dataset.shuffle(buffer_size=batch_size * 1000, seed=self._seed)
if is_training:
dataset = dataset.repeat(self._config.train_epoch)
else:
dataset = dataset.repeat(self._config.test_epoch)
dataset = dataset.map(_extract_fn, num_parallel_calls=num_parallel).batch(batch_size, drop_remainder=True)
dataset = dataset.map(_reshape_fn, num_parallel_calls=num_parallel)
dataset = dataset.prefetch(self._config.prefetch_num)
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()
return batch, iterator
def _model_forward(self, is_training: bool = True) -> Tuple[Dict[str, tf.Tensor], tf.compat.v1.data.Iterator]:
batch, iterator = self._make_batch_and_iterator(is_training)
table = mxrec.get_embedding_table(
name="sparse_embedding_table",
dimension=self._config.emb_dim,
device_vocabulary_size=self._config.dev_vocab_size,
initializer=tf.compat.v1.variance_scaling_initializer(
mode="fan_avg", distribution="normal", seed=self._seed
),
key_dtype=self._config.key_type,
value_dtype=self._config.value_type,
)
ids = batch.get("sparse_feature")
embedding = mxrec.embedding_lookup(table, ids)
model_output = DLRM(self._toml_config).build_model(
embedding=embedding,
dense_feature=batch.get("dense_feature"),
label=batch.get("label"),
seed=self._seed,
)
return model_output, iterator