"""define train, infer, eval, test process"""
import os
import time
import collections
import numpy as np
import tensorflow as tf
import portalocker
from npu_bridge.npu_init import *
from mx_rec.util.initialize import ConfigInitializer
from mx_rec.graph.modifier import modify_graph_and_start_emb_cache
import utils.util as util
import utils.metric as metric
from IO.iterator import FfmIterator
from IO.ffm_cache import FfmCache
from src.exDeepFM import ExtremeDeepFMModel
MODEL_MAP = {
'exDeepFM': ExtremeDeepFMModel,
}
class TrainModel(collections.namedtuple("TrainModel", ("graph", "model", "iterator", "filenames"))):
"""define train class, include graph, model, iterator"""
pass
def create_train_model(model_creator, hparams, scope=None):
filenames = tf.placeholder(tf.string, shape=[None])
src_dataset = tf.data.TFRecordDataset(filenames)
if hparams.data_format == 'ffm':
batch_input = FfmIterator(src_dataset)
elif hparams.data_format == 'din':
batch_input = DinIterator(src_dataset)
elif hparams.data_format == 'cccfnet':
batch_input = CCCFNetIterator(src_dataset)
else:
raise ValueError("not support {0} format data".format(hparams.data_format))
model = model_creator(
hparams,
iterator=batch_input,
scope=scope)
return TrainModel(
graph=tf.get_default_graph(),
model=model,
iterator=batch_input,
filenames=filenames)
def run_eval(load_model, load_sess, filename, sample_num_file, hparams, flag):
with open(sample_num_file, 'r') as f:
sample_num = int(f.readlines()[0].strip())
eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(True).get("labels")
initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
load_sess.run(initializer, feed_dict={load_model.filenames: [filename]})
preds = []
labels = []
while True:
try:
_, _, step_pred, step_labels = load_model.model.eval(load_sess, eval_label)
preds.extend(np.reshape(step_pred, -1))
labels.extend(np.reshape(step_labels, -1))
except tf.errors.OutOfRangeError:
break
preds = preds[:sample_num]
labels = labels[:sample_num]
hparams.logger.info("data num:{0:d}".format(len(labels)))
res = metric.cal_metric(labels, preds, hparams, flag)
return res
def run_infer(load_model, load_sess, filename, hparams, sample_num_file):
with open(sample_num_file, 'r') as f:
sample_num = int(f.readlines()[0].strip())
util.make_dir_with_lock(util.RES_DIR)
initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
load_sess.run(initializer, feed_dict={load_model.filenames: [filename]})
preds = []
while True:
try:
step_pred = load_model.model.infer(load_sess)
preds.extend(np.reshape(step_pred, -1))
except tf.errors.OutOfRangeError:
break
preds = preds[:sample_num]
hparams.res_name = util.convert_res_name(hparams.infer_file)
with open(hparams.res_name, 'w') as out:
out.write('\n'.join(map(str, preds)))
def cache_data(hparams, filename, flag):
if hparams.data_format == 'ffm':
cache_obj = FfmCache()
elif hparams.data_format == 'din':
cache_obj = DinCache()
elif hparams.data_format == 'cccfnet':
cache_obj = CCCFNetCache()
else:
raise ValueError(
"data format must be ffm, din, cccfnet, this format not defined {0}".format(hparams.data_format))
util.make_dir_with_lock(util.CACHE_DIR)
if flag == 'train':
hparams.train_file_cache = util.convert_cached_name(hparams.train_file, hparams.batch_size)
cached_name = hparams.train_file_cache
sample_num_path = util.TRAIN_NUM
impression_id_path = util.TRAIN_IMPRESSION_ID
elif flag == 'eval':
hparams.eval_file_cache = util.convert_cached_name(hparams.eval_file, hparams.batch_size)
cached_name = hparams.eval_file_cache
sample_num_path = util.EVAL_NUM
impression_id_path = util.EVAL_IMPRESSION_ID
elif flag == 'test':
hparams.test_file_cache = util.convert_cached_name(hparams.test_file, hparams.batch_size)
cached_name = hparams.test_file_cache
sample_num_path = util.TEST_NUM
impression_id_path = util.TEST_IMPRESSION_ID
elif flag == 'infer':
hparams.infer_file_cache = util.convert_cached_name(hparams.infer_file, hparams.batch_size)
cached_name = hparams.infer_file_cache
sample_num_path = util.INFER_NUM
impression_id_path = util.INFER_IMPRESSION_ID
else:
raise ValueError("flag must be train, eval, test, infer")
hparams.logger.info('cache filename: {}'.format(filename))
if not os.path.exists(util.LOCK_FILE):
open(util.LOCK_FILE, 'w').close()
with portalocker.Lock(util.LOCK_FILE, 'w', flags=portalocker.LOCK_EX) as lock_fh:
if not os.path.isfile(cached_name):
hparams.logger.info('has not cached file, begin cached...')
start_time = time.time()
sample_num, impression_id_list = cache_obj.write_tfrecord(filename, cached_name, hparams)
util.print_time("cache file used time", start_time)
hparams.logger.info("data sample num:{0}".format(sample_num))
with open(sample_num_path, 'w') as f:
f.write(str(sample_num) + '\n')
with open(impression_id_path, 'w') as f:
for impression_id in impression_id_list:
f.write(str(impression_id) + '\n')
def train(hparams, scope=None, target_session=""):
params = hparams.values()
for key, val in params.items():
hparams.logger.info(str(key) + ':' + str(val))
load_and_cache_data(hparams)
model_creator = get_model_creator(hparams.model_type, hparams.logger)
train_model = create_train_model(model_creator, hparams, scope)
gpuconfig = tf.ConfigProto()
gpuconfig.gpu_options.allow_growth = True
tf.set_random_seed(1234)
modify_graph_and_start_emb_cache(dump_graph=True)
train_sess = tf.Session(target=target_session, graph=train_model.graph, config=npu_config_proto(config_proto=gpuconfig))
train_sess.run(train_model.model.init_op)
if not hparams.load_model_name is None:
checkpoint_path = hparams.load_model_name
try:
train_model.model.saver.restore(train_sess, checkpoint_path)
hparams.logger.info('load model: {}'.format(checkpoint_path))
except:
raise IOError("Failed to find any matching files for {0}".format(checkpoint_path))
hparams.logger.info('total_loss = data_loss+regularization_loss, data_loss = {rmse or logloss ..}')
writer = tf.summary.FileWriter(util.SUMMARIES_DIR, train_sess.graph)
last_eval = 0
for epoch in range(hparams.epochs):
step = 0
initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
train_sess.run(initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]})
epoch_loss = 0
train_start = time.time()
train_load_time = 0
while True:
try:
t1 = time.time()
step_result = train_model.model.train(train_sess)
t3 = time.time()
train_load_time += t3 - t1
(_, step_loss, step_data_loss, summary) = step_result
writer.add_summary(summary, step)
epoch_loss += step_loss
step += 1
if step % hparams.show_step == 0:
hparams.logger.info('step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}' \
.format(step, step_loss, step_data_loss))
except tf.errors.OutOfRangeError:
hparams.logger.info('finish one epoch!')
break
train_end = time.time()
train_time = train_end - train_start
if epoch % hparams.save_epoch == 0:
checkpoint_path = train_model.model.saver.save(
sess=train_sess,
save_path=util.MODEL_DIR + 'epoch_' + str(epoch))
train_res = dict()
train_res["loss"] = epoch_loss / step
eval_start = time.time()
eval_res = run_eval(train_model, train_sess, hparams.eval_file_cache, util.EVAL_NUM, hparams, flag='eval')
train_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(train_res.items(), key=lambda x: x[0])])
eval_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(eval_res.items(), key=lambda x: x[0])])
if hparams.test_file is not None:
test_res = run_eval(train_model, train_sess, hparams.test_file_cache, util.TEST_NUM, hparams, flag='test')
test_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(test_res.items(), key=lambda x: x[0])])
eval_end = time.time()
eval_time = eval_end - eval_start
if hparams.test_file is not None:
hparams.logger.info('at epoch {0:d}'.format(
epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info)
else:
hparams.logger.info('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info)
hparams.logger.info('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}' \
.format(epoch, train_time, eval_time))
hparams.logger.info('\n')
if eval_res["auc"] - last_eval < - 0.003:
break
if eval_res["auc"] > last_eval:
last_eval = eval_res["auc"]
writer.close()
if hparams.infer_file is not None:
run_infer(train_model, train_sess, hparams.infer_file_cache, hparams, util.INFER_NUM)
def load_and_cache_data(hparams):
hparams.logger.info('load and cache data...')
if hparams.train_file is not None:
cache_data(hparams, hparams.train_file, flag='train')
if hparams.eval_file is not None:
cache_data(hparams, hparams.eval_file, flag='eval')
if hparams.test_file is not None:
cache_data(hparams, hparams.test_file, flag='test')
if hparams.infer_file is not None:
cache_data(hparams, hparams.infer_file, flag='infer')
def get_model_creator(model_type, logger):
model_class = MODEL_MAP.get(model_type)
if model_class is None:
raise ValueError("model type should be one of: cccfnet, deepFM, deepWide, dnn, fm, lr, ipnn, opnn, din")
logger.info(f"run {model_type} model!")
return model_class