import json
import os
import threading
import glob
import struct
import subprocess
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Union, Generator, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.python.util import compat
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.communication.hccl.hccl_info import (
get_rank_id,
get_rank_size,
get_local_rank_size,
)
from rec_sdk_common.constants.constants import (
FileParams,
DeviceType,
ValidatorParams,
)
from rec_sdk_common.util.tf_adapter import npu_ops
from rec_sdk_common.validator.validator import (
DirectoryValidator,
para_checker_decorator,
ClassValidator,
IntValidator,
OptionalStringValidator,
)
from mx_rec.validator.validator import FileValidator
from mx_rec.constants.constants import (
DataName,
DataAttr,
HDFS_FILE_PREFIX,
TRAIN_CHANNEL_ID,
BASE_MODEL,
DELTA_MODEL,
SAVE_DIR_MODE,
SAVE_FILE_MODE,
SAVE_FILE_FLAG,
FLOAT32_BYTES,
UINT64_BYTES,
UINT32_BYTES,
)
from mx_rec.saver.constants import FILE_BUFFER_SIZE
from mx_rec.saver.utils import check_files_in_directories, get_optimizer_dict_by_table_name
from mx_rec.util.initialize import ConfigInitializer
from mx_rec.util.perf import performance
from mx_rec.util.global_env_conf import global_env
from mx_rec.optimizers.base import CustomizedOptimizer
from mx_rec.graph.merge_lookup import do_merge_lookup
from mx_rec.graph.modifier import replace_anchor_for_ddr_ssd, change_ext_emb_size_by_opt
SAVE_SPARSE_PATH_PREFIX = "sparse"
SAVE_DELTA_SPARSE_PATH_PREFIX = "delta-sparse"
GLOBAL_STEP_STR = "global_step"
SSD_SAVE_PATH_PREFIX = "ssd_sparse_model_rank_"
SSD_SAVE_FILE_PATTERNS = ["*.meta.*"]
SSD_DATA_FILE_MIN_SIZE = 0
@dataclass
class KeyInfo:
offset: int
emb_size: int
embedding: List[float]
class SaveModelThread(threading.Thread):
def __init__(self, saver, sess, result, root_dir, table_name):
super().__init__()
self.result = result
self.root_dir = root_dir
self.table_name = table_name
self.sess = sess
self.saver = saver
def run(self):
self.saver.save_table_name_data(self.sess, self.result, self.root_dir, self.table_name)
class Saver(object):
@para_checker_decorator(check_option_list=[
("var_list", ClassValidator, {"classes": (list, type(None))}),
("max_to_keep", IntValidator, {"min_value": 0, "max_value": ValidatorParams.MAX_INT32.value}, ["check_value"]),
("prefix_name", ClassValidator, {"classes": (str, type(None))}),
("prefix_name", OptionalStringValidator, {"min_len": 1, "max_len": 50}, ["check_string_length"]),
])
def __init__(self, var_list=None, max_to_keep=3, prefix_name="checkpoint", warm_start_tables=None):
self.max_to_keep = max_to_keep
self._prefix_name = prefix_name
self.var_list = var_list
self.rank_id = get_rank_id()
self.local_rank_size = get_local_rank_size()
self.local_rank_id = self.rank_id % self.local_rank_size
self.rank_size = get_rank_size()
self.save_op_dict = defaultdict(dict)
self.restore_fetch_dict = defaultdict()
self.placeholder_dict = defaultdict(dict)
self._last_checkpoints = []
self.config_instance = ConfigInitializer.get_instance()
self.build()
self.warm_start_tables = warm_start_tables
@staticmethod
def _check_file_system_is_valid(save_path):
if not check_file_system_is_valid(save_path):
raise ValueError("the path to save sparse embedding table data belong to invalid file system, "
"only local file system and hdfs file system supported. ")
def build(self):
self._modify_graph_for_export_model()
if self.var_list is None:
self.var_list = []
logger.debug("optimizer collection name: %s",
self.config_instance.train_params_config.ascend_global_hashtable_collection)
temp_var_list = tf.compat.v1.get_collection(
self.config_instance.train_params_config.ascend_global_hashtable_collection)
for var in temp_var_list:
table_instance = self.config_instance.sparse_embed_config.get_table_instance(var)
if table_instance.is_save:
self.var_list.append(var)
with tf.compat.v1.variable_scope("mx_rec_save"):
self._build_save()
with tf.compat.v1.variable_scope("mx_rec_restore"):
self._build_restore()
logger.debug("Save & Restore graph was built.")
@performance("Save")
def save(self, sess, save_path="model", global_step=None, save_delta=False):
"""
Save sparse tables. checkpoint is saved in under format:
./rank_id/HashTable/HBM/embed_table_name/key/xxx.data
./rank_id/HashTable/HBM/embed_table_name/key/xxx.attribute
./rank_id/HashTable/HBM/embed_table_name/embedding/xxx.data
./rank_id/HashTable/HBM/embed_table_name/embedding/xxx.attribute
:param sess: A Session to use to save the sparse table variables
:param save_path: Only absolute path supported
:param global_step: If provided the global step number is appended to save_path to create
the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer.
:param save_delta: check if save delta model in incremental checkpoint pattern
:return: None
"""
logger.debug("======== Start saving for rank id %s ========", self.rank_id)
self._check_file_system_is_valid(save_path)
save_path = save_path if save_path else self._prefix_name
directory, base_name = os.path.split(save_path)
save_path_prefix = SAVE_SPARSE_PATH_PREFIX if not save_delta else SAVE_DELTA_SPARSE_PATH_PREFIX
ckpt_name = self._build_checkpoint_name(save_path_prefix, base_name, global_step, sess)
saving_path = os.path.join(directory, ckpt_name)
self.config_instance.train_params_config.sparse_dir = saving_path
try:
if not check_file_system_is_hdfs(saving_path):
directory_validator = DirectoryValidator("saving_path", saving_path)
directory_validator.check_not_soft_link()
directory_validator.with_blacklist(exact_compare=False)
directory_validator.check()
except ValueError as err:
raise ValueError(f"The saving path {saving_path} cannot be a system directory "
f"and cannot be soft link.") from err
if not tf.io.gfile.exists(saving_path):
try:
if check_file_system_is_hdfs(saving_path):
tf.io.gfile.makedirs(saving_path)
else:
os.makedirs(saving_path, SAVE_DIR_MODE, exist_ok=True)
except Exception as err:
raise RuntimeError(f"make dir {saving_path} for saving sparse table failed!") from err
logger.info("Saving_path '%s' has been made.", saving_path)
self._save(sess, saving_path, save_delta)
if self.max_to_keep:
self._last_checkpoints.append(saving_path)
if len(self._last_checkpoints) > self.max_to_keep:
logger.info("checkpoints num %d > max_to_keep %d delete %s",
len(self._last_checkpoints), self.max_to_keep,
self._last_checkpoints[0])
checkpoint_path = self._last_checkpoints.pop(0)
file_validator = FileValidator("checkpoint_path", checkpoint_path)
if not check_file_system_is_hdfs(checkpoint_path):
file_validator.check_not_soft_link()
file_validator.check()
try:
tf.io.gfile.rmtree(checkpoint_path)
except tf.errors.NotFoundError as e:
logger.warning("oldest checkpoint file is not exist, maybe it has been deleted.")
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
comm.Barrier()
merge_success = 1
try:
if should_write_data(rank, saving_path):
table_list = self.save_op_dict.keys()
for table_name in table_list:
self.merge_sparse_file(saving_path, table_name)
except Exception as err:
merge_success = 0
err_msg = f"[rank {rank}] merge_sparse_file failed: {err}\n{traceback.format_exc()}"
logger.error(err_msg)
merge_success = comm.bcast(merge_success, root=0)
if not merge_success:
logger.error("MPI job aborted due to merge_sparse_file failed")
comm.Abort(1)
comm.Barrier()
logger.info("sparse model was saved in dir '%s' .", saving_path)
logger.info("======== Saving finished for rank id %s ========", self.rank_id)
@performance("Restore")
def restore(self, sess, reading_path, warm_start_tables=None, model_type="base"):
logger.debug("======== Start restoring ========")
if not check_file_system_is_valid(reading_path):
raise ValueError("the path to save sparse embedding table data belong to invalid file system, "
"only local file system and hdfs file system supported. ")
directory, base_name = os.path.split(reading_path)
if model_type == BASE_MODEL:
ckpt_name = f"{SAVE_SPARSE_PATH_PREFIX}-{base_name}"
else:
ckpt_name = f"tmp-{SAVE_SPARSE_PATH_PREFIX}-{base_name}"
reading_path = os.path.join(directory, ckpt_name)
if not tf.io.gfile.exists(reading_path):
raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.")
file_validator = FileValidator("reading_path", reading_path)
if not check_file_system_is_hdfs(reading_path):
file_validator.check_not_soft_link()
file_validator.check()
self._restore(sess, reading_path, warm_start_tables)
if model_type == DELTA_MODEL:
try:
tf.io.gfile.rmtree(reading_path)
except tf.errors.NotFoundError:
logger.warning("%s is not exists, maybe it has been deleted.", reading_path)
logger.info("sparse model was restored from dir '%s' .", reading_path)
logger.debug("======== Restoring finished ========")
@performance("save_table_name_data")
def save_table_name_data(self, sess, result, root_dir, table_name):
dump_data_dict = sess.run(result.get(table_name))
self._get_valid_dict_data(dump_data_dict, table_name)
save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id)
if "optimizer" in dump_data_dict:
dump_optimizer_data_dict = dump_data_dict.get("optimizer")
for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items():
save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, self.rank_id)
def _build_checkpoint_name(self, save_path_prefix, base_name, global_step, sess):
if not global_step:
return f"{save_path_prefix}-{base_name}"
if not isinstance(global_step, compat.integral_types):
global_step = int(sess.run(global_step))
return f"{save_path_prefix}-{base_name}-{global_step}"
def merge_sparse_file(self, root_dir: str, table_name: str):
"""
将多卡保存下来的多个二进制文件合成一个
Args:
root_dir: 合并路径
table_name: 被合并的表名
Returns: None
"""
logger.info("Start merge sparse file, merge dir:%s, table_name:%s.", root_dir, table_name)
table_dir = os.path.join(root_dir, table_name)
table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name(table_name)
merge_type_list = get_merge_type_list(table_dir)
if not check_file_system_is_hdfs(root_dir):
dir_validator = DirectoryValidator("root_dir", root_dir)
dir_validator.check_not_soft_link()
try:
dir_validator.check()
except ValueError as e:
raise ValueError(f"root_dir:{root_dir} can't be soft link") from e
for data_type in merge_type_list:
upper_dir = os.path.join(table_dir, data_type)
if table_instance.is_dp:
rename_file_and_remove_others(upper_dir)
else:
merge_multi_files(upper_dir)
outfile_path = os.path.join(upper_dir, "slice.data")
file_size = tf.io.gfile.stat(outfile_path).length
if data_type == "key":
attribute = np.array([file_size / 8, 8])
else:
attribute = np.array([file_size / 4 / table_instance.emb_size, table_instance.emb_size, 4])
attribute = attribute.astype(np.int64)
attribute_dir = os.path.join(upper_dir, "slice.attribute")
if check_file_system_is_hdfs(attribute_dir):
with tf.io.gfile.GFile(attribute_dir, "wb") as file:
attribute = attribute.tostring()
file.write(attribute)
else:
with os.fdopen(os.open(attribute_dir, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb") as file:
file.write(attribute.tostring())
def get_warm_start_dict(self, table_list):
placeholder_dict = defaultdict(dict)
restore_fetch_list = []
for table_name, v in self.placeholder_dict.items():
if table_name in table_list:
placeholder_dict[table_name] = v
restore_fetch_list.append(self.restore_fetch_dict.get(table_name))
if not restore_fetch_list:
logger.warning("no tables can be warm start restored.")
return placeholder_dict, restore_fetch_list
@performance("_save")
def _save(self, sess, root_dir, save_delta):
for table_name in self.save_op_dict:
optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance
if optimizer_instance:
set_optimizer_info(optimizer_instance, table_name)
table_instance0 = self.config_instance.sparse_embed_config.get_table_instance(self.var_list[0])
if table_instance0.is_hbm:
self._save_hbm(sess, root_dir, save_delta)
else:
self._save_ddr(sess, root_dir, save_delta)
logger.debug(f"Host data was saved.")
def _save_hbm(self, sess, root_dir, save_delta):
self.config_instance.hybrid_manager_config.save_host_data(root_dir, save_delta)
if self.config_instance.use_dynamic_expansion:
return
result = self.save_op_dict
threads = []
for table_name in result.keys():
table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name(table_name)
if not should_save_sparse_embedding(table_instance.is_dp, root_dir):
continue
thread = SaveModelThread(self, sess, result, root_dir, table_name)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def _save_ddr(self, sess, root_dir, save_delta):
self.config_instance.hybrid_manager_config.start_sync_thread()
self.config_instance.hybrid_manager_config.fetch_device_emb()
sess.graph._unsafe_unfinalize()
for var in self.var_list:
table_instance = self.config_instance.sparse_embed_config.get_table_instance(var)
table_name = table_instance.table_name
use_static = ConfigInitializer.get_instance().use_static
max_lookup_vec_size = None
if use_static:
max_lookup_vec_size = table_instance.send_count * self.rank_size
swap_out_pos, swap_out_len, sync_remain_flag = npu_ops.gen_npu_ops.get_next(
output_types=[tf.int32, tf.int32, tf.bool],
output_shapes=[[max_lookup_vec_size], [], []],
channel_name=f"{table_name}_save_h2d_{TRAIN_CHANNEL_ID}")
if use_static:
swap_out_pos = swap_out_pos[:swap_out_len]
table = [var]
optimizer = get_optimizer_dict_by_table_name(table_name)
if optimizer is not None:
for slots in optimizer.values():
table += list(slots.values())
swap_outs = [tf.gather(one_table, swap_out_pos) for one_table in table]
swap_out = tf.concat(swap_outs, axis=1)
channel_name = f"{table_name}_save_d2h_{TRAIN_CHANNEL_ID}"
logger.info("Channel %s was built for op swap_out_op.", channel_name)
swap_out_op = npu_ops.outfeed_enqueue_op(channel_name=channel_name, inputs=[swap_out])
sync_cnt = 0
is_sync_remain = True
while is_sync_remain:
_, is_sync_remain = sess.run([swap_out_op, sync_remain_flag])
sync_cnt += 1
logger.info("Sending embedding to host, table:%s, sync_cnt:%d, is_sync_remain:%d.",
table_name, sync_cnt, is_sync_remain)
logger.info("Finish sending embedding to host, table:%s.", table_name)
self._save_host_data(root_dir, save_delta, sess)
def _get_valid_dict_data(self, dump_data_dict, table_name):
host_data = self.config_instance.hybrid_manager_config.get_host_data(table_name)
offset = list(host_data)
get_valid_dict_data_from_host_offset(dump_data_dict, offset)
def _build_save(self):
for var in self.var_list:
if global_env.tf_device == DeviceType.NPU.value and "merged" not in var.name:
continue
table_instance = self.config_instance.sparse_embed_config.get_table_instance(var)
table_name = table_instance.table_name
with tf.compat.v1.variable_scope(table_name):
sub_dict = self.save_op_dict[table_name]
sub_dict[DataName.EMBEDDING.value] = var
optimizer = get_optimizer_dict_by_table_name(table_name)
if optimizer:
sub_dict["optimizer"] = optimizer
def _build_restore(self):
for var in self.var_list:
if global_env.tf_device == DeviceType.NPU.value and "merged" not in var.name:
continue
table_instance = self.config_instance.sparse_embed_config.get_table_instance(var)
sub_placeholder_dict = self.placeholder_dict[table_instance.table_name]
with tf.compat.v1.variable_scope(table_instance.table_name):
sub_placeholder_dict[DataName.EMBEDDING.value] = variable = \
tf.compat.v1.placeholder(dtype=tf.float32, shape=[table_instance.slice_device_vocabulary_size,
table_instance.emb_size],
name=DataName.EMBEDDING.value)
assign_op = var.assign(variable)
self.restore_fetch_dict[table_instance.table_name] = [assign_op]
optimizer = get_optimizer_dict_by_table_name(table_instance.table_name)
if optimizer:
self._build_optimizer_restore(sub_placeholder_dict, table_instance, optimizer)
def _build_optimizer_restore(self, sub_placeholder_dict, table_instance, optimizer):
sub_placeholder_dict["optimizer"] = optimizer_placeholder_dict = dict()
for optimizer_name, optimizer_state_dict in optimizer.items():
optimizer_placeholder_dict[optimizer_name] = sub_optimizer_placeholder_dict = \
dict([(state_key, tf.compat.v1.placeholder(dtype=tf.float32,
shape=[table_instance.slice_device_vocabulary_size,
table_instance.emb_size],
name=state_key))
for state_key, state in optimizer_state_dict.items()])
for key_state, state in optimizer_state_dict.items():
if sub_optimizer_placeholder_dict.get(key_state).graph is not state.graph:
continue
assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state))
self.restore_fetch_dict[table_instance.table_name].append(assign_op)
def _restore(self, sess, reading_path, warm_start_tables=None):
if warm_start_tables:
placeholder_dict, restore_fetch_list = self.get_warm_start_dict(warm_start_tables)
else:
placeholder_dict, restore_fetch_list = self.placeholder_dict, self.restore_fetch_dict
for table_name in placeholder_dict:
optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance
if optimizer_instance:
set_optimizer_info(optimizer_instance, table_name)
if self.config_instance.hybrid_manager_config.asc_manager:
self.config_instance.hybrid_manager_config.restore_host_data(reading_path, warm_start_tables)
logger.info("host data was restored.")
table_instance0 = self.config_instance.sparse_embed_config.get_table_instance(self.var_list[0])
if not table_instance0.is_hbm:
return
if self.config_instance.use_dynamic_expansion:
return
restore_feed_dict = defaultdict(dict)
for table_name, sub_placeholder_dict in placeholder_dict.items():
load_offset = self.config_instance.hybrid_manager_config.get_load_offset(table_name)
fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict,
NameDescriptor(table_name, DataName.EMBEDDING.value), load_offset)
if "optimizer" in sub_placeholder_dict:
optimizer_state_placeholder_dict_group = sub_placeholder_dict.get("optimizer")
_fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path,
restore_feed_dict, table_name, load_offset)
sess.run(restore_fetch_list, feed_dict=restore_feed_dict)
def _modify_graph_for_export_model(self):
experimental_mode = self.config_instance.train_params_config.experimental_mode
if experimental_mode is None or not self.config_instance.modify_graph:
return
is_training = experimental_mode == tf.compat.v1.estimator.ModeKeys.TRAIN
do_merge_lookup(is_train=is_training)
slot_num = 0
optimizer_ins = self.config_instance.optimizer_config.optimizer_instance
if optimizer_ins is not None:
change_ext_emb_size_by_opt(optimizer_ins)
slot_num = optimizer_ins.slot_num
channel_id = 0 if is_training else 1
replace_anchor_for_ddr_ssd(tf.compat.v1.get_default_graph(), slot_num, channel_id)
def _save_host_data(self, root_dir: str, save_delta: bool, sess: tf.compat.v1.Session):
if ConfigInitializer.get_instance().train_params_config.experimental_mode is None:
self.config_instance.hybrid_manager_config.save_host_data(root_dir, save_delta)
return
all_saved = True
global_step = sess.run(tf.compat.v1.train.get_global_step())
if global_step is None:
raise ValueError("the global step cannot be None")
ssd_save_file_patterns = [pattern + str(global_step) for pattern in SSD_SAVE_FILE_PATTERNS]
logger.info("The patterns of the ssd file is: %s.", ssd_save_file_patterns)
for var in self.var_list:
table_instance = self.config_instance.sparse_embed_config.get_table_instance(var)
if table_instance.ssd_vocabulary_size == 0:
continue
for ssd_path in table_instance.ssd_data_path:
data_path = os.path.join(ssd_path, SSD_SAVE_PATH_PREFIX + "*")
is_exists = check_files_in_directories(data_path, ssd_save_file_patterns)
all_saved &= is_exists
is_save_l3_storage = not all_saved
logger.info("The `is_save_l3_storage` is %s.", is_save_l3_storage)
self.config_instance.hybrid_manager_config.save_host_data(root_dir, save_delta, is_save_l3_storage)
class NameDescriptor:
def __init__(self, table_name, data_name, optimizer_name=None):
self.table_name = table_name
self.data_name = data_name
self.optimizer_name = optimizer_name
def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list):
"""
Extract embedding and optimizer data from the dict based on offset.
:param dump_data_dict: sparse data dict to be saved
:param offset: offset of the sparse table
"""
embedding_data = dump_data_dict.get(DataName.EMBEDDING.value)[offset, :]
dump_data_dict[DataName.EMBEDDING.value] = embedding_data
if "optimizer" in dump_data_dict:
dump_optimizer_data_dict = dump_data_dict.get("optimizer")
for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items():
for state_key, state in dump_optimizer_data.items():
state = state[offset, :]
dump_optimizer_data[state_key] = state
dump_optimizer_data_dict[optimizer_name] = dump_optimizer_data
dump_data_dict["optimizer"] = dump_optimizer_data_dict
def fill_placeholder(reading_path: str, placeholder_dict: Dict[str, tf.Tensor],
feed_dict: Dict[str, Dict[str, tf.Tensor]],
name_descriptor: NameDescriptor, load_offset: List[int]):
if name_descriptor.optimizer_name:
target_path = generate_path(reading_path, name_descriptor.table_name,
name_descriptor.optimizer_name + "_" + name_descriptor.data_name)
else:
target_path = generate_path(reading_path, name_descriptor.table_name, name_descriptor.data_name)
restore_data_dict = read_binary_data(target_path, name_descriptor.data_name, name_descriptor.table_name,
load_offset)
for key, data in restore_data_dict.items():
embedding_placeholder = placeholder_dict.get(key)
feed_dict[embedding_placeholder] = data
@performance("save_embedding_data")
def save_embedding_data(root_dir, table_name, dump_data_dict, suffix):
target_path = generate_path(root_dir, table_name, DataName.EMBEDDING.value)
data_to_write = dump_data_dict.get(DataName.EMBEDDING.value)
attribute = dict()
attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name
attribute[DataAttr.SHAPE.value] = data_to_write.shape
write_binary_data(target_path, suffix, data_to_write)
def save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, suffix):
for state_key, state in dump_optimizer_data.items():
target_path = generate_path(root_dir, table_name, optimizer_name + "_" + state_key)
data_to_write = state
attribute = dict()
attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name
attribute[DataAttr.SHAPE.value] = data_to_write.shape
write_binary_data(target_path, suffix, data_to_write)
def generate_path(*args):
return os.path.join(*args)
def generate_file_name(suffix):
return "slice_%d.data" % suffix, "slice_%d.attribute" % suffix
def write_binary_data(writing_path: str, suffix: int, data: np.ndarray):
try:
if check_file_system_is_hdfs(writing_path):
tf.io.gfile.makedirs(writing_path)
else:
os.makedirs(writing_path, SAVE_DIR_MODE, exist_ok=True)
except Exception as err:
raise RuntimeError(f"make dir {writing_path} for writing data failed!") from err
data_file, _ = generate_file_name(suffix)
target_data_dir = os.path.join(writing_path, data_file)
write_mode = "wb" if not tf.io.gfile.exists(target_data_dir) else "ab"
if check_file_system_is_hdfs(target_data_dir):
with tf.io.gfile.GFile(target_data_dir, write_mode) as file:
data = data.tostring()
file.write(data)
else:
with os.fdopen(os.open(target_data_dir, SAVE_FILE_FLAG, SAVE_FILE_MODE), write_mode) as file:
file.write(data.tostring())
def read_binary_data(reading_path: str, data_name: str, table_name: str, load_offset) -> dict:
"""
Read sparse origin data from binary file
:param reading_path: sparse data path
:param suffix: suffix of sparse data
:param data_name: the data type,including embedding, offset, etc.
:param table_name: the sparse table name
:return: the sparse data dict
"""
data_file, attribute_file = "slice.data", "slice.attribute"
target_data_dir = os.path.join(reading_path, data_file)
target_attribute_dir = os.path.join(reading_path, attribute_file)
if not tf.io.gfile.exists(target_data_dir):
raise FileExistsError(f"Target_data_dir {target_data_dir} does not exist when reading.")
if not tf.io.gfile.exists(target_attribute_dir):
raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} does not exist when reading.")
attributes = read_attribute_file(target_attribute_dir)
data_to_restore = read_data_file(target_data_dir)
try:
embedding_size = list(attributes)[1]
except Exception as err:
raise RuntimeError(f"get embedding size from attribute file {target_attribute_dir} failed.") from err
data_to_restore = data_to_restore.reshape(-1, embedding_size)
if load_offset:
data_to_restore = data_to_restore[load_offset, :]
data_shape = list(data_to_restore.shape)
table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name(table_name)
current_data_shape = [table_instance.slice_device_vocabulary_size, table_instance.emb_size]
if data_shape != current_data_shape:
data_to_restore = process_embedding_data(data_to_restore, current_data_shape, data_shape)
data_dict = {data_name: data_to_restore}
logger.debug("Attribute: '%s' and data file: '%s' have been read.", target_attribute_dir, target_data_dir)
logger.debug("Reading shape is %s.", data_to_restore.shape)
return data_dict
def validate_read_file(read_file_path):
"""
Validate file before reading,including validating soft link, file size
:param read_file_path: the file path to be validated
"""
file_validator = FileValidator("read_file_path", read_file_path)
file_validator.check_file_size(FileParams.MAX_FILE_SIZE.value, FileParams.MIN_SIZE.value)
if not check_file_system_is_hdfs(read_file_path):
file_validator.check_not_soft_link()
file_validator.check_user_group()
file_validator.check_file_mode()
file_validator.check()
def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list, data_shape: list) -> np.ndarray:
"""
Process embedding data when reading binary file
:param data_to_restore: the embedding data reading from the binary file
:param current_data_shape: current embedding data shape set by user
:param data_shape: embedding data shape saved in the binary file
:return: the embedding data
"""
try:
restore_vocab_size, restore_emb_size = current_data_shape
vocab_size, emb_size = data_shape
except ValueError as err:
raise ValueError(f"The shape dimension of a sparse table cannot exceed two dimensions. ") from err
if restore_vocab_size > vocab_size:
pad_count = restore_vocab_size - vocab_size
pad_matrix = np.zeros((pad_count, restore_emb_size))
data_to_restore = np.concatenate((data_to_restore, pad_matrix), axis=0)
elif restore_vocab_size < vocab_size:
data_to_restore = data_to_restore[:restore_vocab_size, :]
return data_to_restore
def check_file_system_is_valid(file_path):
if file_path.find("://") == -1 or check_file_system_is_hdfs(file_path):
return True
return False
def check_file_system_is_hdfs(file_path):
for prefix in HDFS_FILE_PREFIX:
if file_path.startswith(prefix):
return True
return False
def get_hdfs_safemode_status():
try:
result = subprocess.run(["/usr/local/hadoop-3.3.6/bin/hdfs", "dfsadmin", "-safemode", "get"],
capture_output=True, text=True, check=True, shell=False)
output = result.stdout.strip()
logger.info(f"HDFS safemode status:{output}.")
return output
except FileNotFoundError as err:
logger.warning(f"Command 'hdfs' not found. Ignore this exception in non-HDFS scenario. Please ensure Hadoop"
f"is installed and 'hdfs' is in your PATH in HDFS scenario.")
except Exception as err:
logger.warning(f"Failed to get HDFS safemode status:{err}. Ignore this exception in non-HDFS scenario.")
return ""
def check_hdfs_safemode_status():
status = get_hdfs_safemode_status()
if "Safe mode is ON" in status:
raise RuntimeError(
"The current HDFS is in safe mode. It is recommended to check the server disk space and the usage of HDFS "
"resources. Use 'hdfs dfsadmin -safemode leave' to set Safe mode is OFF, and then run again."
)
def _fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group: dict, reading_path: str,
restore_feed_dict: dict, table_name: str, load_offset: list):
"""
给优化器填充加载的数据.
Args:
optimizer_state_placeholder_dict_group: 待填充优化器的字典
reading_path: 读取路径
restore_feed_dict: session run的feed dict
suffix: rank id
table_name: 表名
Returns: None
"""
for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items():
for state_key in optimizer_state_placeholder_dict:
fill_placeholder(reading_path=reading_path,
placeholder_dict=optimizer_state_placeholder_dict,
feed_dict=restore_feed_dict,
name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name),
load_offset=load_offset)
def get_merge_type_list(table_dir: str):
"""
获取表路径下需要合入的数据类型list
Args:
table_dir: 稀疏表存储路径
Returns: None
"""
merge_type_list = []
for item in tf.io.gfile.listdir(table_dir):
if tf.io.gfile.isdir(os.path.join(table_dir, item)):
merge_type_list.append(item)
return merge_type_list
def merge_multi_files(upper_dir: str):
"""
合并多个二进制文件
Args:
upper_dir: 合并路径
Returns: None
"""
if check_file_system_is_hdfs(upper_dir):
merge_hdfs_file(upper_dir)
return
merge_local_file(upper_dir)
def merge_hdfs_file(upper_dir: str):
data_files = [file for file in tf.io.gfile.listdir(upper_dir) if file.startswith("slice_")]
data_files = sorted(data_files, key=os.path.basename)
outfile_path = os.path.join(upper_dir, "slice.data")
outfile = tf.io.gfile.GFile(outfile_path, "wb")
for file in data_files:
file_dir = os.path.join(upper_dir, file)
with tf.io.gfile.GFile(file_dir, "rb") as file:
outfile.write(file.read())
tf.io.gfile.remove(file_dir)
outfile.close()
def merge_local_file(upper_dir: str) -> None:
data_files = [file for file in os.listdir(upper_dir) if file.startswith("slice_")]
data_files = sorted(data_files, key=os.path.basename)
outfile_path = os.path.join(upper_dir, "slice.data")
outfile = os.fdopen(os.open(outfile_path, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb", buffering=FILE_BUFFER_SIZE)
for file in data_files:
file_dir = os.path.join(upper_dir, file)
if os.path.getsize(file_dir) == 0:
os.remove(file_dir)
continue
f = open(file_dir, "rb", buffering=FILE_BUFFER_SIZE)
while True:
data = f.read(FILE_BUFFER_SIZE)
if not data:
break
outfile.write(data)
f.close()
os.remove(file_dir)
outfile.close()
def rename_file_and_remove_others(upper_dir: str):
"""
In DP mode, the embeddings of all cards are the same, and there is no need to merge the saved files.
Args:
upper_dir: model save path.
Returns: None
"""
data_files = [
file
for file in tf.io.gfile.listdir(upper_dir)
if file.startswith("slice_")
]
if not data_files:
raise RuntimeError(
f"rename and remove file failed, slice*.data do not exist in {upper_dir}."
)
data_files = sorted(data_files, key=os.path.basename)
data_file = os.path.join(upper_dir, data_files[0])
output_file = os.path.join(upper_dir, "slice.data")
tf.io.gfile.rename(data_file, output_file, overwrite=True)
for file in data_files[1:]:
file_dir = os.path.join(upper_dir, file)
if tf.io.gfile.exists(file_dir):
tf.io.gfile.remove(file_dir)
def set_optimizer_info(optimizer: CustomizedOptimizer, table_name: str):
"""
往host侧传递稀疏表的优化器名称信息
Args:
optimizer_dict: 优化器字典
table_name: 表名
Returns: None
"""
from mxrec_pybind import OptimizerInfo
optim_info = OptimizerInfo(optimizer.optimizer_type, optimizer.optim_param_list)
ConfigInitializer.get_instance().hybrid_manager_config.set_optim_info(table_name, optim_info)
def should_write_data(rank_id: int, save_path: str) -> bool:
is_hdfs = check_file_system_is_hdfs(save_path)
local_rank_size = get_local_rank_size()
return rank_id == 0 if is_hdfs else rank_id % local_rank_size == 0
def update_model_index(save_dir: str, model_index: Dict[str, Union[str, int]]):
model_index_file = os.path.join(save_dir, "model_index.json")
if not tf.io.gfile.exists(model_index_file):
model_index_list = []
else:
with tf.io.gfile.GFile(model_index_file, "r") as f:
model_index_list = json.load(f)
model_index_list.append(model_index)
if check_file_system_is_hdfs(model_index_file):
with tf.io.gfile.GFile(model_index_file, "w") as f:
json.dump(model_index_list, f, ensure_ascii=False, separators=(",", ": "), indent=4)
else:
dir_validator = DirectoryValidator("save_dir", save_dir)
dir_validator.check_not_soft_link()
try:
dir_validator.check()
except ValueError as e:
raise ValueError(f"save_dir:{save_dir} can't be soft link") from e
with os.fdopen(os.open(model_index_file, SAVE_FILE_FLAG, SAVE_FILE_MODE), "w") as f:
json.dump(model_index_list, f, ensure_ascii=False, separators=(",", ": "), indent=4)
def write_delta_export_time_ms(save_dir: str, delta_export_time_ms: dict):
delta_export_time_ms_file = os.path.join(save_dir, "delta_export_time_ms.json")
if check_file_system_is_hdfs(delta_export_time_ms_file):
with tf.io.gfile.GFile(delta_export_time_ms_file, "w") as f:
json.dump(delta_export_time_ms, f, indent=4)
else:
dir_validator = DirectoryValidator("save_dir", save_dir)
dir_validator.check_not_soft_link()
try:
dir_validator.check()
except ValueError as e:
raise ValueError(f"save_dir:{save_dir} can't be soft link") from e
with os.fdopen(os.open(delta_export_time_ms_file, SAVE_FILE_FLAG, SAVE_FILE_MODE), "w") as f:
json.dump(delta_export_time_ms, f, indent=4)
def get_model_type_by_version(save_dir: str, model_version: str):
model_index_file = os.path.join(save_dir, "model_index.json")
validate_read_file(model_index_file)
with tf.io.gfile.GFile(model_index_file, "r") as f:
model_index_list = json.load(f)
model_type = None
for model_index in model_index_list:
try:
model_version_int = int(model_version)
except ValueError as err:
raise ValueError("Can not transfer %s to integer.", model_version) from err
if model_index[GLOBAL_STEP_STR] == model_version_int:
model_type = model_index["type"]
return model_type
return model_type
def get_base_and_delta_models(save_dir: str, model_version: str):
model_index_file = os.path.join(save_dir, "model_index.json")
validate_read_file(model_index_file)
with tf.io.gfile.GFile(model_index_file, "r") as f:
model_index_list = json.load(f)
model_index_list.reverse()
base_model = ""
delta_models = []
found_delta_model = False
for model_index in model_index_list:
if model_index[GLOBAL_STEP_STR] == int(model_version):
delta_models.append(model_version)
found_delta_model = True
continue
if not found_delta_model:
continue
if model_index["type"] == DELTA_MODEL:
delta_models.append(str(model_index[GLOBAL_STEP_STR]))
else:
base_model = str(model_index[GLOBAL_STEP_STR])
break
delta_models.reverse()
return base_model, delta_models
def read_base_delta_and_write(save_dir: str, base_model: str, delta_models: list):
table_name_set = ConfigInitializer.get_instance().sparse_embed_config.table_name_set
optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance
optimizer_type, optim_param_list, optimizer_param_name_list = None, None, []
if optimizer:
optimizer_type, optim_param_list = optimizer.optimizer_type, optimizer.optim_param_list
optimizer_param_name_list = [f"{optimizer_type}_{optim_param}" for optim_param in optim_param_list]
base_optimizer = None if not optimizer else get_base_optimizer(save_dir, table_name_set, base_model)
base_table = get_base_key_embedding(save_dir, table_name_set, base_model)
for delta_model in delta_models:
delta_model_path = os.path.join(save_dir, f"{SAVE_DELTA_SPARSE_PATH_PREFIX}-model.ckpt-{delta_model}")
delta_optimizer_params = {}
for table_name in table_name_set:
delta_key_data, delta_embedding_data = get_table_key_emb(delta_model_path, table_name)
for optimizer_param_name in optimizer_param_name_list:
delta_optimizer_params[optimizer_param_name] = \
get_table_optimizer_param(delta_model_path, table_name, optimizer_param_name)
len_of_delta_table = len(delta_key_data)
for i in range(len_of_delta_table):
key = base_table[table_name]["key"]
embed = base_table[table_name]["embedding"]
idx = None
k, v = delta_key_data[i], delta_embedding_data[i]
if k in key:
idx = np.where(key == k)[0][0]
embed[idx] = v
else:
base_table[table_name]["key"] = np.append(key, k)
base_table[table_name]["embedding"] = np.vstack([embed, v])
if delta_optimizer_params:
for optimizer_param_name in optimizer_param_name_list:
tmp = delta_optimizer_params[optimizer_param_name][i]
optimizer_param = base_optimizer[table_name][optimizer_param_name]
if idx is not None:
optimizer_param[idx] = tmp
else:
base_optimizer[table_name][optimizer_param_name] = np.vstack([optimizer_param, tmp])
tmp_path = f"{save_dir}/tmp-{SAVE_SPARSE_PATH_PREFIX}-model.ckpt-{delta_models[-1]}"
write_base_table_to_file(tmp_path, base_table)
if optimizer:
write_base_table_to_file(tmp_path, base_optimizer)
return tmp_path
def get_table_key_emb(model_path: str, table_name: str):
key_path = os.path.join(model_path, table_name, "key")
data_file = os.path.join(key_path, "slice.data")
key_data = read_attribute_file(data_file)
embedding_path = os.path.join(model_path, table_name, "embedding")
attribute_file = os.path.join(embedding_path, "slice.attribute")
embed_attr = read_attribute_file(attribute_file)
data_file = os.path.join(embedding_path, "slice.data")
embedding_data = read_data_file(data_file).reshape(embed_attr[:-1])
return key_data, embedding_data
def write_base_table_to_file(save_dir: str, base_table: dict):
if not check_file_system_is_hdfs(save_dir):
dir_validator = DirectoryValidator("save_dir", save_dir)
dir_validator.check_not_soft_link()
try:
dir_validator.check()
except ValueError as e:
raise ValueError(f"save_dir:{save_dir} can't be soft link") from e
for table_name, table in base_table.items():
for k, v in table.items():
writing_path = os.path.join(save_dir, table_name, k)
try:
if check_file_system_is_hdfs(writing_path):
tf.io.gfile.makedirs(writing_path)
else:
os.makedirs(writing_path, SAVE_DIR_MODE, exist_ok=True)
except Exception as err:
raise RuntimeError(f"Create dir {writing_path} for writing data failed!") from err
data_file, attribute_file = "slice.data", "slice.attribute"
target_data_dir = os.path.join(writing_path, data_file)
target_attribute_dir = os.path.join(writing_path, attribute_file)
write_bytes = 8 if k == "key" else 4
attribute = np.append(v.shape, write_bytes)
if check_file_system_is_hdfs(writing_path):
with tf.io.gfile.GFile(target_attribute_dir, "wb") as file:
file.write(attribute.tostring())
with tf.io.gfile.GFile(target_data_dir, "wb") as file:
file.write(v.tostring())
else:
with os.fdopen(os.open(target_attribute_dir, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb") as file:
file.write(attribute.tostring())
with os.fdopen(os.open(target_data_dir, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb") as file:
file.write(v.tostring())
def clear_delta_models(save_dir: str):
delta_directories = glob.glob(os.path.join(save_dir, 'delta-sparse*'))
for delta_dir in delta_directories:
file_validator = FileValidator("delta_dir", delta_dir)
if not check_file_system_is_hdfs(delta_dir):
file_validator.check_not_soft_link()
file_validator.check()
try:
tf.io.gfile.rmtree(delta_dir)
except tf.errors.NotFoundError:
logger.warning("%s is not exists, maybe it has been deleted.", delta_dir)
def get_table_optimizer_param(model_path: str, table_name: str, optimizer_param_name: str):
attribute_file = os.path.join(model_path, table_name, optimizer_param_name, "slice.attribute")
data_file = os.path.join(model_path, table_name, optimizer_param_name, "slice.data")
attribute = read_attribute_file(attribute_file)
data_to_restore = read_data_file(data_file).reshape(attribute[:-1])
return data_to_restore
def read_attribute_file(target_attribute_dir: str):
with tf.io.gfile.GFile(target_attribute_dir, "rb") as fin:
validate_read_file(target_attribute_dir)
attributes = fin.read()
try:
attributes = np.fromstring(attributes, dtype=np.int64)
except ValueError as err:
raise RuntimeError(f"get attributes from file {target_attribute_dir} failed.") from err
return attributes
def read_data_file(target_data_dir: str):
with tf.io.gfile.GFile(target_data_dir, "rb") as file:
validate_read_file(target_data_dir)
if check_file_system_is_hdfs(target_data_dir):
data_to_restore = file.read()
data_to_restore = np.fromstring(data_to_restore, dtype=np.float32)
else:
data_to_restore = np.fromfile(target_data_dir, dtype=np.float32)
return data_to_restore
def get_base_optimizer(save_dir: str, table_name_set: set, base_model: str):
optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance
optimizer_type = optimizer.optimizer_type
optim_param_list = optimizer.optim_param_list
base_optimizer = {}
if optim_param_list:
optimizer_status_name_list = [
f"{optimizer_type}_{optim_param}"
for optim_param in optim_param_list
]
base_optimizer = {
table_name: {optimizer_status_name: None}
for table_name in table_name_set
for optimizer_status_name in optimizer_status_name_list
}
base_model_path = os.path.join(save_dir, f"{SAVE_SPARSE_PATH_PREFIX}-model.ckpt-{base_model}")
for table_name in table_name_set:
for optimizer_status_name in optimizer_status_name_list:
optimier_data = get_table_optimizer_param(base_model_path, table_name, optimizer_status_name)
base_optimizer[table_name][optimizer_status_name] = optimier_data
return base_optimizer
def get_base_key_embedding(save_dir: str, table_name_set: set, base_model: str):
base_table = {table_name: {"key": None, "embedding": None} for table_name in table_name_set}
base_model_path = os.path.join(save_dir, f"{SAVE_SPARSE_PATH_PREFIX}-model.ckpt-{base_model}")
for table_name in table_name_set:
key_data, embedding_data = get_table_key_emb(base_model_path, table_name)
base_table[table_name]["key"] = key_data
base_table[table_name]["embedding"] = embedding_data
return base_table
def should_save_sparse_embedding(is_dp: bool, save_path: str) -> bool:
"""
Whether embeddings need to be saved for each card.
Args:
is_dp: switch whether to enable dp.
save_path: model save path.
Returns:
bool: whether to save.
"""
if not is_dp:
return True
is_hdfs = check_file_system_is_hdfs(save_path)
if is_hdfs and get_rank_id() % get_rank_size() == 0:
return True
if not is_hdfs and get_rank_id() % get_local_rank_size() == 0:
return True
return False
def read_base_delta_and_write_for_ssd(save_dir: str, base_model: str, delta_models: List[str], rank: int) -> None:
"""
read base model and delta models for incremental restore
:param save_dir: model save dir
:param base_model: full model step
:param delta_models: incremental models
:param rank: process id
:return:
"""
current_ssd_dir = os.path.join(os.path.dirname(save_dir), SSD_SAVE_PATH_PREFIX + str(rank))
file_validator = FileValidator("current_ssd_dir", current_ssd_dir)
if not check_file_system_is_hdfs(current_ssd_dir):
file_validator.check_not_soft_link()
file_validator.check()
table_name_set = ConfigInitializer.get_instance().sparse_embed_config.table_name_set
for table_name in table_name_set:
key_info_map = defaultdict(list)
file_ids = _read_table_meta_data(current_ssd_dir, table_name, base_model)
for fid in file_ids:
_read_key_offset_and_embedding(os.path.join(current_ssd_dir, table_name), base_model, fid, False,
key_info_map)
for delta_model in delta_models:
file_ids = _read_table_meta_data(current_ssd_dir, table_name, delta_model)
for fid in file_ids:
_read_key_offset_and_embedding(os.path.join(current_ssd_dir, table_name), delta_model, fid, True,
key_info_map)
_write_ssd_meta_and_data(current_ssd_dir, table_name, file_ids[0], delta_models[-1], key_info_map)
def _read_table_meta_data(current_ssd_dir: str, table_name: str, model: str) -> List[int]:
"""
read table meta data for SSD
:param current_ssd_dir: ssd model saved dir
:param table_name: table name
:param model: step for saving model
:return: [table_name, [fileID]]
"""
table_meta_file = os.path.join(current_ssd_dir, table_name, table_name + ".meta." + model)
with tf.io.gfile.GFile(table_meta_file, 'rb') as file:
validate_read_file(table_meta_file)
name_size_data = file.read(UINT32_BYTES)
if len(name_size_data) < UINT32_BYTES:
raise EOFError("End of file reached before reading name size, file maybe broken.")
name_size = struct.unpack('I', name_size_data)[0]
name_data = file.read(name_size)
if len(name_data) < name_size:
raise EOFError("End of file reached before reading name, file maybe broken.")
file_cnt_data = file.read(UINT64_BYTES)
if len(file_cnt_data) < UINT64_BYTES:
raise EOFError("End of file reached before reading file count, file maybe broken.")
file_cnt = struct.unpack('Q', file_cnt_data)[0]
file_ids = []
for _ in range(file_cnt):
fid_data = file.read(UINT64_BYTES)
if len(fid_data) < UINT64_BYTES:
raise EOFError("End of file reached before reading all file IDs, file maybe broken.")
fid = struct.unpack('Q', fid_data)[0]
file_ids.append(fid)
return file_ids
def _read_key_offset_and_embedding(current_dir: str, model: str, fid: int, is_delta: bool, key_info_map: dict) -> None:
"""
:param current_dir: save dir
:param model: step for saving model
:param fid: file ID for SSD
:param is_delta: the model is whether full or incremental model
:param key_info_map: key info, include key, offset, embedding size and embedding
:return:
"""
table_meta_file = os.path.join(current_dir, str(fid) + ".meta." + model)
table_data_file = os.path.join(current_dir, str(fid) + ".data." + model)
if is_delta:
table_meta_file = os.path.join(current_dir, "delta-" + str(fid) + ".meta." + model)
table_data_file = os.path.join(current_dir, "delta-" + str(fid) + ".data." + model)
key_offset_gen = _read_key_offset(table_meta_file)
embedding_data_gen = _read_embedding_data(table_data_file)
for (key, offset), (emb_size, embedding) in zip(key_offset_gen, embedding_data_gen):
key_info_map[key] = KeyInfo(offset=offset, emb_size=emb_size, embedding=embedding)
def _read_key_offset(file_path: str) -> Generator[Tuple[int, int], None, None]:
"""
read key and offset from meta file
:param file_path: meta file dir
:return:
"""
with tf.io.gfile.GFile(file_path, 'rb') as file:
if tf.io.gfile.stat(file_path).length == SSD_DATA_FILE_MIN_SIZE:
return
validate_read_file(file_path)
every_key_offset_bytes = UINT64_BYTES + UINT32_BYTES
while True:
data = file.read(every_key_offset_bytes)
if len(data) == 0:
break
if len(data) < every_key_offset_bytes:
raise EOFError("End of file reached before reading key_offset, meta file maybe broken.")
key = struct.unpack('q', data[:UINT64_BYTES])[0]
offset = struct.unpack('I', data[UINT64_BYTES:every_key_offset_bytes])[0]
yield key, offset
def _read_embedding_data(file_path: str) -> Generator[Tuple[int, List[float]], None, None]:
"""
read embedding data from data file
:param file_path:
:return:
"""
with tf.io.gfile.GFile(file_path, 'rb') as file:
if tf.io.gfile.stat(file_path).length == SSD_DATA_FILE_MIN_SIZE:
return
validate_read_file(file_path)
while True:
emb_size_data = file.read(UINT64_BYTES)
if len(emb_size_data) == 0:
break
if len(emb_size_data) < UINT64_BYTES:
raise EOFError("End of file reached before reading embedding size, data file maybe broken.")
emb_size, = struct.unpack('Q', emb_size_data)
embeddings_data = file.read(emb_size * FLOAT32_BYTES)
if embeddings_data == 0:
break
if len(embeddings_data) < emb_size * FLOAT32_BYTES:
raise EOFError("End of file reached before reading embedding file, data file maybe broken.")
embedding = list(struct.unpack(f'{emb_size}f', embeddings_data))
yield emb_size, embedding
def _write_ssd_meta_and_data(current_ssd_dir: str, table_name: str, fid: int, step: str, key_info_map: dict) -> None:
"""
write key, offset, embedding size and embedding into new file
:param current_ssd_dir: current dir
:param table_name: table name
:param fid: file ID
:param step: the step for saving model
:param key_info_map: key info map, include key, offset, embedding size and embedding
:return:
"""
meta_file_path = os.path.join(current_ssd_dir, table_name, str(fid) + ".meta." + step)
data_file_path = os.path.join(current_ssd_dir, table_name, str(fid) + ".data." + step)
if check_file_system_is_hdfs(meta_file_path) and check_file_system_is_hdfs(data_file_path):
with tf.io.gfile.GFile(meta_file_path, "wb") as meta_file, tf.io.gfile.GFile(data_file_path, "wb") as data_file:
for key, value in key_info_map.items():
offset, emb_size, embedding = value.offset, value.emb_size, value.embedding
meta_file.write(struct.pack('qI', key, offset))
data_file.write(struct.pack('q', emb_size))
data_file.write(struct.pack(f'{emb_size}f', *embedding))
else:
with os.fdopen(os.open(meta_file_path, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb") as meta_file, \
os.fdopen(os.open(data_file_path, SAVE_FILE_FLAG, SAVE_FILE_MODE), "wb") as data_file:
for key, value in key_info_map.items():
offset, emb_size, embedding = value.offset, value.emb_size, value.embedding
meta_file.write(struct.pack('qI', key, offset))
data_file.write(struct.pack('q', emb_size))
data_file.write(struct.pack(f'{emb_size}f', *embedding))