from typing import Tuple
import tensorflow as tf
from tensorflow import Tensor
from rec_sdk_common.log.log import LoggingProxy as logger
from mx_rec.constants import constants
from mx_rec.graph.constants import AnchorIteratorOp
from mx_rec.util.initialize import ConfigInitializer
def mark_orphan_lookup_key(lookup_key: Tensor) -> Tensor:
"""Upward search default TF::Graph, mark the key tensor without TF::Dataset as root op.
Args:
lookup_key: An 'tf.Tensor' represents ID-like keys.
Return:
marked_lookup_key: Identity of `lookup_key` with name prefix 'orphan'.
"""
graph_def = tf.compat.v1.get_default_graph().as_graph_def()
subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [lookup_key.op.name])
for node in subgraph.node:
if node.op == AnchorIteratorOp.ITERATOR_GET_NEXT.value:
return lookup_key
name_prefix = constants.ORPHAN_LOOKUP_KEY_PREFIX
marked_lookup_key = tf.identity(lookup_key, name="{}/{}".format(name_prefix, lookup_key.op.name))
logger.info("Mark orphan lookup key %s as %s.", lookup_key, marked_lookup_key)
return marked_lookup_key
def check_and_set_vocab_size(device_vocab_size: int, host_vocab_size: int, ssd_vocab_size: int) -> Tuple[int, int, int]:
if ConfigInitializer.get_instance().use_dynamic_expansion:
logger.info("In dyanmic expansion mode, DDR and SSD vocabulary size will be reset to 0 automatically!")
return (device_vocab_size, 0, 0)
if host_vocab_size == 0 and ssd_vocab_size > 0:
raise ValueError("set SSD vocabulary size must set DDR vocabulary size first")
return (device_vocab_size, host_vocab_size, ssd_vocab_size)