from typing import List
import tensorflow as tf
from tensorflow import Operation, Graph
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.validator.validator import ClassValidator, para_checker_decorator
from mx_rec.graph.slicers import LookupSubgraphSlicer, OrphanLookupKeySlicer
class LookupSubgraphSlicerHook(tf.estimator.SessionRunHook):
@para_checker_decorator(
check_option_list=[
("op_types", ClassValidator, {"classes": (list)}),
]
)
def __init__(self, op_types: List[Operation]) -> None:
super().__init__()
self._op_types = op_types
def begin(self) -> None:
slicer = LookupSubgraphSlicer(self._op_types)
logger.info("Starts to summarize sliceable specific operations in lookup subgraph!")
slicer.summarize()
logger.info("Starts to slice specific operations and their corresponding minimum dependency graphs!")
slicer.slice()
class OrphanLookupKeySlicerHook(tf.estimator.SessionRunHook):
def __init__(self) -> None:
super().__init__()
def begin(self) -> None:
slicer = OrphanLookupKeySlicer()
logger.info("Starts to summarize sliceable orphan lookup keys!")
slicer.summarize()
logger.info("Starts to slice orphan lookup keys and their corresponding minimum dependency graphs!")
slicer.slice()