from typing import Dict
import tensorflow as tf
class OptimizerConfig:
def __init__(self):
self._optimizer_instance = None
self._table_optimizer_dict = {True: {}, False: {}}
@property
def optim_params_list(self):
if not self._optimizer_instance:
return []
return self._optimizer_instance.optim_param_list
@property
def optimizer_instance(self):
return self._optimizer_instance
@optimizer_instance.setter
def optimizer_instance(self, optimizer):
self._optimizer_instance = optimizer
def set_optimizer_for_table(
self, table_name: str, optimizer_name: str, optimizer_dict: Dict[str, tf.Variable], is_training: bool = True
):
self._table_optimizer_dict[is_training][table_name] = {optimizer_name: optimizer_dict}
def get_optimizer_by_table_name(self, table_name: str, is_training: bool = True) -> Dict[str, tf.Variable]:
if self._table_optimizer_dict.get(is_training) is None:
raise KeyError(f"key `{is_training}` does not exist")
return self._table_optimizer_dict.get(is_training).get(table_name)