import torch
import torch_npu
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.utils import disable_cache_limit
class ConfigTests(torch._dynamo.test_case.TestCase):
@disable_cache_limit()
def test_no_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_static = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_static)(fn)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_static.frame_count, 10)
@disable_cache_limit()
def test_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic.frame_count, 2)
@disable_cache_limit()
def test_no_assume_static_by_default(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic.frame_count, 1)
def test_config_compile_ignored(self):
dynamo_guarded_config_ignorelist = {
"log_file_name",
"verbose",
"verify_correctness",
"cache_size_limit",
"accumulated_cache_size_limit",
"replay_record_enabled",
"cprofile",
"repro_after",
"repro_level",
"repro_forward_only",
"repro_tolerance",
"same_two_models_use_fp64",
"error_on_recompile",
"report_guard_failures",
"base_dir",
"DEBUG_DIR_VAR_NAME",
"debug_dir_root",
}
for k in dynamo_guarded_config_ignorelist:
assert k in torch._dynamo.config._compile_ignored_keys
def test_config_hash(self):
config = torch._dynamo.config
starting_hash = config.get_hash()
with config.patch({"verbose": not config.verbose}):
new_hash = config.get_hash()
assert "verbose" in config._compile_ignored_keys
assert new_hash == starting_hash
new_hash = config.get_hash()
assert new_hash == starting_hash
with config.patch({"dead_code_elimination": not config.dead_code_elimination}):
changed_hash = config.get_hash()
assert "dead_code_elimination" not in config._compile_ignored_keys
assert changed_hash != starting_hash
with config.patch({"verbose": not config.verbose}):
inner_changed_hash = config.get_hash()
assert inner_changed_hash == changed_hash
assert inner_changed_hash != starting_hash
newest_hash = config.get_hash()
assert changed_hash != newest_hash
assert newest_hash == starting_hash
@disable_cache_limit()
def test_no_saved_config(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(
cnt_dynamic, save_config=False
)(fn)
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 2)
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
for i in range(2, 12):
opt_fn_static_shape(
torch.randn(i), torch.randn(i)
)
self.assertEqual(cnt_dynamic.frame_count, 3)
@disable_cache_limit()
def test_no_saved_config_nested(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(fn)
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 2)
opt_fn_try_dynamic = torch._dynamo.optimize(
cnt_dynamic_1, save_config=False
)(opt_fn_static_shape)
for i in range(2, 6):
opt_fn_try_dynamic(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic_1.frame_count, 1)
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
for i in range(6, 12):
opt_fn_try_dynamic(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic_1.frame_count, 7)
@disable_cache_limit()
def test_config_changed_from_guarded_config_1(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
res = opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 2)
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
for i in range(2, 12):
opt_fn_static_shape(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_dynamic.frame_count, 10)
@disable_cache_limit()
def test_config_changed_from_guarded_config_2(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_dynamic_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn_dynamic_shape(torch.randn(2), torch.randn(2))
opt_fn_dynamic_shape(torch.randn(3), torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 1)
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
for i in range(2, 12):
opt_fn_dynamic_shape(
torch.randn(i), torch.randn(i)
)
self.assertEqual(cnt_dynamic.frame_count, 1)
@disable_cache_limit()
def test_nested_compile_outer_wins(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn_static_shape(torch.randn(2), torch.randn(2))
opt_fn_static_shape(torch.randn(3), torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 2)
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1)(
lambda x, y: opt_fn_static_shape(x, y)
)
for i in range(2, 12):
opt_fn_dynamic(
torch.randn(i), torch.randn(i)
)
self.assertEqual(cnt_dynamic.frame_count, 2)
self.assertEqual(cnt_dynamic_1.frame_count, 1)
@disable_cache_limit()
def test_nested_fn_does_not_inherit_outer_config(self):
def g1(x):
return x + 1
def g2(x):
return x * 2
def f(x):
x = g1(x)
torch._dynamo.graph_break()
return g2(x)
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
cnt_dynamic_1 = torch._dynamo.testing.CompileCounter()
opt_fn_static_shape = torch._dynamo.optimize(cnt_dynamic, dynamic=False)(f)
opt_fn_static_shape(torch.randn(2))
opt_fn_static_shape(torch.randn(3))
self.assertEqual(cnt_dynamic.frame_count, 4)
opt_fn_dynamic = torch._dynamo.optimize(cnt_dynamic_1, dynamic=True)(g2)
for i in range(2, 12):
opt_fn_dynamic(
torch.randn(i),
)
self.assertEqual(cnt_dynamic_1.frame_count, 1)
@disable_cache_limit()
def test_multiple_compile_recompiles(self):
cnt_dynamic = torch._dynamo.testing.CompileCounter()
def f(dynamic, compile_count):
@torch._dynamo.optimize(cnt_dynamic, dynamic=dynamic)
def g(x):
return x + 1
for i in range(2, 12):
g(torch.randn(i))
self.assertEqual(cnt_dynamic.frame_count, compile_count)
cnt_dynamic.clear()
f(dynamic=True, compile_count=1)
f(dynamic=False, compile_count=10)
f(dynamic=True, compile_count=0)
def test_cache_size_limit(self):
cnt = torch._dynamo.testing.CompileCounter()
key = "_ConfigTests___test_cache_size_limit_key"
try:
torch._dynamo.config._allowed_keys.add(key)
torch._dynamo.config._ConfigTests___test_cache_size_limit_key = -1
with torch._dynamo.config.patch(
{"cache_size_limit": 1, "accumulated_cache_size_limit": 10}
):
def g(x):
return x + 1
for i in range(12):
with torch._dynamo.config.patch(
{key: i % 6}
):
opt_g = torch._dynamo.optimize(cnt)(g)
opt_g(torch.randn(1))
self.assertEqual(cnt.frame_count, 6)
for i in range(6, 12):
with torch._dynamo.config.patch({key: i}):
opt_g = torch._dynamo.optimize(cnt)(g)
opt_g(torch.randn(1))
self.assertEqual(
cnt.frame_count, 10
)
finally:
if key in torch._dynamo.config._allowed_keys:
torch._dynamo.config._allowed_keys.remove(key)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()