from dataclasses import dataclass
from enum import Enum
from typing import Union, Optional, List
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.init_ops import Initializer as InitializerV1
from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2
MXREC = "mxrec"
GRAPH_DEF = "graph_def"
USE_RANKTABLE = "use_ranktable"
USE_FUSION_OP = "use_fusion_op"
USE_LCCL_ALL2ALL_OP = "use_lccl_all2all_op"
FUSION_OP_TYPE = "fusion_op_type"
ALL2ALL_OP_TYPE = "all2all_op_type"
EMBEDDING_TABLE_COLLECTION = "EMBEDDING_TABLE_COLLECTION"
INIT_HASHTABLE_COLLECTION = "INIT_HASHTABLE_COLLECTION"
LOCAL_EMBEDDING_COLLECTION = "LOCAL_EMBEDDING_COLLECTION"
class EmbDistributionStrategy(Enum):
MP = "MP"
DP = "DP"
@dataclass
class EmbTableConfig:
name: str
dim: int
dev_vocab_size: int = 0
initializer: Union[
InitializerV1, InitializerV2
] = tf.compat.v1.random_normal_initializer()
key_dtype: tf.DType = tf.int64
value_dtype: tf.DType = tf.float32
dist_strategy: str = EmbDistributionStrategy.MP.value
@dataclass
class StaticEmbTableConfig(EmbTableConfig):
min_used_times: Optional[int] = None
max_cold_secs: Optional[int] = None
@dataclass
class MPLookupParams:
local_ids: tf.Tensor
local_ids_restore: tf.Tensor
sorted_ids_indices: Optional[tf.Tensor] = None
send_count_matrix: Optional[List[List[tf.Tensor]]] = None
@dataclass
class DPLookupParams:
global_unique_ids: tf.Tensor
max_uni_ids_count: Optional[tf.Tensor] = None
local_uni_ids_count: Optional[tf.Tensor] = None
ids_sc_all: Optional[tf.Tensor] = None
rank_idx: Optional[tf.Tensor] = None
local_unique_ids: Optional[tf.Tensor] = None
local_unique_idx: Optional[tf.Tensor] = None
global_unique_idx: Optional[tf.Tensor] = None
@dataclass
class CommNodeInfo:
cm_chief_ip: str
cm_chief_port: int
cm_chief_device: int
cm_worker_ip: str
cm_worker_size: int
class CommParams(Enum):
CM_NODE_INFO = "cm-node-info"
MIN_IPV4_LEN = 7
MAX_IPV4_LEN = 15
CM_CHIEF_IP = "CM_CHIEF_IP"
CM_CHIEF_PORT = "CM_CHIEF_PORT"
MIN_CM_CHIEF_PORT = 0
MAX_CM_CHIEF_PORT = 65520
CM_CHIEF_DEVICE = "CM_CHIEF_DEVICE"
CM_WORKER_IP = "CM_WORKER_IP"
CM_WORKER_SIZE = "CM_WORKER_SIZE"
MIN_CM_WORKER_SIZE = 0
MAX_CM_WORKER_SIZE = 32768
MAX_RANK_ID = 65535
MAX_LOCAL_ID = 15
class LogParams(Enum):
LOG_LEVEL = "log_level"
DEBUG = "DEBUG"
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
CRITICAL = "CRITICAL"
class MPIParams(Enum):
OMPI_COMM_WORLD_SIZE = "OMPI_COMM_WORLD_SIZE"
OMPI_COMM_WORLD_LOCAL_SIZE = "OMPI_COMM_WORLD_LOCAL_SIZE"
OMPI_COMM_WORLD_RANK = "OMPI_COMM_WORLD_RANK"
class ValidatorParams(Enum):
MAX_INT32 = np.iinfo(np.int32).max
MIN_INT32 = np.iinfo(np.int32).min
MAX_UINT32 = np.iinfo(np.uint32).max
MIN_UINT32 = np.iinfo(np.uint32).min
MAX_INT64 = np.iinfo(np.int64).max
MIN_INT64 = np.iinfo(np.int64).min
MAX_UINT64 = np.iinfo(np.uint64).max
MIN_UINT64 = np.iinfo(np.uint64).min
MAX_FLOAT32 = np.finfo(np.float32).max
MIN_FLOAT32 = np.finfo(np.float32).min
STR_MAX_LENGTH = MAX_INT32
STR_MIN_LENGTH = 0
FILE_MIN_SIZE = 1
FILE_MAX_SIZE = 1024 * 1024 * 1024 * 1024
MAX_FILE_PATH_LENGTH = 1024
class NumCheckValueMethod(Enum):
DEFAULT = "check_value"
OPEN_INTERVAL = "check_value_for_open_interval"
LEFT_OPEN_INTERVAL = "check_value_for_left_open_interval"
RIGHT_OPEN_INTERVAL = "check_value_for_right_open_interval"