#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf

from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.validator.validator import para_checker_decorator, ClassValidator
from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE
from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec
from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding
from mx_rec.util.initialize import ConfigInitializer



@para_checker_decorator(check_option_list=[("is_train", ClassValidator, {"classes": bool})])
def do_merge_lookup(is_train: bool = True):
    """
    自动改图一表一查/多查,添加前向和反向节点:
        1. 如果存在一表多查的情况,则对多查的表进行lookup合并操作,并用合并后的lookup result替换原来打桩的 mock lookup result.
        2. 若不存在一表多查,则无需合并,用sparse forward得到的lookup result替换原来打桩的 mock lookup result.
        3. 自动改图模式需要执行此函数,feature spec模式直接return.
        4. 此函数在Optimizer.compute_gradients()中利用patch执行,确保train时拥有正确的梯度和计算图;eval时在改图阶段执行.

    Args:
        is_train: 当前是否为训练模式,训练模式为True,否则为False

    Returns: None

    """

    if not ConfigInitializer.get_instance().modify_graph:
        logger.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.")
        return
    if ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(is_train):
        logger.debug("The merge multi lookup has been executed once and does not need to be executed again.")
        return
    logger.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train)

    # get anchor ids
    cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE)
    if not cutting_point_list:
        raise RuntimeError("the sparse table does not have sparse lookup.")
    check_cutting_points(cutting_point_list)

    # get lookup info
    sub_cutting_points_dict = dict()
    feature_spec_name_ids_dict = dict()
    for cutting_point in cutting_point_list:
        is_training = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING)
        if is_training != is_train:
            logger.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.",
                         is_train, cutting_point, is_training)
            continue

        table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE)
        if not ConfigInitializer.get_instance().use_static and table_instance.multi_lookup_times.get(is_train) > 1:
            feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC)
            feature_spec_name_ids_dict[feature_spec.name] = cutting_point
        if sub_cutting_points_dict.get(is_training) is None:
            sub_cutting_points_dict[is_training] = []
        sub_cutting_points_dict.get(is_training).append(cutting_point)

    # merge or restore lookup
    sub_cutting_point_list = sub_cutting_points_dict.get(is_train)
    if not sub_cutting_point_list:
        raise RuntimeError(f"The current mode(train: True, eval: False) is {is_train}, and the sparse table does not "
                           f"have anchor ids.")

    for cutting_point in sub_cutting_point_list:
        table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE)
        feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC)
        is_grad = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_GRAD)
        if table_instance.multi_lookup_times.get(is_train) == 1:
            logger.debug("The origin lookup result of %s for %s does not need to be replaced.",
                         feature_spec.name, table_instance.table_name)
            continue

        send_count = table_instance.send_count
        kwargs = dict(is_train=is_train, lookup_ids=cutting_point, multi_lookup=True, is_grad=is_grad)
        if not ConfigInitializer.get_instance().use_static:
            kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict
        lookup_result = table_instance.lookup_for_feat_spec(feature_spec, send_count, **kwargs)
        graph = tf.compat.v1.get_default_graph()
        replace_anchor_vec(graph, cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result)
        logger.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name)

    # records whether the current mode has been merged or restored lookup
    ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_train, True)
    logger.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train)