import os
import tensorflow as tf
from rec_sdk_common.util.tf_adapter import npu_ops
from mxrec.python.binding.runtime_manager import RuntimeManager
class TimeEvictor:
_D2H_CHANNEL_SUFFIX = "time_evictor_d2h"
_BIN_FILE_SUFFIX = "time_evictor.bin"
def __init__(self, table_name: str, max_cold_secs: int):
self._table_name = table_name
self._max_cold_secs = max_cold_secs
self._d2h_channel_name = self._get_device_to_host_channel_name()
self._host_runtime_manager = RuntimeManager()
self._host_runtime_manager.start_time_evictor(table_name, max_cold_secs)
def update_last_timestamp(self, keys: tf.Tensor) -> tf.Tensor:
if not isinstance(keys, tf.Tensor):
raise TypeError(f"expected keys to be a tf.Tensor, but got {type(keys).__name__}")
send_op = npu_ops.outfeed_enqueue_op(
channel_name=self._d2h_channel_name, inputs=[keys], name="{}_op".format(self._d2h_channel_name)
)
with tf.control_dependencies(control_inputs=[send_op]):
keys = tf.identity(keys)
return keys
def save(self, save_path: str):
file_name = "{}_{}".format(self._table_name, self._BIN_FILE_SUFFIX)
file_path = os.path.join(save_path, file_name)
self._host_runtime_manager.save_time_evictor(self._table_name, file_path)
def load(self, load_path: str):
file_name = "{}_{}".format(self._table_name, self._BIN_FILE_SUFFIX)
file_path = os.path.join(load_path, file_name)
self._host_runtime_manager.load_time_evictor(self._table_name, file_path)
def get_evicted_keys(self, table_name: str) -> tf.Tensor:
evicted_keys = self._host_runtime_manager.get_evicted_keys(table_name)
evicted_keys = tf.constant(evicted_keys, dtype=tf.int64)
return evicted_keys
def _get_device_to_host_channel_name(self) -> str:
return "{}_{}".format(self._table_name, self._D2H_CHANNEL_SUFFIX)