import tensorflow as tf
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.validator.validator import para_checker_decorator, ClassValidator
from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE
from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec
from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding
from mx_rec.util.initialize import ConfigInitializer
@para_checker_decorator(check_option_list=[("is_train", ClassValidator, {"classes": bool})])
def do_merge_lookup(is_train: bool = True):
"""
自动改图一表一查/多查,添加前向和反向节点:
1. 如果存在一表多查的情况,则对多查的表进行lookup合并操作,并用合并后的lookup result替换原来打桩的 mock lookup result.
2. 若不存在一表多查,则无需合并,用sparse forward得到的lookup result替换原来打桩的 mock lookup result.
3. 自动改图模式需要执行此函数,feature spec模式直接return.
4. 此函数在Optimizer.compute_gradients()中利用patch执行,确保train时拥有正确的梯度和计算图;eval时在改图阶段执行.
Args:
is_train: 当前是否为训练模式,训练模式为True,否则为False
Returns: None
"""
if not ConfigInitializer.get_instance().modify_graph:
logger.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.")
return
if ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(is_train):
logger.debug("The merge multi lookup has been executed once and does not need to be executed again.")
return
logger.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train)
cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE)
if not cutting_point_list:
raise RuntimeError("the sparse table does not have sparse lookup.")
check_cutting_points(cutting_point_list)
sub_cutting_points_dict = dict()
feature_spec_name_ids_dict = dict()
for cutting_point in cutting_point_list:
is_training = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING)
if is_training != is_train:
logger.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.",
is_train, cutting_point, is_training)
continue
table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE)
if not ConfigInitializer.get_instance().use_static and table_instance.multi_lookup_times.get(is_train) > 1:
feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC)
feature_spec_name_ids_dict[feature_spec.name] = cutting_point
if sub_cutting_points_dict.get(is_training) is None:
sub_cutting_points_dict[is_training] = []
sub_cutting_points_dict.get(is_training).append(cutting_point)
sub_cutting_point_list = sub_cutting_points_dict.get(is_train)
if not sub_cutting_point_list:
raise RuntimeError(f"The current mode(train: True, eval: False) is {is_train}, and the sparse table does not "
f"have anchor ids.")
for cutting_point in sub_cutting_point_list:
table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE)
feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC)
is_grad = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_GRAD)
if table_instance.multi_lookup_times.get(is_train) == 1:
logger.debug("The origin lookup result of %s for %s does not need to be replaced.",
feature_spec.name, table_instance.table_name)
continue
send_count = table_instance.send_count
kwargs = dict(is_train=is_train, lookup_ids=cutting_point, multi_lookup=True, is_grad=is_grad)
if not ConfigInitializer.get_instance().use_static:
kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict
lookup_result = table_instance.lookup_for_feat_spec(feature_spec, send_count, **kwargs)
graph = tf.compat.v1.get_default_graph()
replace_anchor_vec(graph, cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result)
logger.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name)
ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_train, True)
logger.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train)