import os
import time
import warnings
import tensorflow as tf
from sklearn.metrics import roc_auc_score
from mpi4py import MPI
import numpy as np
from npu_bridge.npu_init import *
from config import Config
from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET
from mx_rec.core.asc.manager import start_asc_pipeline
from mx_rec.core.embedding import create_table, sparse_lookup
from mx_rec.core.feature_process import EvictHook
from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook
from mx_rec.constants.constants import ASCEND_TIMESTAMP
from mx_rec.util.initialize import ConfigInitializer, init, terminate_config_initializer
import rec_sdk_common
from mx_rec.util.variable import get_dense_and_sparse_variable
import examples.model_common as cm
from examples.model_common import (
sess_config, create_feature_spec_list, clear_saved_model, evaluate_fix, make_batch_and_iterator
)
from model import MyModel
from demo_logger import logger
from optimizer import get_dense_and_sparse_optimizer
npu_plugin.set_device_sat_mode(0)
DENSE_HASHTABLE_SEED = 128
SPARSE_HASHTABLE_SEED = 128
cm.MODEL_NAME = "MMOE"
cm.logger = logger
def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph):
embedding_list = []
logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, "
f"hash_table_list: {len(hash_table_list)}")
for feature, hash_table in zip(feature_list, hash_table_list):
if cm.MODIFY_GRAPH_FLAG:
feature = batch["sparse_feature"]
embedding = sparse_lookup(hash_table, feature, cfg.send_count, dim=None, is_train=is_train,
name="user_embedding_lookup", modify_graph=modify_graph, batch=batch,
access_and_evict_config=None)
embedding_list.append(embedding)
if len(embedding_list) == 1:
emb = embedding_list[0]
elif len(embedding_list) > 1:
emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False)
else:
raise ValueError("the length of embedding_list must be greater than or equal to 1.")
emb = tf.reduce_sum(emb, axis=1)
my_model = MyModel()
model_output = my_model.build_model(embedding=emb,
dense_feature=batch["dense_feature"],
label=batch["label"],
is_training=is_train,
seed=DENSE_HASHTABLE_SEED)
return model_output
def evaluate():
print("read_test dataset")
if not cm.MODIFY_GRAPH_FLAG:
eval_label = eval_model.get("label")
sess.run([eval_iterator.initializer])
else:
eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(False).get("label")
sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)])
log_loss_list = []
pred_income_list = []
pred_mat_list = []
label_income_list = []
label_mat_list = []
eval_current_steps = 0
finished = False
print("eval begin")
while not finished:
eval_current_steps += 1
eval_start = time.time()
try:
eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label])
except tf.errors.OutOfRangeError:
break
eval_cost = time.time() - eval_start
qps_eval = (1 / eval_cost) * rank_size * cfg.batch_size
log_loss_list += list(eval_loss.reshape(-1))
pred_income = pred[0]
pred_mat = pred[1]
pred_income_list += list(pred_income.reshape(-1))
pred_mat_list += list(pred_mat.reshape(-1))
label_income_list += list(label[:, 0].reshape(-1))
label_mat_list += list(label[:, 1].reshape(-1))
print(f"eval current_steps: {eval_current_steps}, qps: {qps_eval}")
if eval_current_steps == cm.eval_steps:
finished = True
auc_income = roc_auc_score(label_income_list, pred_income_list)
auc_mat = roc_auc_score(label_mat_list, pred_mat_list)
mean_log_loss = np.mean(log_loss_list)
return auc_income, auc_mat, mean_log_loss
if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings("ignore")
comm = MPI.COMM_WORLD
clear_saved_model()
comm.Barrier()
cm.train_steps = 1000
cm.eval_steps = 1500
cfg = Config()
use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0)))
logger.info(f"USE_DYNAMIC:{use_dynamic}")
init(train_steps=cm.train_steps, eval_steps=cm.eval_steps,
use_dynamic=use_dynamic, use_dynamic_expansion=cm.use_dynamic_expansion)
rank_id = rec_sdk_common.communication.hccl.hccl_info.get_rank_id()
feature_spec_list_train = None
feature_spec_list_eval = None
if cm.use_faae:
feature_spec_list_train = create_feature_spec_list(cfg, use_timestamp=True)
feature_spec_list_eval = create_feature_spec_list(cfg, use_timestamp=True)
else:
feature_spec_list_train = create_feature_spec_list(cfg, use_timestamp=False)
feature_spec_list_eval = create_feature_spec_list(cfg, use_timestamp=False)
train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True,
dump_graph=True, is_use_faae=cm.use_faae)
eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False,
dump_graph=False, is_use_faae=cm.use_faae)
logger.info(f"train_batch: {train_batch}")
if cm.use_faae:
cfg.dev_vocab_size = cfg.dev_vocab_size // 2
optimizer_list = [get_dense_and_sparse_optimizer(cfg)]
emb_initializer = tf.constant_initializer(value=0.1)
sparse_hashtable = create_table(
key_dtype=cfg.key_type,
dim=tf.TensorShape([cfg.emb_dim]),
name="sparse_embeddings",
emb_initializer=emb_initializer,
is_dp=cm.USE_DP,
**cfg.get_emb_table_cfg()
)
if cm.use_faae:
tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, train_batch["timestamp"])
sparse_hashtable_list = [sparse_hashtable, sparse_hashtable] if cm.use_multi_lookup else [sparse_hashtable]
train_model = model_forward(feature_spec_list_train, sparse_hashtable_list, train_batch,
is_train=True, modify_graph=cm.MODIFY_GRAPH_FLAG)
eval_model = model_forward(feature_spec_list_eval, sparse_hashtable_list, eval_batch,
is_train=False, modify_graph=cm.MODIFY_GRAPH_FLAG)
dense_variables, sparse_variables = get_dense_and_sparse_variable()
trainable_varibles = []
trainable_varibles.extend(dense_variables)
if cm.use_dynamic_expansion:
trainable_varibles.append(tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB)[0])
else:
trainable_varibles.extend(sparse_variables)
rank_size = rec_sdk_common.communication.hccl.hccl_info.get_rank_size()
train_ops = []
for loss, (dense_optimizer, sparse_optimizer) in zip([train_model.get("loss")], optimizer_list):
grads = dense_optimizer.compute_gradients(loss, var_list=trainable_varibles)
avg_grads = []
for grad, var in grads[:-1]:
if rank_size > 1:
grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None
if grad is not None:
avg_grads.append((grad / 8.0, var))
train_ops.append(dense_optimizer.apply_gradients(avg_grads))
if cm.use_dynamic_expansion:
train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET)
sparse_grads = list(grads[-1])
grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)]
train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars))
else:
sparse_grads = list(grads[-1])
print("sparse_grads_tensor:", sparse_grads)
grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)]
train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars))
with tf.control_dependencies(train_ops):
train_ops = tf.no_op()
cfg.learning_rate = [cfg.learning_rate[0], cfg.learning_rate[1]]
if cm.MODIFY_GRAPH_FLAG:
modify_graph_and_start_emb_cache(dump_graph=True)
else:
start_asc_pipeline()
hook_list = []
if cm.use_faae:
hook_evict = EvictHook(evict_enable=True, evict_time_interval=120)
hook_list.append(hook_evict)
if cm.MODIFY_GRAPH_FLAG:
hook_list.append(GraphModifierHook(modify_graph=False))
if cm.use_faae:
sess = tf.compat.v1.train.MonitoredTrainingSession(
hooks=hook_list,
config=sess_config(dump_data=False)
)
sess.graph._unsafe_unfinalize()
if not cm.MODIFY_GRAPH_FLAG:
sess.run(train_iterator.initializer)
else:
sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True))
else:
sess = tf.compat.v1.Session(config=sess_config(dump_data=False))
sess.run(tf.compat.v1.global_variables_initializer())
if not cm.MODIFY_GRAPH_FLAG:
sess.run(train_iterator.initializer)
else:
sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True))
epoch = 0
cost_sum = 0
qps_sum = 0
best_auc_income = 0
best_auc_mat = 0
iteration_per_loop = 10
train_ops = util.set_iteration_per_loop(sess, train_ops, 10)
i = 0
while True:
i += 1
logger.info(f"################ training at step {i * iteration_per_loop} ################")
start_time = time.time()
try:
grad, loss, lr, global_step = sess.run([train_ops, train_model.get("loss"),
cfg.learning_rate, cfg.global_step])
except tf.errors.OutOfRangeError:
logger.info(f"Encounter the end of Sequence for training.")
break
end_time = time.time()
cost_time = end_time - start_time
qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop
cost_sum += cost_time
logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}")
logger.info(f"step: {i * iteration_per_loop}; grad: {grad}")
logger.info(f"step: {i * iteration_per_loop}; lr: {lr}")
logger.info(f"global step: {global_step}")
logger.info(f"step: {i * iteration_per_loop}; current sess cost time: {cost_time:.10f}; current QPS: {qps}")
logger.info(f"training at step:{i * iteration_per_loop}, table[{sparse_hashtable.table_name}], "
f"table size:{sparse_hashtable.size()}, table capacity:{sparse_hashtable.capacity()}")
if i % (cm.train_steps // iteration_per_loop) == 0:
if cm.interval is not None:
test_auc_income, test_auc_mat, test_mean_log_loss = evaluate_fix(i * iteration_per_loop, sess,
eval_model, eval_iterator)
else:
test_auc_income, test_auc_mat, test_mean_log_loss = evaluate()
print("Test auc income: {};Test auc mat: {} ;log_loss: {} ".format(test_auc_income,
test_auc_mat, test_mean_log_loss))
best_auc_income = max(best_auc_income, test_auc_income)
best_auc_mat = max(best_auc_mat, test_auc_mat)
logger.info(f"training step: {i * iteration_per_loop}, best auc income: "
f"{best_auc_income} , best auc mat: {best_auc_mat}")
sess.close()
terminate_config_initializer()
logger.info("Demo done!")