from typing import Union, Optional, List
import tensorflow as tf
from rec_sdk_common.communication.hccl.hccl_info import get_rank_size
from rec_sdk_common.constants.constants import ValidatorParams
from mx_rec.constants.constants import MAX_VOCABULARY_SIZE, MULTI_LOOKUP_TIMES
from mx_rec.core.asc.feature_spec import FeatureSpec
from mx_rec.util.initialize import ConfigInitializer
def check_emb_init_params(is_hbm: bool, embedding_size: tf.TensorShape):
"""
校验稀疏表的初始化参数.
Args:
is_hbm: 是否为HBM模式
embedding_size: 稀疏表维度大小
Returns: None
"""
if ConfigInitializer.get_instance().hybrid_manager_config.freeze:
raise EnvironmentError("Emb cache management has been established, you cannot build new hash table.")
if not is_hbm and ConfigInitializer.get_instance().use_dynamic_expansion:
raise ValueError("DDR/SSD mode do not support embedding dynamic expansion for now.")
if embedding_size.ndims != 1:
raise ValueError("Parameter 'embedding_size' can only be one dim shape.")
def check_emb_lookup_params(
table_params: dict, feature_spec: Union[tf.Tensor, FeatureSpec], send_count: Optional[int], is_training: bool
):
"""
校验稀疏表此次lookup的参数.
Args:
table_params: 稀疏表参数字典
feature_spec: 稀疏表次数lookup的tensor或tensor的包装类
send_count: all2all通信参数
is_training: 当前流程是训练还是推理
Returns: None
"""
if isinstance(feature_spec, FeatureSpec):
if not feature_spec.initialized:
raise RuntimeError("Feature Spec has not been initialized.")
if is_training not in feature_spec.pipeline_mode:
raise RuntimeError(
f"You have not config feature for is training mode '{is_training}', please config "
f"feature with func sparse_lookup at first."
)
slice_device_vocabulary_size = table_params.get("slice_device_vocabulary_size")
slice_host_vocabulary_size = table_params.get("slice_host_vocabulary_size")
table_name = table_params.get("table_name")
if slice_host_vocabulary_size > MAX_VOCABULARY_SIZE:
raise ValueError(
f"Given host_vocabulary_size was too big for table "
f"'{table_name}', in which slice_device_vocabulary_size was "
f"{slice_device_vocabulary_size} and slice_host_vocabulary_size was "
f"{slice_host_vocabulary_size}."
)
if not ConfigInitializer.get_instance().use_static:
return
if not (isinstance(send_count, int) and send_count > 0):
raise ValueError("Send count must be a integer which is larger than 0.")
if table_params.get("is_hbm") or ConfigInitializer.get_instance().use_dynamic_expansion:
return
rank_size = get_rank_size()
if slice_device_vocabulary_size < send_count * rank_size:
raise ValueError(
f"Given device_vocabulary_size was too small for table '{table_name}', "
f"in which slice_device_vocabulary_size was {slice_device_vocabulary_size} "
f"and it must be bigger than send_count({send_count}) * rank_size({rank_size}): "
f"{send_count * rank_size}, please increase [device vocabSize] in [create_table] interface."
)
if slice_host_vocabulary_size < send_count * rank_size:
raise ValueError(
f"Given host_vocabulary_size was too small for table '{table_name}', "
f"in which slice_host_vocabulary_size was {slice_host_vocabulary_size} "
f"and it must be bigger than send_count({send_count}) * rank_size({rank_size}): "
f"{send_count * rank_size}, please increase [host vocabSize] in [create_table] interface."
)
def check_emb_multi_lookup_times(lookup_times: int, table_name: str):
"""
校验稀疏表一表多查的次数.
Args:
lookup_times: 稀疏表lookup的次数
table_name: 稀疏表名
Returns: None
"""
if lookup_times > MULTI_LOOKUP_TIMES:
raise RuntimeError(
f"The number of multiple sparse lookup for a table ({table_name}) is "
f"{MULTI_LOOKUP_TIMES}, and current times is {lookup_times}."
)
def check_and_format_emb_padding_keys(
padding_keys: Optional[Union[int, List[int]]],
padding_keys_mask: bool,
padding_keys_len: Optional[int],
) -> List[int]:
"""
Check and set the padding keys parameters.
Args:
padding_keys: Upper-layer services must ensure that the padding keys are included in the incoming ids.
padding_keys_mask: Whether the embedding value corresponding to the padding key is updated;
`True` indicates that the embedding value is not updated.
padding_keys_len: Indicates the feature length of the corresponding ids. This parameter is mandatory
if padding keys are specified.
Returns: The padding keys after formatting.
"""
if isinstance(padding_keys, int):
if padding_keys < ValidatorParams.MIN_INT64.value or padding_keys > ValidatorParams.MAX_INT64.value:
raise ValueError(f"the padding keys should be between {ValidatorParams.MIN_INT64.value} and {ValidatorParams.MAX_INT64.value}")
elif isinstance(padding_keys, list):
if len(padding_keys) < 1 or len(padding_keys) > ValidatorParams.MAX_INT32.value:
raise ValueError(f"the length of the padding keys should be between 1 and {ValidatorParams.MAX_INT32.value}")
for padding_key in padding_keys:
if padding_key < ValidatorParams.MIN_INT64.value or padding_key > ValidatorParams.MAX_INT64.value:
raise ValueError(f"the padding keys should be between {ValidatorParams.MIN_INT64.value} and {ValidatorParams.MAX_INT64.value}")
elif isinstance(padding_keys, type(None)):
if padding_keys_mask:
raise ValueError("the padding keys mask be False when padding keys is None")
return []
else:
raise TypeError(f"the padding keys must be None/int/list type")
if not padding_keys_mask:
raise ValueError("Padding keys mask be True when padding keys is not None.")
if padding_keys_len is None:
raise ValueError("Padding keys length cannot be None when padding keys is not None.")
if isinstance(padding_keys, int):
padding_keys = [padding_keys]
return padding_keys
padding_keys = list(set(padding_keys))
return padding_keys
def check_padding_keys_global_params():
"""
Check the global parameters of the padding keys.
Returns: None
"""
padding_keys_mask_list = []
table_instance_dict = ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict
for _, table_instance in table_instance_dict.items():
if table_instance.padding_keys_mask and table_instance.is_dp:
raise RuntimeError("The padding keys mode does not yet support dp mode.")
if table_instance.padding_keys_mask and not table_instance.is_grad:
raise RuntimeError("The padding keys mode should not be used together with the no grad mode.")
padding_keys_mask_list.append(table_instance.padding_keys_mask)
if all(padding_keys_mask_list) and not ConfigInitializer.get_instance().use_static:
raise RuntimeError("When the padding keys mask of all tables is True, it should be set to static shape mode.")