import dataclasses
from typing import List, Dict, Tuple, DefaultDict, Union
from collections import defaultdict
from collections.abc import Callable
import tensorflow as tf
from tensorflow import Operation, Tensor, Graph
from tensorflow.core.framework.graph_pb2 import GraphDef
from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter
from tensorflow.python.framework.errors_impl import InvalidArgumentError
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.communication.hccl.hccl_info import get_rank_id, get_device_id
from rec_sdk_common.util.tf_adapter import npu_ops
from rec_sdk_common.validator.validator import para_checker_decorator, ClassValidator
import mxrec_pybind
from mx_rec.core.embedding_proxy import MergeableEmbeddingTableProxy
from mx_rec.graph import utils
from mx_rec.constants.constants import (
ASCEND_CUTTING_POINT_INITIALIZER,
ASCEND_SPARSE_LOOKUP_ENTRANCE,
ASCAnchorAttr,
ASCEND_TIMESTAMP,
Flag,
MAX_WHILE_SIZE,
LIBREC_EOS_OPS_SO,
TRAIN_CHANNEL_ID,
EVAL_CHANNEL_ID,
)
from mx_rec.core.asc.feature_spec import FeatureSpec
from mx_rec.core.asc.helper import get_asc_insert_func
from mx_rec.core.asc.manager import start_asc_pipeline
from mx_rec.core.asc.swap_args import SwapArgs, SwapDataType
from mx_rec.core.asc.build_graph import SwapInfo
from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding
from mx_rec.graph.merge_lookup import do_merge_lookup
from mx_rec.graph.utils import check_and_force_list, export_pb_graph
from mx_rec.graph.constants import DeprecatedOp, AnchorDatasetOp, AnchorIteratorOp
from mx_rec.util.global_env_conf import global_env
from mx_rec.util.initialize import ConfigInitializer
from mx_rec.util.ops import import_host_pipeline_ops
from mx_rec.util.perf import performance
host_pipeline_ops = import_host_pipeline_ops()
class GraphModifierHook(tf.estimator.SessionRunHook):
@para_checker_decorator(
check_option_list=[
("dump_graph", ClassValidator, {"classes": (bool,)}),
("modify_graph", ClassValidator, {"classes": (bool,)}),
]
)
def __init__(self, dump_graph: bool = False, modify_graph: bool = True):
self._dump_graph = dump_graph
self._modify_graph = modify_graph
self._iterator_type = None
ConfigInitializer.get_instance().train_params_config.is_graph_modify_hook_running = True
def begin(self):
if self._modify_graph:
modify_graph_and_start_emb_cache(dump_graph=self._dump_graph)
else:
start_asc_pipeline()
self._iterator_type = ConfigInitializer.get_instance().train_params_config.iterator_type
if self._modify_graph and self._iterator_type not in (
AnchorIteratorOp.MAKE_ITERATOR.value,
AnchorIteratorOp.ONE_SHOT_ITERATOR.value,
):
raise ValueError("the value of iterator type should be like `MakeIterator` or `OneShotIterator`.")
logger.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type)
def after_create_session(self, session, coord):
if self._modify_graph and self._iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value:
session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER))
@dataclasses.dataclass
class _AnchorRecord:
replacement_spec: DefaultDict[Tensor, List[Tuple[int, Operation]]]
passing_tensors: List[Tensor]
batch_tensor_indexs: List[int]
sub_cutting_points: List[Tensor]
sub_graph_def: GraphDef
input_names: List[str]
output_names: List[str]
is_training: bool
input_indexs: List[int] = None
class _GraphModifier:
@para_checker_decorator(
check_option_list=[
("dump_graph", ClassValidator, {"classes": (bool,)}),
("modify_graph", ClassValidator, {"classes": (bool,)}),
]
)
def __init__(self, full_graph: Graph = None, dump_graph: bool = False):
if not full_graph:
full_graph = tf.compat.v1.get_default_graph()
self._full_graph = full_graph
self._dump_graph = dump_graph
@staticmethod
def _get_preprocessing_map_func(
graph_def: GraphDef,
input_names: List[str],
output_names: List[str],
pipeline_input_indexes: List[int] = None,
) -> Callable:
input_names = check_and_force_list(input_names, str)
output_names = check_and_force_list(output_names, str)
pipeline_input_indexes = check_and_force_list(pipeline_input_indexes, int)
def map_func(*args) -> tuple:
batch = args
if not isinstance(batch, tuple) or len(batch) == 0:
raise ValueError(f"The dataset batch is invalid, and the batch is: {batch}.")
logger.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch)
input_tensors = []
graph = tf.compat.v1.get_default_graph()
for index in pipeline_input_indexes:
tensor = graph.get_tensor_by_name("args_%d:0" % index)
input_tensors.append(tensor)
output_list = tf.import_graph_def(
graph_def, input_map=dict(zip(input_names, input_tensors)), return_elements=output_names
)
output_batch = list(batch)
output_batch.append(tuple(output_list))
logger.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch)
return tuple(output_batch)
return map_func
@performance("graph_modifier")
def modify_graph_for_asc(self, prefetch: int = 10):
cutting_point_list = self._full_graph.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE)
utils.check_cutting_points(cutting_point_list)
if not cutting_point_list:
logger.warning("Nothing to revise.")
return
export_pb_graph("old_graph.pbtxt", self._dump_graph, graph_def=self._full_graph.as_graph_def())
get_next_op_map = self._generate_get_next_op_specs(cutting_point_list)
logger.debug(
"In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.",
len(get_next_op_map),
get_next_op_map.keys(),
)
for get_next_op, record in get_next_op_map.items():
is_training = record.is_training
src_dataset = self._get_src_dataset(get_next_op, is_training)
timestamp_index = _get_timestamp_index(self._full_graph, get_next_op, is_training)
original_batch_tensor_count = _get_dataset_tensor_count(src_dataset)
sub_cutting_points = record.sub_cutting_points
input_index_list = _get_input_index_list(
sub_cutting_points,
record.replacement_spec,
record.output_names,
original_batch_tensor_count,
timestamp_index=timestamp_index,
)
record.input_indexs = input_index_list
with self._full_graph.as_default():
tgt_dataset = self._get_tgt_dataset(src_dataset, sub_cutting_points, record, prefetch=prefetch)
self._update_iterator_getnext(get_next_op, tgt_dataset, is_training, record)
if not is_training:
with self._full_graph.as_default():
do_merge_lookup(is_train=False)
if "evaluate" in ConfigInitializer.get_instance().train_params_config.bool_gauge_set:
logger.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.")
ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_training, False)
if is_training and not ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(True):
raise RuntimeError(
"In training mode, `do_merge_lookup` should have been executed in compute gradients "
"phase. Please check whether compute gradients is performed."
)
self._modify_graph_for_ddr(get_next_op_map)
logger.info("Graph has been revised.")
export_pb_graph("new_graph.pbtxt", self._dump_graph, graph_def=self._full_graph.as_graph_def())
def _modify_graph_for_ddr(self, get_next_op_map: Dict[Tensor, _AnchorRecord]):
optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance
if optimizer_instance is None:
slot_num = 0
else:
change_ext_emb_size_by_opt(optimizer_instance)
slot_num = optimizer_instance.slot_num
for _, record in get_next_op_map.items():
is_training = record.is_training
channel_id = 0 if is_training else 1
replace_anchor_for_ddr_ssd(self._full_graph, slot_num, channel_id)
def _generate_get_next_op_specs(self, cutting_point_list: List[Tensor]) -> Dict[Tensor, _AnchorRecord]:
get_next_op_map = defaultdict(dict)
for input_tensor in cutting_point_list:
get_next_op = utils.upward_bfs_op(input_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value)
if get_next_op not in get_next_op_map:
logger.debug("find a new get_next_op named '%s'", get_next_op.name)
replacement_specs = utils.record_ops_to_replace(self._full_graph, get_next_op)
passing_tensors, batch_tensor_indexs, sub_cutting_points = _get_passing_tensor_list(
cutting_point_list, get_next_op
)
sub_graph_def, input_names, output_names = self._get_sub_graph(passing_tensors, sub_cutting_points)
is_training = BaseSparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING)
record = _AnchorRecord(
replacement_specs,
passing_tensors,
batch_tensor_indexs,
sub_cutting_points,
sub_graph_def,
input_names,
output_names,
is_training,
)
get_next_op_map[get_next_op] = record
export_pb_graph(f"cut_graph_{get_next_op.name}.pbtxt", self._dump_graph, graph_def=sub_graph_def)
return get_next_op_map
def _get_sub_graph(
self, input_tensors: List[Tensor], output_tensors: List[Tensor]
) -> Tuple[GraphDef, List[str], List[str]]:
input_tensors = check_and_force_list(input_tensors, tf.Tensor)
output_tensors = check_and_force_list(output_tensors, tf.Tensor)
input_op_name_list = [tensor.op.name for tensor in input_tensors]
output_op_name_list = [tensor.op.name for tensor in output_tensors]
graph_def = self._full_graph.as_graph_def()
cut_graph_input = tf.compat.v1.graph_util.extract_sub_graph(graph_def, input_op_name_list)
cut_graph_output = tf.compat.v1.graph_util.extract_sub_graph(graph_def, output_op_name_list)
node_list = []
node_list_input = cut_graph_input.node
node_list_output = cut_graph_output.node
for node in node_list_output:
if node not in node_list_input:
node_list.append(node)
sub_graph_def = tf.compat.v1.GraphDef()
sub_graph_def.node.extend(node_list)
input_name_list = [tensor.name for tensor in input_tensors]
output_name_list = [tensor.name for tensor in output_tensors]
return sub_graph_def, input_name_list, output_name_list
def _get_src_dataset(self, get_next_op: Operation, is_training: bool) -> DatasetV1Adapter:
"""
根据`IteratorGetNext`算子在计算图中找出原始dataset.
Args:
get_next_op: `IteratorGetNext`算子
is_training: 当前是否为训练模式,训练模式为True,否则为False
Returns: 原始数据集
"""
try:
target_op = utils.find_trans_dataset(self._full_graph, get_next_op)
except (ValueError, TypeError, RuntimeError) as err:
logger.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err)
graph = self._full_graph
dataset_op_list = [op for op in graph.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name]
dataset_op_list = list(
filter(
lambda op: op not in self._full_graph.get_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET),
dataset_op_list,
)
)
dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name)
logger.debug(
"In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.",
is_training,
dataset_op_list,
)
if len(dataset_op_list) == 1:
target_op = dataset_op_list[0]
elif is_training and len(dataset_op_list) == 2:
prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name)
target_op = prefetch_dataset_op_list[0]
elif not is_training and len(dataset_op_list) == 3:
prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name)
target_op = prefetch_dataset_op_list[1]
else:
raise RuntimeError(
f"'{AnchorDatasetOp.PREFETCH_DATASET.value}' not found, got transformation datasets: "
f"{dataset_op_list}."
) from err
except Exception as err:
raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err
if not target_op.outputs:
raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.")
logger.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs)
src_dataset = utils.find_target_instance_dataset(self._full_graph, target_op.outputs[0])
ConfigInitializer.get_instance().train_params_config.dataset_element_spec = src_dataset.element_spec
return src_dataset
def _get_tgt_dataset(
self,
src_dataset: DatasetV1Adapter,
sub_cutting_point_list: List[Tensor],
record: _AnchorRecord,
prefetch: int = 10,
) -> DatasetV1Adapter:
"""
根据原始数据集生成新的数据集实例.
Args:
src_dataset: 原始数据集实例
sub_cutting_point_list: 打桩的lookup ids列表
records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典
dump_graph: 是否dump计算图,默认为False
prefetch: dataset预取数据量,默认为10
Returns: 新数据集实例
"""
librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO)
channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(
record.is_training
)
max_train_steps = ConfigInitializer.get_instance().max_steps
max_eval_steps = ConfigInitializer.get_instance().eval_steps
src_dataset = src_dataset.eos_map(librec, channel_id, max_train_steps, max_eval_steps)
tgt_dataset = src_dataset.map(
self._get_preprocessing_map_func(
record.sub_graph_def,
record.input_names,
record.output_names,
pipeline_input_indexes=record.batch_tensor_indexs,
)
)
feature_numbers = [
BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt
for cutting_point in sub_cutting_point_list
]
table_names = [
BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name
for cutting_point in sub_cutting_point_list
]
tgt_dataset = tgt_dataset.map(
get_asc_insert_func(
feature_numbers=feature_numbers,
table_names=table_names,
args_index_list=record.input_indexs,
is_training=record.is_training,
dump_graph=self._dump_graph,
)
)
tgt_dataset = tgt_dataset.prefetch(prefetch)
return tgt_dataset
def _update_iterator_getnext(
self, get_next_op: Operation, tgt_dataset: DatasetV1Adapter, is_training: bool, record: _AnchorRecord
) -> None:
"""
用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch.
Args:
get_next_op: `IteratorGetNext`算子
tgt_dataset: 新数据集
is_training: 当前是否为训练模式,训练模式为True,否则为False
records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典
Returns: None
"""
if not get_next_op.outputs:
raise RuntimeError("there is no tensor in the dataset. Please check the dataset and data processing.")
iterator_type = ""
if get_next_op.outputs[0].op.inputs:
iterator_type = get_next_op.outputs[0].op.inputs[0].op.type
if iterator_type == "IteratorV2":
iterator_type = utils.find_make_iterator_op(self._full_graph, get_next_op.outputs[0]).type
if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value):
raise RuntimeError(
f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, "
f"but the current iterator is `{iterator_type}`."
)
ConfigInitializer.get_instance().train_params_config.iterator_type = iterator_type
logger.info("The iterator type of dataset is `%s`.", iterator_type)
if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value:
new_iterator = tgt_dataset.make_initializable_iterator()
tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer)
ConfigInitializer.get_instance().train_params_config.set_initializer(is_training, new_iterator.initializer)
else:
new_iterator = tgt_dataset.make_one_shot_iterator()
new_batch = new_iterator.get_next()
ConfigInitializer.get_instance().train_params_config.set_target_batch(is_training, new_batch)
new_batch_tensor = _get_new_batch_tensor(new_batch)
logger.debug("New dataset batch tensor is : %s.", new_batch_tensor)
new_get_next_op_name = utils.upward_bfs_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value).name
self._update_input_tensor_with_new_batch(record.replacement_spec, new_get_next_op_name, new_batch)
def _update_input_tensor_with_new_batch(
self,
replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]],
new_get_next_op_name: str,
new_batch: Dict[str, Tensor],
) -> None:
"""
用新batch中的IteratorGetNext替换计算图中老batch的IteratorGetNext.
Args:
replacement_specs: 记录待替换算子的dict,key为老batch的IteratorGetNext,value为以老batch作为输入的算子
new_get_next_op_name: 新数据集的get_next算子名称
new_batch: 新数据集的batch
Returns: None
"""
for old_tensor, item in replacement_specs.items():
for idx, operator in item:
old_tensor_name = old_tensor.name
output_index = old_tensor_name.split(":")[-1]
new_tensor_name = f"{new_get_next_op_name}:{output_index}"
new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name)
try:
operator._update_input(idx, new_tensor)
except InvalidArgumentError as err:
logger.info(
"The replacement specs keys (old batch) is: %s. \n\t\t The new batch is: %s.",
replacement_specs.keys(),
new_batch,
)
raise RuntimeError(
f"Cannot update edge, old tensor: {old_tensor}, new tensor: {new_tensor}."
) from err
@para_checker_decorator(
check_option_list=[
("full_graph", ClassValidator, {"classes": (Graph, type(None))}),
("dump_graph", ClassValidator, {"classes": (bool,)}),
]
)
def modify_graph_and_start_emb_cache(full_graph: Graph = None, dump_graph: bool = False):
modifier = _GraphModifier(full_graph=full_graph, dump_graph=dump_graph)
modifier.modify_graph_for_asc()
MergeableEmbeddingTableProxy().reset()
start_asc_pipeline()
def _get_input_index_list(
cutting_point_list: List[Tensor],
replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]],
mapping_name_list: List[str],
base_count: int,
timestamp_index: int = None,
) -> List[int]:
input_index_list = []
for cutting_point in cutting_point_list:
if cutting_point in replacement_specs:
index = int(cutting_point.name.split(":")[1])
elif cutting_point.name in mapping_name_list:
index = base_count + mapping_name_list.index(cutting_point.name)
else:
raise ValueError(f"Cannot find a matching output for cutting point tensor named '{cutting_point.name}'.")
input_index_list.append(index)
if timestamp_index is not None:
input_index_list = [timestamp_index] + input_index_list
return input_index_list
def _get_passing_tensor_list(
src_tensors: List[Tensor], target_op: Operation
) -> Tuple[List[Tensor], List[int], List[Tensor]]:
def get_passing_tensors(src_tensor):
passing_tensors = []
tensor_list = [src_tensor]
while_num = 0
while tensor_list:
while_num += 1
if while_num > MAX_WHILE_SIZE:
raise RuntimeError(
f"In get_passing_tensors function, the maximum cycle depth is greater " f"than {MAX_WHILE_SIZE}."
)
last_tensor = tensor_list.pop()
if last_tensor.op is target_op:
passing_tensors.append(last_tensor)
else:
tensor_list.extend(list(last_tensor.op.inputs))
return passing_tensors
src_tensors = check_and_force_list(src_tensors, Tensor)
passing_tensor_list = []
sub_src_tensors = []
for tensor in src_tensors:
passing_tensors = get_passing_tensors(tensor)
for passing_tensor in passing_tensors:
if passing_tensor not in passing_tensor_list:
passing_tensor_list.append(passing_tensor)
if len(passing_tensors) != 0:
logger.info("passing_tensors: %s", passing_tensors)
sub_src_tensors.append(tensor)
else:
logger.info("Cannot find passing tensor for given tensor '%s'.", tensor)
output_index_list = [int(tensor.name.split(":")[1]) for tensor in passing_tensor_list]
return passing_tensor_list, output_index_list, sub_src_tensors
def _get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int:
"""
获取数据集中batch的tensor数量.
Args:
dataset: 数据集实例
Returns: 数据集batch中的tensor数量
"""
src_element_spec = dataset.element_spec
if not isinstance(src_element_spec, (list, tuple)):
src_element_spec = [src_element_spec]
src_sorted_keys = utils.make_sorted_key_to_tensor_list(src_element_spec, [])
return len(src_sorted_keys)
def _get_timestamp_index(graph: Graph, get_next_op: Operation, is_training: bool) -> int:
timestamp_tensor_list = graph.get_collection(ASCEND_TIMESTAMP)
timestamp_index = None
for timestamp in timestamp_tensor_list:
if timestamp in get_next_op.outputs:
timestamp_index = int(timestamp.name.split(":")[1])
timestamp_feature_spec = ConfigInitializer.get_instance().feature_spec_config.get_feature_spec("timestamp")
if timestamp_feature_spec is None:
timestamp_feature_spec = FeatureSpec("timestamp", index_key=timestamp_index, is_timestamp=True)
timestamp_feature_spec.include_timestamp(is_training)
ConfigInitializer.get_instance().feature_spec_config.insert_feature_spec(
timestamp_feature_spec, is_training
)
break
if timestamp_feature_spec.index_key != timestamp_index:
raise ValueError(
f"Given timestamp_index, which is {timestamp_index}, does not match index "
f"key. Please double check."
)
timestamp_feature_spec.include_timestamp(is_training)
break
return timestamp_index
def change_ext_emb_size_by_opt(optimizer: tf.compat.v1.train.Optimizer):
for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items():
if ConfigInitializer.get_instance().use_dynamic_expansion or not table_instance.is_hbm:
table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num)
logger.info("ext_emb_size is reset to be %s in change_ext_emb_size_by_opt", table_instance.ext_emb_size)
def _get_variable_and_slot_list(each_var, slot_num, table_name, channel_id):
variable_and_slot_list = [each_var]
if slot_num == 0:
return variable_and_slot_list
is_training = True if channel_id == TRAIN_CHANNEL_ID else False
optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(
table_name, is_training=is_training
)
if optimizer is None and channel_id == TRAIN_CHANNEL_ID:
raise RuntimeError(
"In training mode, table_instance should have been set_optimizer_for_table "
"before modify_graph, please check whether apply_gradients is performed"
)
if optimizer is None and channel_id == EVAL_CHANNEL_ID:
if not ConfigInitializer.get_instance().train_params_config.bool_gauge_set:
slot_place_holder = tf.zeros_like(each_var)
else:
slot_place_holder = tf.ones_like(each_var)
for _ in range(slot_num):
variable_and_slot_list.append(slot_place_holder)
else:
for slot_dict in optimizer.values():
for slot_val in slot_dict.values():
variable_and_slot_list.append(slot_val)
return variable_and_slot_list
def shm_swap(tables, swap_in_index, swap_out_index, h2d_name, d2h_name) -> tf.Operation:
max_table_nun = 6
table_list = []
table_num = len(tables)
for i in range(max_table_nun):
if i < table_num:
table_list.append(tables[i])
else:
table_list.append(tables[0])
swap_in_index = tf.cast(swap_in_index, dtype=tf.int64)
swap_out_index = tf.cast(swap_out_index, dtype=tf.int64)
device_id = get_device_id()
h2d_name_id = f'{h2d_name}_{device_id}'
d2h_name_id = f'{d2h_name}_{device_id}'
capacity = 50
rma_shm_host_swap_in = mxrec_pybind.get_shm_mem(h2d_name_id, device_id, capacity)
shm_swap_in = str(rma_shm_host_swap_in)
rma_shm_host_swap_out = mxrec_pybind.get_shm_mem(d2h_name_id, device_id, capacity)
shm_swap_out = str(rma_shm_host_swap_out)
shm_swap_op = host_pipeline_ops.rma_swap_multi_tables(swap_in_index=swap_in_index,
swap_out_index=swap_out_index,
table_a=table_list[0],
table_b=table_list[1],
table_c=table_list[2],
table_d=table_list[3],
table_e=table_list[4],
table_f=table_list[5],
table_num=table_num,
shm_swap_in=shm_swap_in,
shm_swap_out=shm_swap_out)
return shm_swap_op
def _get_swap_info(
table_instance: BaseSparseEmbedding,
variable_and_slot_list: List[tf.Variable],
swap_info: SwapInfo,
channel_id: int,
) -> List[tf.Operation]:
"""
Get swap op.
:param table_instance: BaseSparseEmbedding
:param variable_and_slot_list: [var + slots]
:param swap_info: swap in/out length and position
:param channel_id: train or predict
:return: swap op
"""
if table_instance.is_hbm:
return [tf.no_op()]
if len(variable_and_slot_list) == 0:
raise RuntimeError("When enable emb_transfer, optimizer should have slots")
use_static = ConfigInitializer.get_instance().use_static
max_lookup_vec_size = None
if use_static:
max_lookup_vec_size = (
table_instance.send_count * table_instance.rank_size
if not table_instance.is_dp else table_instance.send_count
)
swap_out_pos = swap_info.swap_out_pos
swap_in_pos = swap_info.swap_in_pos
if global_env.use_shm_swap == Flag.TRUE.value:
if use_static:
length_out = tf.cast(swap_info.swap_out_len, dtype=tf.int64)
swap_out_pos = swap_out_pos[: length_out]
length_in = tf.cast(swap_info.swap_in_len, dtype=tf.int64)
swap_in_pos = swap_in_pos[: length_in]
optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(
table_instance.table_name)
h2d_name = f'{table_instance.table_name}_h2d_{channel_id}'
d2h_name = f'{table_instance.table_name}_d2h_{channel_id}'
if optimizer is None and channel_id == EVAL_CHANNEL_ID:
swap_op = [shm_swap([variable_and_slot_list[0]], swap_in_index=swap_in_pos,
swap_out_index=swap_out_pos, h2d_name=h2d_name, d2h_name=d2h_name)]
else:
swap_op = [shm_swap(variable_and_slot_list, swap_in_index=swap_in_pos, swap_out_index=swap_out_pos,
h2d_name=h2d_name, d2h_name=d2h_name)]
return swap_op
if use_static:
swap_out_pos = swap_out_pos[: swap_info.swap_out_len]
swap_in_pos = swap_in_pos[: swap_info.swap_in_len]
with tf.compat.v1.variable_scope("h2d_emb"):
logger.debug("Channel %s_h2d_%s was built for getnext.", table_instance.table_name, channel_id)
h2d_emb = npu_ops.gen_npu_ops.get_next(
output_types=[tf.float32],
output_shapes=[[max_lookup_vec_size, table_instance.ext_emb_size]],
channel_name=f"{table_instance.table_name}_h2d_{channel_id}",
)[0]
logger.debug("h2d_emb shape: %s", h2d_emb)
if use_static:
h2d_emb = h2d_emb[:swap_info.swap_in_len, :]
swap_outs = [tf.gather(one_table, swap_out_pos) for one_table in variable_and_slot_list]
swap_out = tf.concat(swap_outs, axis=1)
logger.debug("Channel %s_d2h_%s was built for op outfeed.", table_instance.table_name, channel_id)
swap_out_op = npu_ops.outfeed_enqueue_op(
channel_name=f"{table_instance.table_name}_d2h_{channel_id}", inputs=[swap_out]
)
with tf.control_dependencies([swap_out_op]):
nd_swap_pos = tf.expand_dims(swap_in_pos, 1)
var_num = len(variable_and_slot_list)
h2d_emb_split = tf.split(h2d_emb, var_num, axis=1)
is_training = True if channel_id == TRAIN_CHANNEL_ID else False
optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(
table_instance.table_name, is_training=is_training
)
if optimizer is None and channel_id == EVAL_CHANNEL_ID:
swap_in_op = [tf.compat.v1.scatter_nd_update(variable_and_slot_list[0], nd_swap_pos, h2d_emb_split[0])]
else:
swap_in_op = [
tf.compat.v1.scatter_nd_update(variable_and_slot_list[i], nd_swap_pos, h2d_emb_split[i])
for i in range(var_num)]
return swap_in_op
def _get_new_batch_tensor(new_batch: Union[List, Tuple, Dict, tf.Tensor]) -> tf.Tensor:
"""
Get a tensor from the new batch.
Args:
new_batch: New dataset batch.
Returns: A tensor in the batch.
"""
if isinstance(new_batch, list):
batch_tensor = new_batch.pop()
return _get_new_batch_tensor(batch_tensor)
elif isinstance(new_batch, tuple):
new_batch = list(new_batch)
batch_tensor = new_batch.pop()
return _get_new_batch_tensor(batch_tensor)
elif isinstance(new_batch, dict):
for _, value in new_batch.items():
return _get_new_batch_tensor(value)
if not isinstance(new_batch, tf.Tensor):
raise TypeError(f"Cannot find a tensor from give batch: {new_batch}.")
if AnchorIteratorOp.ITERATOR_GET_NEXT.value not in new_batch.name:
raise ValueError(f"{new_batch} is not {AnchorIteratorOp.ITERATOR_GET_NEXT.value} tensor.")
return new_batch
def replace_anchor_for_ddr_ssd(graph: tf.Graph, slot_num: int, channel_id: int):
swap_args = SwapArgs()
sparse_variables = graph.get_collection(
ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection
)
for each_var in sparse_variables:
table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(each_var)
if table_instance.is_hbm:
continue
variable_and_slot_list = _get_variable_and_slot_list(each_var, slot_num, table_instance.table_name, channel_id)
swap_args_dict = swap_args.swap_config_dict[table_instance.table_name][channel_id]
swap_op = _get_swap_info(table_instance, variable_and_slot_list, swap_args_dict["swap_info"], channel_id)
swap_control_dict = swap_args.swap_control_dict[table_instance.table_name][channel_id]
if SwapDataType.CONTROL_OPS.value not in swap_control_dict:
raise ValueError("swap control missing key [control_ops] in modify_graph_for_asc")
control_ops = swap_control_dict[SwapDataType.CONTROL_OPS.value]
utils.replace_anchor_control(graph, control_ops, swap_op)
if channel_id == TRAIN_CHANNEL_ID and slot_num > 1:
slot_control_dict = swap_args.slot_control_dict[table_instance.variable]
if SwapDataType.CONTROL_OPS.value not in slot_control_dict:
raise ValueError("slot control missing key [control_ops] in modify_graph_for_asc")
slot_control_ops = slot_control_dict[SwapDataType.CONTROL_OPS.value]
utils.replace_anchor_control(graph, slot_control_ops, swap_op)