import unittest
from unittest import mock
from mx_rec.graph.slicers import NoGradSubgraphSlicer
from mx_rec.graph.hooks import LookupSubgraphSlicerHook, OrphanLookupKeySlicerHook
class MockLookupSubgraphSlicer(NoGradSubgraphSlicer):
def __init__(self, op_types) -> None:
super().__init__()
def summarize(self) -> None:
pass
def slice(self) -> None:
pass
class MockOrphanLookupKeySlicer(NoGradSubgraphSlicer):
def __init__(self) -> None:
super().__init__()
def summarize(self) -> None:
pass
def slice(self) -> None:
pass
class TestLookupSubgraphSlicerHook(unittest.TestCase):
@mock.patch.multiple(
"mx_rec.graph.hooks",
LookupSubgraphSlicer=mock.MagicMock(return_value=MockLookupSubgraphSlicer(["xxx"])),
)
def test_ok(self):
hook = LookupSubgraphSlicerHook(["xxx"])
hook.begin()
self.assertIsNotNone(hook)
class TestOrphanLookupKeySlicerHook(unittest.TestCase):
@mock.patch.multiple(
"mx_rec.graph.hooks",
OrphanLookupKeySlicer=mock.MagicMock(return_value=MockOrphanLookupKeySlicer()),
)
def test_ok(self):
hook = OrphanLookupKeySlicerHook()
hook.begin()
self.assertIsNotNone(hook)