import os
import re
from typing import List
import six
import tensorflow as tf
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.training import warm_starting_util
from rec_sdk_common.log.log import LoggingProxy as logger
from mx_rec.saver.saver import Saver
class WarmStartController:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(WarmStartController, cls).__new__(cls)
cls._instance._warm_start_dict = {}
cls._instance.table_name_to_prev_table_name = {}
return cls._instance
def __init__(self):
logger.info("start to build WarmStartController.")
def add_element(self, path: str, table_list: List[str]):
if path not in self._warm_start_dict:
self._warm_start_dict[path] = table_list
else:
self._warm_start_dict[path] += table_list
def add_table_to_prev_table(self, table: str, prev_table: str):
self.table_name_to_prev_table_name[table] = prev_table
def get_elements(self):
return self._warm_start_dict
def patch_for_warm_start():
estimator_lib.Estimator.__init__ = patch_estimator_init(estimator_lib.Estimator.__init__)
warm_starting_util.warm_start = patch_for_func_warm_start(warm_starting_util.warm_start)
estimator_lib.Estimator.train = patch_for_estimator_train(estimator_lib.Estimator.train)
def patch_estimator_init(func):
def wrapper(*args, **kwargs):
warm_start_from = kwargs.get('warm_start_from', None)
if warm_start_from:
kwargs['warm_start_from'] = warm_settings_filter(warm_start_from)
return func(*args, **kwargs)
return wrapper
def patch_for_func_warm_start(func):
def wrapper(*args, **kwargs):
ckpt_to_initialize_from = args[0]
if isinstance(ckpt_to_initialize_from, (list, tuple)):
vars_to_warm_start_list = args[1]
var_name_to_prev_var_name_list = args[3]
warm_start_num = len(ckpt_to_initialize_from)
for i in range(warm_start_num):
f = func(ckpt_to_initialize_from[i], vars_to_warm_start_list[i], args[2],
var_name_to_prev_var_name_list[i], **kwargs)
return f
else:
return func(*args, **kwargs)
return wrapper
def patch_for_estimator_train(func):
def wrapper(*args, **kwargs):
hooks = kwargs.get('hooks', [])
if WarmStartController().get_elements():
hooks.append(SparseRestoreHook())
return func(*args, **kwargs)
return wrapper
def warm_settings_filter(warm_start_from):
warm_start_from_res = None
if isinstance(warm_start_from, estimator_lib.WarmStartSettings):
if isinstance(warm_start_from.ckpt_to_initialize_from, (list, tuple)):
out_setting_list = []
logger.info("According to warm_start_settings, warm start will load from more than one checkpoint path.")
warm_start_settings_list = _build_warm_settings_list(warm_start_from)
for setting in warm_start_settings_list:
filter_setting = _warm_settings_filter(setting)
if filter_setting:
out_setting_list.append(filter_setting)
if out_setting_list:
warm_start_from_res = recover_warm_settings(out_setting_list)
elif isinstance(warm_start_from.ckpt_to_initialize_from, (six.string_types, six.binary_type)):
logger.info("According to warm_start_settings, warm start will load from only one checkpoint path.")
filter_setting = _warm_settings_filter(warm_start_from)
if filter_setting:
warm_start_from_res = filter_setting
elif isinstance(warm_start_from, (six.string_types, six.binary_type)):
table_name_list = get_table_name_set_by_ckpt_path(warm_start_from)
WarmStartController().add_element(warm_start_from, table_name_list)
warm_start_from_res = warm_start_from
else:
raise ValueError("Invalid parameter: warm_start_from. ")
return warm_start_from_res
def recover_warm_settings(setting_list: List[tf.estimator.WarmStartSettings]) -> tf.estimator.WarmStartSettings:
"""
Recover WarmStartSettings from a list of custom-defined WarmStartSettings.
"""
ckpt_to_initialize_from_list = []
vars_to_warm_start_list = []
var_name_to_prev_var_name_list = []
for setting in setting_list:
ckpt_to_initialize_from_list.append(setting.ckpt_to_initialize_from)
vars_to_warm_start_list.append(setting.vars_to_warm_start)
var_name_to_prev_var_name_list.append(setting.var_name_to_prev_var_name)
return estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=ckpt_to_initialize_from_list,
vars_to_warm_start=vars_to_warm_start_list,
var_name_to_prev_var_name=var_name_to_prev_var_name_list)
def _build_warm_settings_list(warm_start_from: tf.estimator.WarmStartSettings) -> List[tf.estimator.WarmStartSettings]:
"""
Converts custom-defined WarmStartSettings into a list of TensorFlow-native WarmStartSettings.
"""
ckpt_to_initialize_from = warm_start_from.ckpt_to_initialize_from
vars_to_warm_start = warm_start_from.vars_to_warm_start
var_name_to_prev_var_name = warm_start_from.var_name_to_prev_var_name
for params in [vars_to_warm_start, var_name_to_prev_var_name]:
if not isinstance(params, (list, tuple)):
raise ValueError("If you choose to load from multiple model paths through the warm start option, "
"then the parameter type in the warm settings should be a list.")
if not (len(ckpt_to_initialize_from) == len(vars_to_warm_start) == len(var_name_to_prev_var_name)):
raise ValueError("If you choose to load from multiple model paths through the warm start option, "
"then the parameter list list should be the same length. ")
warm_start_settings_count = len(ckpt_to_initialize_from)
warm_start_settings_list = []
for i in range(warm_start_settings_count):
tmp_settings = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=ckpt_to_initialize_from[i],
vars_to_warm_start=vars_to_warm_start[i],
var_name_to_prev_var_name=var_name_to_prev_var_name[i])
warm_start_settings_list.append(tmp_settings)
return warm_start_settings_list
def _warm_settings_filter(warm_start_setting: tf.estimator.WarmStartSettings) -> tf.estimator.WarmStartSettings:
"""
Filter the vars_to_warm_start parameter to remove sparse table parameters.
"""
vars_to_warm_start = warm_start_setting.vars_to_warm_start
var_name_to_prev_var_name = warm_start_setting.var_name_to_prev_var_name
vars_to_warm_start_res = []
warm_start_setting_res = None
table_name_list = get_table_name_set_by_ckpt_path(warm_start_setting.ckpt_to_initialize_from)
if isinstance(vars_to_warm_start, str):
matching_tables = [table for table in table_name_list if re.match(vars_to_warm_start, table)]
if matching_tables:
WarmStartController().add_element(warm_start_setting.ckpt_to_initialize_from, matching_tables)
warm_start_setting_res = warm_start_setting
elif all(isinstance(v, str) for v in vars_to_warm_start):
sparse_vars = []
for v in vars_to_warm_start:
matching_tables = [table for table in table_name_list if re.match(v, table)]
if matching_tables:
sparse_vars.append(v)
WarmStartController().add_element(warm_start_setting.ckpt_to_initialize_from, matching_tables)
vars_to_warm_start_res = [v for v in vars_to_warm_start if v not in sparse_vars]
if vars_to_warm_start_res:
warm_start_setting_res = estimator_lib.WarmStartSettings(
ckpt_to_initialize_from=warm_start_setting.ckpt_to_initialize_from,
vars_to_warm_start=vars_to_warm_start_res,
var_name_to_prev_var_name=warm_start_setting.var_name_to_prev_var_name)
else:
raise ValueError("vars_to_warm_start must be list or str!")
return warm_start_setting_res
def get_table_name_set_by_ckpt_path(warm_start_path: str) -> List[str]:
'''
Get the list of sparse table names saved under the path 'warm_start_path'.
'''
table_name_list = []
if tf.io.gfile.isdir(warm_start_path):
restore_path = get_latest_ckpt(warm_start_path)
else:
restore_path = warm_start_path
directory, base_name = os.path.split(restore_path)
ckpt_name = f"sparse-{base_name}"
sparse_path = os.path.join(directory, ckpt_name)
if not tf.io.gfile.isdir(sparse_path):
logger.info("under the warm start path %s, sparse directory %s not exists.", warm_start_path, sparse_path)
else:
for dirname in tf.io.gfile.listdir(sparse_path):
table_name_list.append(dirname)
return table_name_list
def get_latest_ckpt(warm_start_path: str) -> str:
ckpt_path = os.path.join(warm_start_path, "checkpoint")
if not tf.io.gfile.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint file is missing under the warm start model path {warm_start_path}")
with tf.io.gfile.GFile(ckpt_path, "r") as f:
latest_ckpt = f.readline().rstrip()
latest_ckpt = latest_ckpt.split(":")[1].strip(' ').replace('"', '')
latest_ckpt = latest_ckpt.split("/")[-1]
path = os.path.join(warm_start_path, latest_ckpt)
return path
class SparseRestoreHook(tf.estimator.SessionRunHook):
def __init__(self):
logger.info("In warm start mode, SparseRestoreHook has been initialized.")
self._is_warm_start = False
self._saver = None
self._warm_start_dict = {}
def begin(self):
self._saver = Saver()
logger.info("In warm start mode, begin SparseRestoreHook.")
def after_create_session(self, session, coord):
if not self._is_warm_start:
self._warm_start_dict = WarmStartController().get_elements()
for path, restore_tables in self._warm_start_dict.items():
restore_path = get_latest_ckpt(path)
self._saver.restore(session, restore_path, restore_tables)
self._is_warm_start = True