from typing import Optional, Any
from tensorflow.python.framework import ops
from tensorflow import Variable
from rec_sdk_common.log.log import LoggingProxy as logger
class SparseEmbedConfig:
"""
Sparse table related configurations.
"""
def __init__(self):
self._table_instance_dict = dict()
self._dangling_table = []
self._table_name_set = set()
self._removing_var_list = []
self._name_to_var_dict = dict()
self._tensor_to_table_instance_dict = dict()
@property
def table_instance_dict(self):
return self._table_instance_dict
@property
def dangling_table(self):
return self._dangling_table
@property
def table_name_set(self):
return self._table_name_set
@property
def name_to_var_dict(self):
return self._name_to_var_dict
@property
def removing_var_list(self):
return self._removing_var_list
def get_table_instance(self, key) -> object:
"""
Get table instance by key.
Args:
key: It's tf.Tensor in dynamic expansion mode and tf.Variable in normal mode(HBM/DDR/SSD).
Returns: Table instance.
"""
from mx_rec.util.initialize import ConfigInitializer
if ConfigInitializer.get_instance().use_dynamic_expansion and isinstance(key, ops.Tensor):
return self.get_table_instance_by_tensor(key)
if key not in self._table_instance_dict:
raise KeyError("given key => '{}' does not exist".format(key))
return self._table_instance_dict.get(key)
def get_table_instance_by_tensor(self, tensor) -> object:
if tensor not in self._tensor_to_table_instance_dict:
raise KeyError("given tensor => '{}' does not exist".format(tensor))
return self._tensor_to_table_instance_dict.get(tensor)
def get_table_instance_by_name(self, table_name: Optional[str]) -> object:
if table_name not in self._name_to_var_dict:
raise KeyError("given table name => '{}' does not exist".format(table_name))
key = self._name_to_var_dict.get(table_name)
return self._table_instance_dict.get(key)
def insert_dangling_table(self, table_name: Optional[str]) -> None:
if table_name not in self._dangling_table:
self._dangling_table.append(table_name)
def insert_removing_var_list(self, var_name) -> None:
if var_name not in self._removing_var_list:
self._removing_var_list.append(var_name)
def insert_table_instance(self, name: str, key: Variable, instance: object, eval_flag: bool) -> None:
if key in self._table_instance_dict:
raise KeyError(f"Given key {key} has been used.")
if eval_flag:
self._table_instance_dict[key] = instance
return
if name in self._table_name_set:
raise ValueError(f"Duplicated hashtable name '{name}' was used.")
logger.debug("Record one hash table, with name: %s, key: %s.", name, key)
self._table_name_set.add(name)
self._name_to_var_dict[name] = key
self._table_instance_dict[key] = instance
def insert_table_instance_to_tensor_dict(self, tensor: ops.Tensor, instance: object) -> None:
if tensor in self._tensor_to_table_instance_dict:
raise KeyError(f"Given tensor {tensor} has been used.")
logger.debug("Record one hash table for expansion mode, with tensor: %s.", tensor)
self._tensor_to_table_instance_dict[tensor] = instance
def update_table_instance(self, table_name: str, emb_table: Any, old_var: Variable, new_var: Variable) -> None:
self._name_to_var_dict[table_name] = new_var
del self._table_instance_dict[old_var]
self._table_instance_dict[new_var] = emb_table
def export_table_num(self) -> int:
return len(self.table_instance_dict) if self.table_instance_dict else 0