"""define train, infer, eval, test process"""
import os
import time
import collections

import numpy as np
import tensorflow as tf
import portalocker

from npu_bridge.npu_init import *
from mx_rec.util.initialize import ConfigInitializer
from mx_rec.graph.modifier import modify_graph_and_start_emb_cache

import utils.util as util
import utils.metric as metric
from IO.iterator import FfmIterator
from IO.ffm_cache import FfmCache
from src.exDeepFM import ExtremeDeepFMModel

MODEL_MAP = {
    'exDeepFM': ExtremeDeepFMModel,
}


class TrainModel(collections.namedtuple("TrainModel", ("graph", "model", "iterator", "filenames"))):
    """define train class, include graph, model, iterator"""
    pass


def create_train_model(model_creator, hparams, scope=None):
    # feed train file name, valid file name, or test file name
    filenames = tf.placeholder(tf.string, shape=[None])
    src_dataset = tf.data.TFRecordDataset(filenames)

    if hparams.data_format == 'ffm':
        batch_input = FfmIterator(src_dataset)
    elif hparams.data_format == 'din':
        batch_input = DinIterator(src_dataset)
    elif hparams.data_format == 'cccfnet':
        batch_input = CCCFNetIterator(src_dataset)
    else:
        raise ValueError("not support {0} format data".format(hparams.data_format))
    # build model
    model = model_creator(
        hparams,
        iterator=batch_input,
        scope=scope)

    return TrainModel(
        graph=tf.get_default_graph(),
        model=model,
        iterator=batch_input,
        filenames=filenames)


# run evaluation and get evaluted loss
def run_eval(load_model, load_sess, filename, sample_num_file, hparams, flag):
    # load sample num
    with open(sample_num_file, 'r') as f:
        sample_num = int(f.readlines()[0].strip())
    eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(True).get("labels")
    initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
    load_sess.run(initializer, feed_dict={load_model.filenames: [filename]})
    preds = []
    labels = []
    while True:
        try:
            _, _, step_pred, step_labels = load_model.model.eval(load_sess, eval_label)
            preds.extend(np.reshape(step_pred, -1))
            labels.extend(np.reshape(step_labels, -1))
        except tf.errors.OutOfRangeError:
            break
    preds = preds[:sample_num]
    labels = labels[:sample_num]
    hparams.logger.info("data num:{0:d}".format(len(labels)))
    res = metric.cal_metric(labels, preds, hparams, flag)
    return res


# run infer
def run_infer(load_model, load_sess, filename, hparams, sample_num_file):
    # load sample num
    with open(sample_num_file, 'r') as f:
        sample_num = int(f.readlines()[0].strip())
    util.make_dir_with_lock(util.RES_DIR)
    # In the run_eval function, get_initializer's parameter is set to true.
    initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
    load_sess.run(initializer, feed_dict={load_model.filenames: [filename]})
    preds = []
    while True:
        try:
            step_pred = load_model.model.infer(load_sess)
            preds.extend(np.reshape(step_pred, -1))
        except tf.errors.OutOfRangeError:
            break
    preds = preds[:sample_num]
    hparams.res_name = util.convert_res_name(hparams.infer_file)
    with open(hparams.res_name, 'w') as out:
        out.write('\n'.join(map(str, preds)))


# cache data
def cache_data(hparams, filename, flag):
    if hparams.data_format == 'ffm':
        cache_obj = FfmCache()
    elif hparams.data_format == 'din':
        cache_obj = DinCache()
    elif hparams.data_format == 'cccfnet':
        cache_obj = CCCFNetCache()
    else:
        raise ValueError(
            "data format must be ffm, din, cccfnet, this format not defined {0}".format(hparams.data_format))
    util.make_dir_with_lock(util.CACHE_DIR)
    if flag == 'train':
        hparams.train_file_cache = util.convert_cached_name(hparams.train_file, hparams.batch_size)
        cached_name = hparams.train_file_cache
        sample_num_path = util.TRAIN_NUM
        impression_id_path = util.TRAIN_IMPRESSION_ID
    elif flag == 'eval':
        hparams.eval_file_cache = util.convert_cached_name(hparams.eval_file, hparams.batch_size)
        cached_name = hparams.eval_file_cache
        sample_num_path = util.EVAL_NUM
        impression_id_path = util.EVAL_IMPRESSION_ID
    elif flag == 'test':
        hparams.test_file_cache = util.convert_cached_name(hparams.test_file, hparams.batch_size)
        cached_name = hparams.test_file_cache
        sample_num_path = util.TEST_NUM
        impression_id_path = util.TEST_IMPRESSION_ID
    elif flag == 'infer':
        hparams.infer_file_cache = util.convert_cached_name(hparams.infer_file, hparams.batch_size)
        cached_name = hparams.infer_file_cache
        sample_num_path = util.INFER_NUM
        impression_id_path = util.INFER_IMPRESSION_ID
    else:
        raise ValueError("flag must be train, eval, test, infer")
    hparams.logger.info('cache filename: {}'.format(filename))

    if not os.path.exists(util.LOCK_FILE):
        open(util.LOCK_FILE, 'w').close()

    with portalocker.Lock(util.LOCK_FILE, 'w', flags=portalocker.LOCK_EX) as lock_fh:
        if not os.path.isfile(cached_name):
            hparams.logger.info('has not cached file, begin cached...')
            start_time = time.time()
            sample_num, impression_id_list = cache_obj.write_tfrecord(filename, cached_name, hparams)
            util.print_time("cache file used time", start_time)
            hparams.logger.info("data sample num:{0}".format(sample_num))
            with open(sample_num_path, 'w') as f:
                f.write(str(sample_num) + '\n')
            with open(impression_id_path, 'w') as f:
                for impression_id in impression_id_list:
                    f.write(str(impression_id) + '\n')


def train(hparams, scope=None, target_session=""):
    params = hparams.values()
    for key, val in params.items():
        hparams.logger.info(str(key) + ':' + str(val))

    load_and_cache_data(hparams)

    model_creator = get_model_creator(hparams.model_type, hparams.logger)

    # define train,eval,infer graph
    # define train session, eval session, infer session
    train_model = create_train_model(model_creator, hparams, scope)
    gpuconfig = tf.ConfigProto()
    gpuconfig.gpu_options.allow_growth = True
    tf.set_random_seed(1234)

    modify_graph_and_start_emb_cache(dump_graph=True)

    train_sess = tf.Session(target=target_session, graph=train_model.graph, config=npu_config_proto(config_proto=gpuconfig))

    train_sess.run(train_model.model.init_op)
    # load model from checkpoint
    if not hparams.load_model_name is None:
        checkpoint_path = hparams.load_model_name
        try:
            train_model.model.saver.restore(train_sess, checkpoint_path)
            hparams.logger.info('load model: {}'.format(checkpoint_path))
        except:
            raise IOError("Failed to find any matching files for {0}".format(checkpoint_path))
    hparams.logger.info('total_loss = data_loss+regularization_loss, data_loss = {rmse or logloss ..}')
    writer = tf.summary.FileWriter(util.SUMMARIES_DIR, train_sess.graph)
    last_eval = 0
    for epoch in range(hparams.epochs):
        step = 0
        initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True)
        train_sess.run(initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]})

        epoch_loss = 0
        train_start = time.time()
        train_load_time = 0
        while True:
            try:
                t1 = time.time()
                step_result = train_model.model.train(train_sess)
                t3 = time.time()
                train_load_time += t3 - t1
                (_, step_loss, step_data_loss, summary) = step_result
                writer.add_summary(summary, step)
                epoch_loss += step_loss
                step += 1
                if step % hparams.show_step == 0:
                    hparams.logger.info('step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}' \
                          .format(step, step_loss, step_data_loss))
            except tf.errors.OutOfRangeError:
                hparams.logger.info('finish one epoch!')
                break
        train_end = time.time()
        train_time = train_end - train_start
        if epoch % hparams.save_epoch == 0:
            checkpoint_path = train_model.model.saver.save(
                sess=train_sess,
                save_path=util.MODEL_DIR + 'epoch_' + str(epoch))
        train_res = dict()
        train_res["loss"] = epoch_loss / step
        eval_start = time.time()
        eval_res = run_eval(train_model, train_sess, hparams.eval_file_cache, util.EVAL_NUM, hparams, flag='eval')
        train_info = ', '.join(
            [str(item[0]) + ':' + str(item[1])
             for item in sorted(train_res.items(), key=lambda x: x[0])])
        eval_info = ', '.join(
            [str(item[0]) + ':' + str(item[1])
             for item in sorted(eval_res.items(), key=lambda x: x[0])])
        if hparams.test_file is not None:
            test_res = run_eval(train_model, train_sess, hparams.test_file_cache, util.TEST_NUM, hparams, flag='test')
            test_info = ', '.join(
                [str(item[0]) + ':' + str(item[1])
                 for item in sorted(test_res.items(), key=lambda x: x[0])])
        eval_end = time.time()
        eval_time = eval_end - eval_start
        if hparams.test_file is not None:
            hparams.logger.info('at epoch {0:d}'.format(
                epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info)
        else:
            hparams.logger.info('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info)

        hparams.logger.info('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}' \
                    .format(epoch, train_time, eval_time))
        hparams.logger.info('\n')

        if eval_res["auc"] - last_eval < - 0.003:
            break
        if eval_res["auc"] > last_eval:
            last_eval = eval_res["auc"]
    writer.close()
    # after train,run infer
    if hparams.infer_file is not None:
        run_infer(train_model, train_sess, hparams.infer_file_cache, hparams, util.INFER_NUM)


def load_and_cache_data(hparams):
    hparams.logger.info('load and cache data...')
    if hparams.train_file is not None:
        cache_data(hparams, hparams.train_file, flag='train')
    if hparams.eval_file is not None:
        cache_data(hparams, hparams.eval_file, flag='eval')
    if hparams.test_file is not None:
        cache_data(hparams, hparams.test_file, flag='test')
    if hparams.infer_file is not None:
        cache_data(hparams, hparams.infer_file, flag='infer')


def get_model_creator(model_type, logger):
    model_class = MODEL_MAP.get(model_type)
    if model_class is None:
        raise ValueError("model type should be one of: cccfnet, deepFM, deepWide, dnn, fm, lr, ipnn, opnn, din")
    
    logger.info(f"run {model_type} model!")
    return model_class