import unittest
import torch
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.base import PerformanceModel
from tensor_cast.performance_model.memory_tracker import MemoryTracker
from tensor_cast.performance_model.op_invoke_info import OpInvokeInfo
from tensor_cast.runtime import Runtime, RuntimeEvent
class TestMemoryTracker(unittest.TestCase):
def _run_and_check(self, test_func, input_tensors, expected_profile_bytes):
with (
Runtime([], TEST_DEVICE, MemoryTracker(TEST_DEVICE)) as runtime,
torch.no_grad(),
):
_ = test_func(*input_tensors)
profile = runtime.memory_tracker.get_profile()
self.assertEqual(
len(profile),
len(expected_profile_bytes),
"Mismatch in the number of ops tracked.",
)
for i, op_profile in enumerate(profile):
expected_before, expected_after = expected_profile_bytes[i]
op_name = op_profile.op_invoke_info.func.__name__ if op_profile.op_invoke_info is not None else "output"
self.assertEqual(
op_profile.usage_before_call_bytes,
expected_before,
f"Op {i} ({op_name}): 'before' memory mismatch.",
)
self.assertEqual(
op_profile.usage_after_call_bytes,
expected_after,
f"Op {i} ({op_name}): 'after' memory mismatch.",
)
def test_simple_allocation(self):
"""Tests a basic op with a new memory allocation."""
def func(x, y):
return torch.add(x, y)
x = torch.randn(100)
y = torch.randn(100)
expected_profile = [(800, 1200), (1200, 1200)]
self._run_and_check(func, [x, y], expected_profile)
def test_inplace_mutation(self):
"""Tests an in-place op that does not allocate memory."""
def func(x, y):
x.add_(y)
return x
x = torch.randn(100)
y = torch.randn(100)
expected_profile = [(800, 800), (800, 800)]
self._run_and_check(func, [x, y], expected_profile)
def test_view_op_alias(self):
"""Tests a view operation where the output aliases the input."""
def func(x, y):
v = x.view(-1)
return torch.add(v, y)
x = torch.randn(10, 10)
y = torch.randn(100)
expected_profile = [(800, 800), (800, 1200), (1200, 1200)]
self._run_and_check(func, [x, y], expected_profile)
def test_multi_output_alias(self):
"""Tests an op like `split` where multiple outputs alias one input."""
def func(x):
y, z = torch.split(x, 50)
return torch.mul(y, 2), torch.sin(z)
x = torch.randn(100)
expected_profile = [(400, 400), (400, 600), (600, 800), (800, 800)]
self._run_and_check(func, [x], expected_profile)
def test_alias_chain(self):
"""Tests when a view is created from another view."""
def func(x):
y = x.transpose(0, 1)
z = torch.select(y, 0, 0)
return z * 2.0
x = torch.randn(10, 10)
expected_profile = [(400, 400), (400, 400), (400, 440), (440, 440)]
self._run_and_check(func, [x], expected_profile)
def test_mixed_lifetimes_and_aliases(self):
"""Tests a complex sequence with allocations, views, and varied tensor lifetimes."""
def func(x, y, z):
a = x + y
b = a.view(-1)
c = b * 2.0
d = c + z
return d
x = torch.randn(100)
y = torch.randn(100)
z = torch.randn(100)
expected_profile = [
(1200, 1600),
(1600, 1600),
(1600, 2000),
(1600, 2000),
(1600, 1600),
]
self._run_and_check(func, [x, y, z], expected_profile)
def test_model_output_alias(self):
"""Ensures that an aliased tensor that is a model output is not freed."""
def func(x):
return x.view(-1)
x = torch.randn(100)
expected_profile = [(400, 400), (400, 400)]
self._run_and_check(func, [x], expected_profile)
def test_unused_alias_is_not_model_output_when_source_is_used_later(self):
"""Ensures dead alias outputs do not keep their source alive."""
x = torch.randn(100)
a = x + 1.0
view_out = a.view(-1)
mul_out = a * 2.0
add_info = OpInvokeInfo(torch.ops.aten.add.Tensor, (x, 1.0), {}, a)
view_info = OpInvokeInfo(torch.ops.aten.view.default, (a, [-1]), {}, view_out)
mul_info = OpInvokeInfo(torch.ops.aten.mul.Tensor, (a, 2.0), {}, mul_out)
mt = MemoryTracker(TEST_DEVICE)
mt.record_single_op_invocation(add_info)
mt.record_single_op_invocation(view_info)
mt.record_single_op_invocation(mul_info)
mt.analyze()
profile = mt.get_profile()
expected_profile = [(400, 800), (800, 800), (800, 1200), (800, 800)]
self.assertEqual(len(profile), len(expected_profile))
for op_profile, expected in zip(profile, expected_profile):
self.assertEqual(
(
op_profile.usage_before_call_bytes,
op_profile.usage_after_call_bytes,
),
expected,
)
def test_alias_from_kwargs_input(self):
"""Ensures alias tracking works when tensor inputs are passed by kwargs."""
def func(x):
y = torch.ops.aten.view.default(self=x, size=[-1])
return torch.add(input=y, other=1.0)
x = torch.randn(100)
expected_profile = [(400, 400), (400, 800), (800, 800)]
self._run_and_check(func, [x], expected_profile)
def test_multistream_wait_anchor_does_not_add_model_input_memory(self):
"""Ensures wait anchors do not turn their output into extra model input memory."""
x = torch.randn(100)
wait_out = x.clone()
relu_out = torch.relu(x)
neg_out = torch.neg(wait_out)
wait_info = OpInvokeInfo(
torch.ops.tensor_cast._internal_wait_and_bind.default,
(x, 1, []),
{},
wait_out,
)
relu_info = OpInvokeInfo(torch.ops.aten.relu.default, (x,), {}, relu_out)
neg_info = OpInvokeInfo(torch.ops.aten.neg.default, (wait_out,), {}, neg_out)
runtime = Runtime([], TEST_DEVICE, MemoryTracker(TEST_DEVICE))
runtime.op_info_group = [relu_info, wait_info, neg_info]
runtime.replay_op_invoke_infos()
runtime.memory_tracker.analyze()
profile = runtime.memory_tracker.get_profile()
expected_profile = [(400, 800), (800, 1200), (1200, 1200)]
self.assertEqual(len(profile), len(expected_profile))
for op_profile, expected in zip(profile, expected_profile):
self.assertEqual(
(
op_profile.usage_before_call_bytes,
op_profile.usage_after_call_bytes,
),
expected,
)
def test_multistream_memory_keeps_def_use_order(self):
"""Ensures memory replay does not classify intermediates as inputs."""
produced = torch.empty(100)
consumed = torch.neg(produced)
producer_info = OpInvokeInfo(
torch.ops.aten.empty.memory_format,
([100],),
{},
produced,
)
consumer_info = OpInvokeInfo(
torch.ops.aten.neg.default,
(produced,),
{},
consumed,
)
runtime = Runtime([], TEST_DEVICE, MemoryTracker(TEST_DEVICE))
runtime.event_list = [
RuntimeEvent(
op_invoke_info=producer_info,
perf_results={"analytic": PerformanceModel.Result(10.0)},
stream_id=1,
),
RuntimeEvent(
op_invoke_info=consumer_info,
perf_results={"analytic": PerformanceModel.Result(1.0)},
stream_id=0,
),
]
runtime._event_reference_ids = [0, 0]
runtime._record_memory_invocations()
runtime.memory_tracker.analyze()
profile = runtime.memory_tracker.get_profile()
expected_profile = [(0, 400), (400, 800), (400, 400)]
self.assertEqual(len(profile), len(expected_profile))
for op_profile, expected in zip(profile, expected_profile):
self.assertEqual(
(
op_profile.usage_before_call_bytes,
op_profile.usage_after_call_bytes,
),
expected,
)