import os
import abc
from typing import List, Dict, Set, Tuple, Union
import pandas as pd
import tensorflow as tf
from tensorflow import Operation, Tensor, SparseTensor, Graph, variant, resource
from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter
from rec_sdk_common.log.log import LoggingProxy as logger
from rec_sdk_common.validator.validator import ClassValidator, para_checker_decorator
from mx_rec.graph import utils
from mx_rec.constants.constants import (
ASCAnchorAttr,
ASCEND_TIMESTAMP,
MAX_WHILE_SIZE,
ASCEND_SPARSE_LOOKUP_ENTRANCE,
ORPHAN_LOOKUP_KEY_PREFIX
)
from mx_rec.graph.constants import DeprecatedOp, AnchorDatasetOp, AnchorIteratorOp
from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding
class NoGradSubgraphSlicer(metaclass=abc.ABCMeta):
_SLICED_OP_NAME_PREFIX = "sliced"
_SLICING_SUMMARY_NAME = "slicing_summary.csv"
_UNSLICED_FULL_GRAPH_NAME = "unsliced_full_graph.pbtxt"
_SLICED_SUB_GRAPH_NAME = "sliced_sub_graph.pbtxt"
_SLICED_FULL_GRAPH_NAME = "sliced_full_graph.pbtxt"
_INVALID_STR_IN_OP_TYPE = ("Dataset", "Summary")
_INVALID_STR_IN_OP_NAME = ("save", "report_", "loss")
_INVALID_CONSUMER_OP_TYPE = ("Assign", "SaveV2")
_VALID_TENSOR_CLASS = (Tensor, SparseTensor)
_INVALID_TENSOR_DTYPE = (variant, resource)
def __init__(self, full_graph: Graph = None, info_dir: str = "slicing") -> None:
if not full_graph:
full_graph = tf.compat.v1.get_default_graph()
self._full_graph = full_graph
if not os.path.exists(info_dir):
os.makedirs(info_dir)
self._info_dir = info_dir
@staticmethod
def _find_min_dep_ops(
tgt_ops: Set[Operation],
) -> Set[Operation]:
logger.debug("Search from base nodes: %s.", tgt_ops)
base_ops = tgt_ops.copy()
visited_ops = base_ops
loop_cnt = 0
while base_ops:
loop_cnt += 1
if loop_cnt > MAX_WHILE_SIZE:
raise RuntimeError(f"maximum loop times exceed limit: {MAX_WHILE_SIZE}.")
parent_ops = set()
for base_node in base_ops:
if len(base_node.control_inputs) != 0:
raise ValueError("control dependencies are not supported.")
parent_ops.update(
tensor_in.op
for tensor_in in base_node.inputs
if tensor_in.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value
)
new_ops = parent_ops - visited_ops
base_ops = parent_ops
visited_ops.update(new_ops)
logger.debug("Found minimum dependency graph nodes: %s.", visited_ops)
return visited_ops
@staticmethod
def _validate_op(op: Operation) -> bool:
op_type = op.type
op_name = op.name
op_inputs = op.inputs
op_outputs = op.outputs
for s in NoGradSubgraphSlicer._INVALID_STR_IN_OP_TYPE:
if s in op_type:
logger.warning("Invalid operation type: %s which contains str: %s.", op_type, s)
return False
for s in NoGradSubgraphSlicer._INVALID_STR_IN_OP_NAME:
if s in op_name:
logger.warning("Invalid operation name: %s which contains str: %s.", op_name, s)
return False
for t in op_inputs:
if t.dtype in NoGradSubgraphSlicer._INVALID_TENSOR_DTYPE:
logger.warning("Invalid operation input tensor of operation: %s whose type is %s.", t, t.dtype)
return False
for t in op_outputs:
if t.dtype in NoGradSubgraphSlicer._INVALID_TENSOR_DTYPE:
logger.warning("Invalid operation output tensor of operation: %s whose type is %s.", t, t.dtype)
return False
return True
@staticmethod
def _update_subgraph_in(
base_ops: Operation,
input_to_edge_ops: Dict[Operation, Set[Operation]],
sub_graph_ops: Set[Operation],
) -> None:
for input_tensor in base_ops.inputs:
input_node = input_tensor.op
if input_node not in sub_graph_ops:
res = input_to_edge_ops.get(input_node, set())
res.add(base_ops)
input_to_edge_ops[input_node] = res
@staticmethod
def _update_subgraph_out(
base_ops: Operation,
out_op_to_edge_ops: Dict[Operation, Set[Operation]],
sub_graph_ops: Set[Operation],
) -> None:
for output_tensor in base_ops.outputs:
for output_consumer in output_tensor.consumers():
if output_consumer not in sub_graph_ops:
res = out_op_to_edge_ops.get(output_consumer, set())
res.add(base_ops)
out_op_to_edge_ops[output_consumer] = res
@staticmethod
def _topo_sort_sliced_ops(sliced_ops: Set[Operation]) -> List[Operation]:
topo_subgraph_list = []
topo_subgraph_set = set()
start_nodes = set()
[start_nodes.add(x) for x in sliced_ops]
logger.info("Got topo_subgraph start nodes: %s", start_nodes)
def topo_sort_helper(curr_op, output_list, output_set):
if not isinstance(curr_op, Operation):
raise RuntimeError(f"topo_subgraph_dfs input should be node(aka. tf.Operator). {curr_op}")
curr_inputs = curr_op.inputs
logger.debug("Got topo_dfs: %s <- %s", curr_op.name, [x.name for x in curr_inputs])
current_control_inputs = curr_op.control_inputs
if len(current_control_inputs) > 0:
raise RuntimeError(
f"control input are not supported: {curr_op.name}, control_inputs: {current_control_inputs}"
)
if curr_op in output_set:
return
output_set.add(curr_op)
for tensor in curr_inputs:
node = tensor.op
if node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value and node not in output_set:
topo_sort_helper(node, output_list, output_set)
output_list.append(curr_op)
[topo_sort_helper(x, topo_subgraph_list, topo_subgraph_set) for x in start_nodes]
if len(topo_subgraph_list) != len(topo_subgraph_set):
raise RuntimeError(f"got duplicated topo node: {sorted(topo_subgraph_list, key=lambda x: x.name)}.")
logger.info("Got topo_subgraph: %s", topo_subgraph_list)
return topo_subgraph_list
@staticmethod
def _get_mapping_for_subgraph_in(
from_op: Operation,
to_ops: Set[Operation],
tensor_mapping: Union[Dict[Tensor, Tensor], Dict[SparseTensor, SparseTensor]],
) -> None:
if from_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value:
raise RuntimeError(f"expect IteratorGetNext for input tensor of subgraph, but got {from_op}")
for node in to_ops:
for each_tensor in node.inputs:
if each_tensor.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value:
continue
old_tensor_name = each_tensor.name
x_index = int(old_tensor_name.split(":")[-1])
g = tf.compat.v1.get_default_graph()
arg_tensor = g.get_tensor_by_name("args_%d:0" % x_index)
tensor_mapping[each_tensor] = arg_tensor
@staticmethod
def _get_mapping_for_subgraph(
old_op: Operation,
node_mapping: Dict[Operation, Operation],
tensor_mapping: Dict[Tensor, Tensor],
) -> None:
logger.debug("old operation name: %s\nold operation inputs: %s\n", old_op.name, [x for x in old_op.inputs])
for each_tensor in old_op.inputs:
if each_tensor not in tensor_mapping:
raise RuntimeError(
f"each_tensor(input) {each_tensor} need by {old_op.name} not in tensor_mapping.{tensor_mapping}"
)
new_inputs = NoGradSubgraphSlicer._get_mapped_tensor(tensor_mapping, old_op.inputs)
node_def = old_op.node_def
node_def.name = "{}/{}".format(NoGradSubgraphSlicer._SLICED_OP_NAME_PREFIX, node_def.name)
new_node = tf.Operation(node_def=node_def, g=tf.compat.v1.get_default_graph(), inputs=new_inputs)
node_mapping[old_op] = new_node
for old_out_tensor, new_out_tensor in zip(old_op.outputs, new_node.outputs):
tensor_mapping[old_out_tensor] = new_out_tensor
@staticmethod
def _get_mapped_tensor(tensor2tensor: Dict[Tensor, Tensor], keys: List[Tensor]) -> List[Tensor]:
tensors = []
for k in keys:
if k not in tensor2tensor:
raise KeyError(f"failed to find key tensor: {k} from tensor map: {tensor2tensor}.")
tensors.append(tensor2tensor[k])
return tensors
@staticmethod
def _sort_sliced_graph_outputs(subgraph_out: Dict[Operation, Set[Operation]]) -> List[Tensor]:
extra_outputs = []
sorted_outputs = sorted(subgraph_out.items(), key=lambda x: x[0].name)
for outside_op, edge_ops in sorted_outputs:
outside_op_inputs = set(outside_op.inputs)
for edge_op in edge_ops:
NoGradSubgraphSlicer._add_sorted_additional_tensors(extra_outputs, outside_op_inputs, edge_op)
return extra_outputs
@staticmethod
def _add_sorted_additional_tensors(extra_outputs, outside_op_inputs, edge_op) -> None:
for each_tensor in sorted(edge_op.outputs, key=lambda x: x.name):
if each_tensor not in outside_op_inputs:
continue
if each_tensor in extra_outputs:
continue
extra_outputs.append(each_tensor)
@staticmethod
def _get_tensor_consumers(tensor: Tensor) -> List[Operation]:
if not isinstance(tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS):
raise RuntimeError(f"expected 'tf.Tensor' or 'tf.SparseTensor', but got: {tensor}")
graph = tensor.graph
consumers = []
consumer_names = [op.name for op in tensor.consumers()]
with graph._lock:
for name in consumer_names:
if name not in graph._nodes_by_name:
continue
consumers.append(graph._nodes_by_name[name])
return consumers
@abc.abstractmethod
def summarize(self) -> None:
pass
@abc.abstractmethod
def slice(self) -> None:
pass
def _slice_ops(self, sliceable_ops: Set[Operation], is_training: bool) -> None:
"""Slice the minimum dependency graph of given operation set.
Args:
sliceable_ops (Set[Operation]): The operation set that can be sliced.
is_training (bool): Whether the slicing is for training graph or not.
"""
sliced_ops = self._find_min_dep_ops(sliceable_ops)
in_op_to_edge_ops, out_op_to_edge_ops = self._find_subgraph_in_and_out(sliced_ops)
old_get_next = self._find_old_get_next(sliceable_ops)
old_dataset = self._find_old_dataset(old_get_next, is_training)
new_dataset = self._make_new_dataset(old_dataset, sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops)
new_dataset = new_dataset.prefetch(0)
new_get_next = self._make_new_get_next(old_get_next, new_dataset)
self._replace_get_next(old_get_next, new_get_next, out_op_to_edge_ops, sliced_ops)
def _make_new_dataset(
self,
old_dataset: DatasetV1Adapter,
sliced_ops: Set[Operation],
in_op_to_edge_ops: Dict[Operation, Set[Operation]],
out_op_to_edge_ops: Dict[Operation, Set[Operation]],
) -> DatasetV1Adapter:
"""Make a new dataset which clones the sliced subgraph by mapfunc.
Args:
old_dataset: The old dataset that needs to be mapped.
sliced_ops: The operation set that has been sliced.
in_op_to_edge_ops: The input relationship of sliced subgraph.
out_op_to_edge_ops: The output relationship of sliced subgraph.
Returns:
DatasetV1Adapter: The new dataset that has cloned the sliced subgraph.
"""
def slice_map_func(*batch):
logger.debug("The layout of old batch: %s.", batch)
funcgraph = tf.compat.v1.get_default_graph()
flatten_batch = tf.nest.flatten(batch)
for t in flatten_batch:
if isinstance(t, NoGradSubgraphSlicer._VALID_TENSOR_CLASS):
continue
raise RuntimeError(f"expected 'tf.Tensor' or 'tf.SparseTensor' in batch, but got %s.", t)
new_batch = self._clone_subgraph_into_funcgraph(sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops, batch)
utils.export_pb_graph(
file_name=NoGradSubgraphSlicer._SLICED_SUB_GRAPH_NAME,
dump_graph=True,
graph_def=funcgraph.as_graph_def(),
export_path=self._info_dir,
)
return new_batch
return old_dataset.map(slice_map_func)
def _find_subgraph_in_and_out(
self,
sub_graph_ops: Set[Operation],
) -> Tuple[Dict[Operation, Set[Operation]], Dict[Operation, Set[Operation]]]:
"""Find the input and output relationship of sliced subgraph.
Args:
sub_graph_ops: The operation set that has been sliced.
Returns:
in_op_to_edge_ops: The input relationship of sliced subgraph.
out_op_to_edge_ops: The output relationship of sliced subgraph.
"""
in_op_to_edge_ops = dict()
out_op_to_edge_ops = dict()
for base_node in sub_graph_ops:
self._update_subgraph_in(base_node, in_op_to_edge_ops, sub_graph_ops)
self._update_subgraph_out(base_node, out_op_to_edge_ops, sub_graph_ops)
logger.info("Got input relationship of extracted subgraph: %s", in_op_to_edge_ops)
logger.info("Got output relationship of extracted subgraph: %s", out_op_to_edge_ops)
return in_op_to_edge_ops, out_op_to_edge_ops
def _find_old_get_next(self, sliceable_ops: Set[Operation]) -> Operation:
"""Find the old 'IteratorGetNext' operation.
Args:
sliceable_ops: The operation set that can be sliced.
Returns:
old_get_next: The old 'IteratorGetNext' operation.
"""
old_get_next = utils.upward_bfs_op(sliceable_ops, AnchorIteratorOp.ITERATOR_GET_NEXT.value)
self._full_graph.add_to_collection(DeprecatedOp.DEPRECATED_ITERATOR_GET_NEXT, old_get_next)
logger.info("Old 'IteratorGetNext' operation has been deprecated now.")
return old_get_next
def _find_old_dataset(self, get_next: Operation, is_training: bool) -> DatasetV1Adapter:
"""Find the old dataset that needs to be mapped.
Due to the different iterator types, the search method is different.
1. If the iterator type is 'MakeIterator', this func will exec upward bfs search through get_next.
2. If the iterator type is 'OneShotIterator', this func will fetch all operation in 'self._full_graph', then
filter out the 'PrefetchDataset' operation. This diff is caused by the isolation of 'OneShotIterator' and the
'PrefetchDataset'.
Args:
get_next: The old 'IteratorGetNext' operation.
is_training: Whether the slicing is for training graph or not.
Returns:
old_dataset: The old dataset that needs to be mapped.
"""
tgt_trans_dataset = None
try:
tgt_trans_dataset = utils.find_trans_dataset(self._full_graph, get_next)
except (ValueError, TypeError, RuntimeError) as err:
trans_datasets = [
op
for op in self._full_graph.get_operations()
if AnchorDatasetOp.PREFETCH_DATASET.value in op.name
]
trans_datasets = list(
filter(
lambda op: op not in tf.compat.v1.get_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET),
trans_datasets,
)
)
sorted_datasets = sorted(trans_datasets, key=lambda op: op.name)
if len(trans_datasets) == 1:
tgt_trans_dataset = sorted_datasets[0]
elif is_training and len(sorted_datasets) == 2:
tgt_trans_dataset = sorted_datasets[0]
elif not is_training and len(sorted_datasets) == 2:
tgt_trans_dataset = sorted_datasets[0]
else:
raise RuntimeError(f"target transformation dataset not found, got datasets: {trans_datasets}.") from err
except Exception as err:
raise RuntimeError(f"the dataset was not found, the error is `{err}`.") from err
if not tgt_trans_dataset.outputs:
raise ValueError(f"the length of the outputs of target op `{tgt_trans_dataset}` is 0.")
logger.info("Find target op `%s`, and output is `%s`.", tgt_trans_dataset.name, tgt_trans_dataset.outputs)
self._full_graph.add_to_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET, tgt_trans_dataset)
old_dataset = utils.find_target_instance_dataset(self._full_graph, tgt_trans_dataset.outputs[0])
return old_dataset
def _clone_subgraph_into_funcgraph(
self,
sliced_ops: Set[Operation],
in_op_to_edge_ops: Set[Operation],
out_op_to_edge_ops: Set[Operation],
batch: Tuple[Dict[str, Union[Tensor, SparseTensor, Dict]]],
) -> Dict[str, Union[Tensor, SparseTensor, Dict]]:
"""Clone the sliced subgraph into a new funcgraph.
Args:
sliced_ops: The operation set that has been sliced.
in_op_to_edge_ops: The input relationship of sliced subgraph.
out_op_to_edge_ops: The output relationship of sliced subgraph.
batch: The original batch layout of old dataset.
Returns:
new_batch: The new batch layout of new dataset.
"""
topo_subgraph_list = self._topo_sort_sliced_ops(sliced_ops)
node_mapping = {}
tensor_mapping = {}
for in_op, edge_ops in in_op_to_edge_ops.items():
self._get_mapping_for_subgraph_in(in_op, edge_ops, tensor_mapping)
for old_op in topo_subgraph_list:
self._get_mapping_for_subgraph(old_op, node_mapping, tensor_mapping)
logger.info("Got node_mapping: %s", node_mapping)
logger.info("Got tensor_mapping: %s", tensor_mapping)
ordered_output_tensors = self._sort_sliced_graph_outputs(out_op_to_edge_ops)
extra_output_tensor = self._get_mapped_tensor(tensor_mapping, ordered_output_tensors)
if not isinstance(batch, tuple):
batch = (batch,)
new_batch = batch[0]
for tensor in extra_output_tensor:
next_last_key = f"{sorted(new_batch)[-1]}_"
new_batch[next_last_key] = tensor
logger.debug("Got new batch layout: %s.", new_batch)
return new_batch
def _make_new_get_next(
self,
old_get_next: Operation,
new_dataset: DatasetV1Adapter,
) -> Operation:
"""Make new 'IteratorGetNext' operation.
1. This func will automatically detect the iterator type of the old dataset, and then make 'IteratorGetNext'
from the corresponding iterator.
2. Only 'MakeIterator' and 'OneShotIterator' are available now.
Args:
old_get_next: The old 'IteratorGetNext' operation.
new_dataset: The new dataset which contains sliced subgraph and corresponding additional outputs.
Returns:
new_get_next: The new 'IteratorGetNext' operation.
"""
if not old_get_next.outputs:
raise RuntimeError("no available tensor in the dataset. Please check the dataset and data processing.")
iter_type = None
if old_get_next.inputs:
iter_type = old_get_next.inputs[0].op.type
if iter_type == AnchorIteratorOp.ITERATOR_V2.value:
iter_type = utils.find_make_iterator_op(self._full_graph, old_get_next.outputs[0]).type
if iter_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value):
raise RuntimeError(
f"only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, "
f"but the current iterator is `{iter_type}`."
)
logger.info("The iterator type of old dataset is %s.", iter_type)
if iter_type == AnchorIteratorOp.MAKE_ITERATOR.value:
new_iterator = tf.compat.v1.data.make_initializable_iterator(new_dataset)
else:
new_iterator = tf.compat.v1.data.make_one_shot_iterator(new_dataset)
logger.info("Got new iterator: %s from dataset %s.", new_iterator, new_dataset)
new_batch_name = "{}/{}".format(
NoGradSubgraphSlicer._SLICED_OP_NAME_PREFIX, AnchorIteratorOp.ITERATOR_GET_NEXT.value
)
new_batch = new_iterator.get_next(name=new_batch_name)
if "timestamp" in new_batch.keys():
tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, new_batch["timestamp"])
try:
new_batch_tensor = new_batch
while not isinstance(new_batch_tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS):
if isinstance(new_batch_tensor, tuple):
new_batch_tensor = new_batch_tensor[0]
elif isinstance(new_batch_tensor, dict):
new_batch_tensor = list(new_batch_tensor.values())
elif isinstance(new_batch_tensor, list):
new_batch_tensor = new_batch_tensor[0]
elif isinstance(new_batch_tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS):
break
else:
raise RuntimeError(f"batch value {new_batch_tensor} of {type(new_batch_tensor)} is not supported.")
except IndexError as err:
raise IndexError("cannot find a tensor from given batch.") from err
new_get_next = utils.upward_bfs_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value)
logger.info("Got old_new_get_next: %s.", new_get_next)
return new_get_next
def _replace_get_next(
self,
old_get_next: Operation,
new_get_next: Operation,
out_op_to_edge_ops: Dict[Operation, Set[Operation]],
sliced_ops: Set[Operation],
) -> None:
"""Replace the old 'IteratorGetNext' operation with the new one.
1. This func will update the consumer of the old 'IteratorGetNext' operation to the new one.
2. This func will update the consumer of the output tensors of the sliced subgraph to the new one.
Args:
old_get_next: The old 'IteratorGetNext' operation.
new_get_next: The new 'IteratorGetNext' operation.
out_op_to_edge_ops: The output relationship of sliced subgraph.
sliced_ops: The operation set that has been sliced.
"""
for t in old_get_next.outputs:
self._update_old_get_next_consumer(t, new_get_next, sliced_ops)
next_offset = len(old_get_next.outputs) - 1
sorted_outputs = self._sort_sliced_graph_outputs(out_op_to_edge_ops)
for t in sorted_outputs:
next_offset += 1
self._update_sliced_graph_consumer(t, new_get_next, next_offset)
def _update_old_get_next_consumer(
self, old_get_next_output: Tensor, new_get_next: Operation, sliced_ops: Set[Operation]
) -> None:
"""Update the consumer of the old 'IteratorGetNext' operation to the new one.
Args:
old_get_next_output: The output tensor of the old 'IteratorGetNext' operation.
new_get_next: The new 'IteratorGetNext' operation.
sliced_ops: The operation set that has been sliced.
"""
old_tensor_name = old_get_next_output.name
output_index = old_tensor_name.split(":")[-1]
new_tensor_name = f"{new_get_next.name}:{output_index}"
new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name)
old_tensor_consumers = self._get_tensor_consumers(old_get_next_output)
for consumer in old_tensor_consumers:
if consumer in sliced_ops:
logger.debug("Ignore consumer: %s in sliced operations.", consumer.name)
continue
for i, t in enumerate(consumer.inputs):
if t != old_get_next_output:
logger.debug(
"Ignore input %s of consumer %s, cause it not output of 'IteratorGetNext'.",
t.name,
consumer.name,
)
continue
consumer._update_input(i, new_tensor)
logger.debug(
"Succeed replace old input %s of consumer %s to new input %s.",
old_tensor_name,
consumer.name,
new_tensor,
)
def _update_sliced_graph_consumer(
self, sliced_graph_output: Tensor, new_get_next: Operation, next_offset: int
) -> None:
"""Update the consumer of the output tensors of the sliced subgraph to the new one.
The outputs of the sliced subgraph are not the original outputs of 'IteratorGetNext'. Thus, next offset should
trace the last index of outputs of new 'IteratorGetNext'.
Args:
sliced_graph_output: The output tensor of the sliced subgraph.
new_get_next: The new 'IteratorGetNext' operation.
next_offset: The last offset of the new 'IteratorGetNext' operation.
"""
new_tensor_name = f"{new_get_next.name}:{next_offset}"
new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name)
old_tensor_consumers = self._get_tensor_consumers(sliced_graph_output)
for consumer in old_tensor_consumers:
if consumer.type in NoGradSubgraphSlicer._INVALID_CONSUMER_OP_TYPE:
logger.debug("Ignore invalid consumer: %s.", consumer.name)
continue
for i, t in enumerate(consumer.inputs):
if t != sliced_graph_output:
logger.debug(
"Ignore input %s of consumer %s, cause it not output of sliced graph.",
t.name,
consumer.name,
)
continue
consumer._update_input(i, new_tensor)
logger.debug(
"Succeed replace old input %s of consumer %s to new input %s.",
sliced_graph_output,
consumer.name,
new_tensor,
)
@para_checker_decorator(
check_option_list=[
("op_types", ClassValidator, {"classes": (list,)}),
("full_graph", ClassValidator, {"classes": (Graph, type(None))}),
("info_dir", ClassValidator, {"classes": (str,)}),
]
)
class LookupSubgraphSlicer(NoGradSubgraphSlicer):
def __init__(self, op_types: List[str], full_graph: Graph = None, info_dir: str = "lookup_slicing") -> None:
"""Initialize LookupSubgraphSlicer.
Args:
op_types: The list of operation types to be sliced in lookup subgraph.
full_graph: The full graph to be sliced. If None, the default graph will be used.
info_dir: The directory to save the slicing information. Defaults to "lookup_slicing".
"""
super().__init__(full_graph, info_dir)
if not op_types:
raise ValueError("no slicing operation types specified!")
self._op_types = set(op_types)
def summarize(self) -> None:
all_tgt_ops = self._find_all_tgt_ops()
(train_sliceable_tgt_ops, eval_sliceable_tgt_ops) = self._find_sliceable_tgt_ops()
all_sliceable_tgt_ops = train_sliceable_tgt_ops | eval_sliceable_tgt_ops
result = {"Operation Type": [], "Total Num": [], "Sliceable Num": [], "Sliceable Ratio": []}
for op_type in self._op_types:
tgt_ops = set(filter(lambda op: op.type == op_type, all_tgt_ops))
sliceable_tgt_ops = set(filter(lambda op: op.type == op_type, all_sliceable_tgt_ops))
total_num = len(tgt_ops)
sliceable_num = len(sliceable_tgt_ops)
try:
sliceable_ratio = sliceable_num / total_num
except ZeroDivisionError:
logger.warning("No target operaiton types '%s' found in given graph.", self._op_types)
result["Operation Type"].append(op_type)
result["Total Num"].append(total_num)
result["Sliceable Num"].append(sliceable_num)
result["Sliceable Ratio"].append(sliceable_ratio)
result_df = pd.DataFrame(data=result)
file = "{}/{}".format(self._info_dir, NoGradSubgraphSlicer._SLICING_SUMMARY_NAME)
result_df.to_csv(file, sep=",")
logger.info("Summary of slicing:\n%s", result_df)
def slice(self) -> None:
utils.export_pb_graph(
file_name=NoGradSubgraphSlicer._UNSLICED_FULL_GRAPH_NAME,
dump_graph=True,
graph_def=self._full_graph.as_graph_def(),
export_path=self._info_dir,
)
(train_sliceable_ops, eval_sliceable_ops) = self._find_sliceable_tgt_ops()
if train_sliceable_ops:
logger.info("Start to slice training lookup subgraph.")
self._slice_ops(train_sliceable_ops, is_training=True)
if eval_sliceable_ops:
logger.info("Start to slice evaluation lookup subgraph.")
self._slice_ops(eval_sliceable_ops, is_training=False)
utils.export_pb_graph(
file_name=NoGradSubgraphSlicer._SLICED_FULL_GRAPH_NAME,
dump_graph=True,
graph_def=self._full_graph.as_graph_def(),
export_path=self._info_dir,
)
def _find_all_tgt_ops(self) -> Set[Operation]:
"""Found all operations of specific types in full graph."""
all_tgt_ops = set()
all_ops = self._full_graph.get_operations()
for op in all_ops:
if op.type not in self._op_types:
continue
all_tgt_ops.add(op)
return all_tgt_ops
def _find_sliceable_tgt_ops(self) -> Tuple[Set[Operation], Set[Operation]]:
"""Found sliceable operations of given types in lookup subgraph."""
lookup_keys = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE)
train_base_ops = set()
eval_base_ops = set()
for t in lookup_keys:
if BaseSparseEmbedding.get_anchor_attribute(t, ASCAnchorAttr.IS_TRAINING):
train_base_ops.add(t.op)
else:
eval_base_ops.add(t.op)
def find_sliceable_ops(base_ops):
min_dep_ops = self._find_min_dep_ops(base_ops)
sliceable_ops = set()
for op in min_dep_ops:
if not self._validate_op(op):
continue
if op.type not in self._op_types:
continue
sliceable_ops.add(op)
return sliceable_ops
train_sliceable_ops = find_sliceable_ops(train_base_ops)
eval_sliceable_ops = find_sliceable_ops(eval_base_ops)
logger.debug("Found sliceable operations in training lookup subgraph: %s.", train_sliceable_ops)
logger.debug("Found sliceable operations in evaluation lookup subgraph: %s.", eval_sliceable_ops)
return (train_sliceable_ops, eval_sliceable_ops)
@para_checker_decorator(
check_option_list=[
("full_graph", ClassValidator, {"classes": (Graph, type(None))}),
("info_dir", ClassValidator, {"classes": (str,)}),
]
)
class OrphanLookupKeySlicer(NoGradSubgraphSlicer):
def __init__(self, full_graph: Graph = None, info_dir: str = "orphan_slicing") -> None:
"""Initialize OrphanLookupKeySlicer.
Args:
full_graph: The full graph to be sliced. If None, the default graph will be used.
info_dir: The directory to save the slicing information. Defaults to "orphan_slicing".
"""
super().__init__(full_graph, info_dir)
def summarize(self) -> None:
(train_sliceable_ops, _) = self._find_sliceable_tgt_ops()
if len(train_sliceable_ops) == 0:
return
result = {"Operation Type": [], "Operation Name": []}
for op in train_sliceable_ops:
result["Operation Type"].append(op.type)
result["Operation Name"].append(op.name)
result_df = pd.DataFrame(data=result)
file = "{}/{}".format(self._info_dir, NoGradSubgraphSlicer._SLICING_SUMMARY_NAME)
result_df.to_csv(file, sep=",")
logger.info("Summary of slicing:\n%s", result_df)
def slice(self) -> None:
utils.export_pb_graph(
file_name=NoGradSubgraphSlicer._UNSLICED_FULL_GRAPH_NAME,
dump_graph=True,
graph_def=self._full_graph.as_graph_def(),
export_path=self._info_dir,
)
(train_sliceable_ops, eval_sliceable_ops) = self._find_sliceable_tgt_ops()
if train_sliceable_ops:
logger.info("Start to slice training lookup subgraph.")
self._slice_ops(train_sliceable_ops, is_training=True)
if eval_sliceable_ops:
logger.info("Start to slice evaluation lookup subgraph.")
self._slice_ops(eval_sliceable_ops, is_training=False)
utils.export_pb_graph(
file_name=NoGradSubgraphSlicer._SLICED_FULL_GRAPH_NAME,
dump_graph=True,
graph_def=self._full_graph.as_graph_def(),
export_path=self._info_dir,
)
def _slice_ops(self, sliceable_ops: Set[Operation], is_training: bool) -> None:
"""Override the '_slice_ops' protected method of super class."""
sliced_ops = self._find_min_dep_ops(sliceable_ops)
in_op_to_edge_ops, out_op_to_edge_ops = self._find_subgraph_in_and_out(sliced_ops)
all_get_nexts = [
op
for op in self._full_graph.get_operations()
if op.type == AnchorIteratorOp.ITERATOR_GET_NEXT.value
]
alive_get_nexts = list(
filter(
lambda op: op not in self._full_graph.get_collection(DeprecatedOp.DEPRECATED_ITERATOR_GET_NEXT),
all_get_nexts,
)
)
alive_get_nexts = sorted(alive_get_nexts, key=lambda op: op.name)
old_get_next = None
if len(alive_get_nexts) == 1:
old_get_next = alive_get_nexts[0]
else:
old_get_next = alive_get_nexts[0] if is_training else alive_get_nexts[1]
old_dataset = self._find_old_dataset(old_get_next, is_training)
new_dataset = self._make_new_dataset(old_dataset, sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops)
new_dataset = new_dataset.prefetch(0)
new_get_next = self._make_new_get_next(old_get_next, new_dataset)
self._replace_get_next(old_get_next, new_get_next, out_op_to_edge_ops, sliced_ops)
def _find_sliceable_tgt_ops(self) -> Tuple[Set[Operation], Set[Operation]]:
"""Found orhpan keys' additional identity operation in lookup subgraph."""
lookup_keys = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE)
train_base_ops = set()
eval_base_ops = set()
for t in lookup_keys:
if BaseSparseEmbedding.get_anchor_attribute(t, ASCAnchorAttr.IS_TRAINING):
train_base_ops.add(t.op)
else:
eval_base_ops.add(t.op)
def find_sliceable_ops(base_ops):
min_dep_ops = self._find_min_dep_ops(base_ops)
sliceable_ops = set()
for op in min_dep_ops:
if not self._validate_op(op):
continue
if ORPHAN_LOOKUP_KEY_PREFIX not in op.name:
continue
sliceable_ops.add(op)
return sliceable_ops
train_sliceable_ops = find_sliceable_ops(train_base_ops)
eval_sliceable_ops = find_sliceable_ops(eval_base_ops)
logger.debug("Found sliceable operations in training lookup subgraph: %s.", train_sliceable_ops)
logger.debug("Found sliceable operations in evaluation lookup subgraph: %s.", eval_sliceable_ops)
return (train_sliceable_ops, eval_sliceable_ops)