import unittest
from unittest import mock
from unittest.mock import patch
import tensorflow as tf
from mx_rec.util.global_env_conf import global_env
from mx_rec.util.variable import get_dense_and_sparse_variable
from core.mock_class import MockConfigInitializer
class MockTableInstance:
def __init__(self):
self.is_hbm = False
@patch.multiple(
"mx_rec.graph.patch",
ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()),
)
class VariableTest(unittest.TestCase):
def setUp(self):
"""
准备步骤
:return:无
"""
self.cm_worker_size = global_env.cm_worker_size
self.cm_chief_device = global_env.cm_chief_device
global_env.cm_worker_size = "8"
global_env.cm_chief_device = "0"
def tearDown(self):
"""
销毁步骤
:return: 无
"""
global_env.cm_worker_size = self.cm_worker_size
global_env.cm_chief_device = self.cm_chief_device
@mock.patch("mx_rec.util.variable.ConfigInitializer")
def test_get_dense_and_sparse_variable(self, variable_config_initializer):
mock_config_initializer = MockConfigInitializer(ascend_global_hashtable_collection="sparse_hastable")
variable_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer)
dense_layer = tf.Variable([1, 2], trainable=True)
sparse_emb = tf.Variable([4, 5], trainable=False)
tf.compat.v1.add_to_collection("sparse_hastable", sparse_emb)
tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, dense_layer)
dense_variables, sparse_variables = get_dense_and_sparse_variable()
with tf.Session() as sess:
result = tf.reduce_all(tf.equal(dense_layer, dense_variables))
sess.run(tf.compat.v1.global_variables_initializer())
result_run = sess.run([result])
self.assertTrue(result_run)
tf.reset_default_graph()
if __name__ == '__main__':
unittest.main()