import os
import pathlib
import shutil
import unittest
from unittest import TestCase
import tensorflow as tf
from tensorflow import Tensor, TensorSpec
from mx_rec.constants.constants import ASCAnchorAttr
from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding
from mx_rec.graph.utils import (
find_trans_dataset,
find_parent_op,
find_make_iterator_op,
find_target_instance_dataset,
upward_bfs_op,
check_and_force_list,
check_cutting_points,
export_pb_graph,
make_sorted_key_to_tensor_list,
replace_anchor_vec,
)
from graph.mock_dataset import gen_mock_dataset
class FindTransDatasetTest(TestCase):
def setUp(self) -> None:
self._graph = tf.compat.v1.get_default_graph()
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_dataset = gen_mock_dataset()
mock_iterator = mock_dataset.make_initializable_iterator()
mock_batch = mock_iterator.get_next()
mock_ids = mock_batch.get("mock_ids")
mock_get_next_op = mock_ids.op
found_dataset_op = find_trans_dataset(self._graph, mock_get_next_op)
self.assertEqual(found_dataset_op.type, "OptimizeDataset")
def test_err_invalid_op_type(self):
mock_get_next_op = tf.zeros(shape=(4096, 8)).op
with self.assertRaises(TypeError):
find_trans_dataset(self._graph, mock_get_next_op)
class FindParentOpTest(TestCase):
def tearDown(self):
tf.compat.v1.reset_default_graph()
def test_ok(self):
tsr1 = tf.constant([1, 2, 3], dtype=tf.int64)
mock_parent_op = tsr1.op
tsr2 = tf.identity(tsr1)
mock_child_op = tsr2.op
parent_op = find_parent_op(mock_child_op)
self.assertEqual([mock_parent_op], parent_op)
class FindMakeIteratorOpTest(TestCase):
def setUp(self) -> None:
self._graph = tf.compat.v1.get_default_graph()
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_dataset = gen_mock_dataset()
mock_iterator = mock_dataset.make_initializable_iterator()
mock_batch = mock_iterator.get_next()
mock_ids = mock_batch.get("mock_ids")
found_iter_op = find_make_iterator_op(self._graph, mock_ids)
self.assertEqual(found_iter_op.type, "MakeIterator")
def test_err_no_tgt_dataset_op(self):
mock_ids = tf.zeros(shape=(4096, 8))
with self.assertRaises(ValueError):
find_make_iterator_op(self._graph, mock_ids)
class FindTargetInstanceDatasetTest(TestCase):
def setUp(self) -> None:
self._graph = tf.compat.v1.get_default_graph()
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
def test_err_no_target_dataset_instance(self):
with self.assertRaises(LookupError):
find_target_instance_dataset(self._graph, None)
class UpwardBFSOpTest(TestCase):
def setUp(self) -> None:
self._graph = tf.compat.v1.get_default_graph()
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_dataset = gen_mock_dataset()
mock_iterator = mock_dataset.make_initializable_iterator()
mock_batch = mock_iterator.get_next()
mock_ids = mock_batch.get("mock_ids")
mock_base_op = tf.identity(mock_ids).op
found_tgt_dataset_op = upward_bfs_op(base_ops=mock_base_op, tgt_op_type="IteratorGetNext")
self.assertEqual(found_tgt_dataset_op, mock_ids.op)
def test_err_no_tgt_op_type(self):
mock_ids = tf.zeros(shape=(4096, 8))
mock_base_op = mock_ids.op
with self.assertRaises(ValueError):
upward_bfs_op(base_ops=mock_base_op, tgt_op_type="IteratorGetNext")
class CheckCuttingPointsTest(TestCase):
def setUp(self):
self._generator_iter_times = 3
def tearDown(self):
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_cutting_point_list = [tf.identity(tf.zeros(shape=(1,))) for _ in range(self._generator_iter_times)]
check_cutting_points(mock_cutting_point_list)
def test_err_invalid_cutting_point_list(self):
mock_cutting_point_list = ["point" for _ in range(self._generator_iter_times)]
with self.assertRaises(TypeError):
check_cutting_points(mock_cutting_point_list)
def test_err_invalid_cutting_point_operation(self):
mock_cutting_point_list = [tf.zeros(shape=(1,)) for _ in range(self._generator_iter_times)]
with self.assertRaises(ValueError):
check_cutting_points(mock_cutting_point_list)
class CheckAndForceListTest(TestCase):
def tearDown(self):
tf.compat.v1.reset_default_graph()
def test_ok_single_object(self):
mock_obj = "obj"
obj_type = str
checked_objs = check_and_force_list(mock_obj, obj_type)
self.assertEqual([mock_obj], checked_objs)
def test_ok_object_list(self):
mock_objs = ["obj1", "obj2", "ojb3"]
obj_type = str
checked_cutting_points = check_and_force_list(mock_objs, obj_type)
self.assertEqual(mock_objs, checked_cutting_points)
def test_err_inconsistent_object_and_type(self):
mock_objs = ["obj1", "obj2", "ojb3"]
obj_type = Tensor
with self.assertRaises(ValueError):
check_and_force_list(mock_objs, obj_type)
class ExportPBGraphTest(TestCase):
def setUp(self) -> None:
self._dir_name = "./export_graph"
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
if os.path.isdir(self._dir_name):
shutil.rmtree(self._dir_name)
def test_ok(self):
mock_file_name = "test_graph.pbtxt"
dump_graph = True
mock_graph_def = tf.Graph().as_graph_def()
as_text = True
export_pb_graph(mock_file_name, dump_graph, mock_graph_def, self._dir_name, as_text)
path = pathlib.Path(self._dir_name + "/" + mock_file_name)
self.assertTrue(path.is_file())
class MakeSortedKeyToTensorListTest(TestCase):
def tearDown(self) -> None:
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_batch = {
"item_ids": TensorSpec(shape=(4096, 16), dtype=tf.int64),
"user_ids": TensorSpec(shape=(4096, 8), dtype=tf.int64),
"category_ids": TensorSpec(shape=(4096, 3), dtype=tf.int64),
"label_0": TensorSpec(shape=(4096,), dtype=tf.int64),
"label_1": TensorSpec(shape=(4096,), dtype=tf.int64),
"user_ids_last_key": TensorSpec(shape=(4096, 16), dtype=tf.int64),
"user_ids_last_key_last_key": TensorSpec(shape=(4096, 8), dtype=tf.int64),
}
mock_element_spec = [mock_batch]
mock_sorted_keys = []
mock_prefix = "mock_prefix"
expected = [
"mock_prefix_0_item_ids",
"mock_prefix_0_item_ids_user_ids",
"mock_prefix_0_item_ids_user_ids_category_ids",
"mock_prefix_0_item_ids_user_ids_category_ids_label_0",
"mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1",
"mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1_user_ids_last_key",
"mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1_user_ids_last_key_user_ids_last_key_last_key",
]
sorted_batch_keys = make_sorted_key_to_tensor_list(mock_element_spec, mock_sorted_keys, mock_prefix)
self.assertEqual(sorted_batch_keys, expected)
class ReplaceAnchorVecTest(TestCase):
def tearDown(self):
tf.compat.v1.reset_default_graph()
def test_ok(self):
mock_cutting_point = tf.zeros(shape=(4096, 8), dtype=tf.int64, name="ids")
mock_attribute = ASCAnchorAttr.MOCK_LOOKUP_RESULT
mock_anchor = tf.zeros(shape=(4096, 8), dtype=tf.float32, name="anchor")
anchor_vec = tf.identity(mock_cutting_point, name="anchor_vec")
anchor_vec_output = tf.identity(anchor_vec, name="anchor_vec_output")
BaseSparseEmbedding.anchor_tensor_specs[mock_cutting_point][mock_attribute] = anchor_vec
replace_anchor_vec(tf.compat.v1.get_default_graph(), mock_cutting_point, mock_attribute, mock_anchor)
self.assertEqual(anchor_vec_output.op.inputs[0], mock_anchor)
if __name__ == "__main__":
unittest.main()