import copy
import functools
import unittest
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
from torch._dynamo.test_case import TestCase
from torch.overrides import TorchFunctionMode
FilterFn = Callable[[list], list[bool]]
_DICT_GUARD_TYPES = frozenset(
{
"DICT_VERSION",
"DICT_KEYS",
"DICT_KEYS_MATCH",
"DICT_CONTAINS",
}
)
_OPTIONAL_TYPE_GUARD_TYPES = frozenset(
{
"TYPE_MATCH",
"OPTIONAL_TENSOR",
}
)
_HASATTR_GUARD_TYPES = frozenset(
{
"HASATTR",
"NOT_PRESENT_IN_GENERIC_DICT",
}
)
_RUNTIME_STATE_GUARD_TYPES = frozenset(
{
"GRAD_MODE",
"TORCH_FUNCTION_STATE",
"GLOBAL_STATE",
"DEFAULT_DEVICE",
"DETERMINISTIC_ALGORITHMS",
"AUTOCAST_STATE",
"FSDP_TRAINING_STATE",
}
)
_ORIGINAL_TORCH_COMPILE = None
_TORCH_COMPILE_STACK = []
def _entry_type(entry) -> str:
return getattr(entry, "guard_type", "") or ""
def make_npu_guard_filter(
*,
disable_dict_version: bool = False,
disable_optional_type: bool = False,
disable_hasattr: bool = False,
disable_runtime_state: bool = False,
) -> FilterFn:
def _filter(entries) -> list[bool]:
keep = [True] * len(entries)
for i, e in enumerate(entries):
t = _entry_type(e)
if disable_dict_version and t in _DICT_GUARD_TYPES:
keep[i] = False
continue
if disable_optional_type and t in _OPTIONAL_TYPE_GUARD_TYPES:
keep[i] = False
continue
if disable_hasattr and t in _HASATTR_GUARD_TYPES:
keep[i] = False
continue
if disable_runtime_state and t in _RUNTIME_STATE_GUARD_TYPES:
keep[i] = False
continue
return keep
return _filter
def _supports_dynamo_guard_filter_config() -> bool:
import torch._dynamo.config as cfg
return hasattr(cfg, "_config") and "guard_filter_fn" in cfg._config
def _set_wrapped_torch_compile(filter_fn: FilterFn | None) -> None:
global _ORIGINAL_TORCH_COMPILE
if _ORIGINAL_TORCH_COMPILE is None:
_ORIGINAL_TORCH_COMPILE = torch.compile
if filter_fn is None:
torch.compile = _ORIGINAL_TORCH_COMPILE
return
@functools.wraps(_ORIGINAL_TORCH_COMPILE)
def wrapped_compile(model=None, *args, **kwargs):
options = kwargs.get("options")
if options is None:
options = {}
else:
options = dict(options)
options.setdefault("guard_filter_fn", filter_fn)
kwargs["options"] = options
return _ORIGINAL_TORCH_COMPILE(model, *args, **kwargs)
torch.compile = wrapped_compile
def _install_global_guard_filter(filter_fn: FilterFn):
if _supports_dynamo_guard_filter_config():
import torch._dynamo.config as cfg
prev = getattr(cfg, "guard_filter_fn", None)
cfg.guard_filter_fn = filter_fn
return ("dynamo_config", prev)
previous = _TORCH_COMPILE_STACK[-1] if _TORCH_COMPILE_STACK else None
_TORCH_COMPILE_STACK.append(filter_fn)
_set_wrapped_torch_compile(filter_fn)
return ("torch_compile", previous)
def _restore_global_guard_filter(state) -> None:
kind, prev = state
if kind == "dynamo_config":
import torch._dynamo.config as cfg
cfg.guard_filter_fn = prev
return
if _TORCH_COMPILE_STACK:
_TORCH_COMPILE_STACK.pop()
if _TORCH_COMPILE_STACK:
_set_wrapped_torch_compile(_TORCH_COMPILE_STACK[-1])
else:
_set_wrapped_torch_compile(None)
def _get_installed_guard_filter():
if _supports_dynamo_guard_filter_config():
import torch._dynamo.config as cfg
return getattr(cfg, "guard_filter_fn", None)
compile_fn = torch.compile
closure = getattr(compile_fn, "__closure__", None) or ()
for cell in closure:
try:
value = cell.cell_contents
except ValueError:
continue
if callable(value) and getattr(value, "__name__", "") == "_filter":
return value
return None
class NpuGuardPolicy:
def __init__(
self,
*,
filter_fn: FilterFn | None = None,
disable_dict_version: bool = False,
disable_optional_type: bool = False,
disable_hasattr: bool = False,
disable_runtime_state: bool = False,
):
switches_used = any(
[
disable_dict_version,
disable_optional_type,
disable_hasattr,
disable_runtime_state,
]
)
if filter_fn is not None and switches_used:
raise ValueError("filter_fn and disable_* switches are mutually exclusive")
if filter_fn is not None:
self._fn = filter_fn
else:
self._fn = make_npu_guard_filter(
disable_dict_version=disable_dict_version,
disable_optional_type=disable_optional_type,
disable_hasattr=disable_hasattr,
disable_runtime_state=disable_runtime_state,
)
self._prev = None
self._active = False
def __enter__(self) -> "NpuGuardPolicy":
self._prev = _install_global_guard_filter(self._fn)
self._active = True
return self
def __exit__(self, exc_type, exc, tb) -> None:
if not self._active:
return
_restore_global_guard_filter(self._prev)
self._prev = None
self._active = False
@dataclass
class FakeEntry:
guard_type: str
name: str = ""
is_global: bool = False
has_value: bool = False
value: Any = None
def _entries(*types: str):
return [FakeEntry(guard_type=t, name=f"L[{i}]") for i, t in enumerate(types)]
def _make_mlp(seed: int = 0) -> nn.Module:
torch.manual_seed(seed)
return nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4))
def _recompile_count() -> int:
from torch._dynamo.utils import counters
return sum(counters["recompiles"].values())
def _reset_dynamo() -> None:
from torch._dynamo.utils import counters
torch._dynamo.reset()
counters.clear()
class GlobalTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class TestDictDimension(TestCase):
def test_drops_dict_family(self):
f = make_npu_guard_filter(disable_dict_version=True)
es = _entries(
"DICT_VERSION", "DICT_KEYS", "DICT_CONTAINS", "TENSOR_MATCH", "TYPE_MATCH"
)
self.assertEqual(f(es), [False, False, False, True, True])
def test_inactive_keeps_all(self):
f = make_npu_guard_filter()
es = _entries("DICT_VERSION", "TENSOR_MATCH")
self.assertEqual(f(es), [True, True])
class TestOptionalTypeDimension(TestCase):
def test_drops_type_family(self):
f = make_npu_guard_filter(disable_optional_type=True)
es = _entries("TYPE_MATCH", "OPTIONAL_TENSOR", "TENSOR_MATCH", "DICT_VERSION")
self.assertEqual(f(es), [False, False, True, True])
class TestHasattrDimension(TestCase):
def test_drops_hasattr_family(self):
f = make_npu_guard_filter(disable_hasattr=True)
es = _entries("HASATTR", "NOT_PRESENT_IN_GENERIC_DICT", "TENSOR_MATCH")
self.assertEqual(f(es), [False, False, True])
class TestRuntimeStateDimension(TestCase):
def test_drops_runtime_state_family(self):
f = make_npu_guard_filter(disable_runtime_state=True)
es = _entries(
"GRAD_MODE",
"TORCH_FUNCTION_STATE",
"GLOBAL_STATE",
"DEFAULT_DEVICE",
"DETERMINISTIC_ALGORITHMS",
"AUTOCAST_STATE",
"FSDP_TRAINING_STATE",
"TENSOR_MATCH",
)
self.assertEqual(f(es), [False] * 7 + [True])
class TestUpstreamGuardFilterCoverage(TestCase):
def setUp(self):
super().setUp()
_reset_dynamo()
def test_guard_filter_fn_by_id(self):
def guard_filter_fn(entries):
return [entry.guard_type != "ID_MATCH" for entry in entries]
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x):
return id(x)
inputs = (torch.randn(3, 2),)
fn(*inputs)
inputs_1 = (torch.randn(3, 2),)
with torch.compiler.set_stance("fail_on_recompile"):
self.assertEqual(fn(*inputs_1), id(inputs[0]))
def test_guard_filter_fn_by_is_global(self):
def guard_filter_fn(entries):
return [not entry.is_global for entry in entries]
global GLOBAL_INT
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x):
return x + GLOBAL_INT
GLOBAL_INT = 1
fn(torch.randn(3, 2))
GLOBAL_INT = 2
inputs = (torch.randn(3, 2),)
with torch.compiler.set_stance("fail_on_recompile"):
self.assertTrue(torch.equal(fn(*inputs), inputs[0] + 1))
def test_guard_filter_fn_by_name_and_value(self):
def guard_filter_fn(entries):
return [
not (entry.name == "y" and entry.value is None) for entry in entries
]
@torch.compile(fullgraph=True, options={"guard_filter_fn": guard_filter_fn})
def fn(x, y):
if y is not None:
x += y
return x
fn(torch.randn(3, 2), None)
inputs = (torch.randn(3, 2), torch.tensor(1))
with torch.compiler.set_stance("fail_on_recompile"):
self.assertTrue(torch.equal(fn(*inputs), inputs[0]))
def test_guard_filter_inbuilt_nn_modules(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.LayerNorm(8)
def forward(self, x):
return self.norm(x)
mod = Mod()
opt_mod = torch.compile(
mod,
options={
"guard_filter_fn": torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe
},
)
x = torch.rand(4, 8)
opt_mod(x)
mod.norm.eps = 1e-02
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
class TestCombinedSwitches(TestCase):
def test_union_drop(self):
f = make_npu_guard_filter(
disable_dict_version=True,
disable_optional_type=True,
disable_hasattr=True,
disable_runtime_state=True,
)
es = _entries(
"DICT_VERSION",
"TYPE_MATCH",
"HASATTR",
"GRAD_MODE",
"TENSOR_MATCH",
"ID_MATCH",
)
self.assertEqual(f(es), [False, False, False, False, True, True])
def test_returns_list_of_correct_length(self):
f = make_npu_guard_filter(disable_dict_version=True)
es = _entries("DICT_VERSION", "TENSOR_MATCH", "TYPE_MATCH")
keep = f(es)
self.assertIsInstance(keep, list)
self.assertEqual(len(keep), len(es))
class TestPolicy(TestCase):
def setUp(self):
super().setUp()
_reset_dynamo()
def test_installs_and_restores(self):
prev = _get_installed_guard_filter()
with NpuGuardPolicy(disable_dict_version=True):
self.assertIsNotNone(_get_installed_guard_filter())
self.assertNotEqual(_get_installed_guard_filter(), prev)
self.assertEqual(_get_installed_guard_filter(), prev)
def test_restores_on_exception(self):
prev = _get_installed_guard_filter()
with self.assertRaises(RuntimeError):
with NpuGuardPolicy(disable_runtime_state=True):
raise RuntimeError("boom")
self.assertEqual(_get_installed_guard_filter(), prev)
def test_nesting(self):
prev = _get_installed_guard_filter()
with NpuGuardPolicy(disable_dict_version=True):
outer_fn = _get_installed_guard_filter()
with NpuGuardPolicy(disable_runtime_state=True):
self.assertNotEqual(_get_installed_guard_filter(), outer_fn)
self.assertEqual(_get_installed_guard_filter(), outer_fn)
self.assertEqual(_get_installed_guard_filter(), prev)
def test_with_custom_filter_fn(self):
called = {"n": 0}
def my_filter(entries):
called["n"] += 1
return [True] * len(entries)
with NpuGuardPolicy(filter_fn=my_filter):
c = torch.compile(_make_mlp())
c(torch.randn(2, 8))
self.assertGreaterEqual(called["n"], 1)
def test_filter_fn_and_switches_are_exclusive(self):
with self.assertRaises(ValueError):
NpuGuardPolicy(
filter_fn=lambda e: [True] * len(e), disable_dict_version=True
)
class TestParity(TestCase):
def setUp(self):
super().setUp()
_reset_dynamo()
def _parity(self, **switches):
torch.manual_seed(0)
m = _make_mlp().eval()
x = torch.randn(2, 8)
with torch.no_grad():
ref = m(x)
m2 = copy.deepcopy(m)
c = torch.compile(
m2, options={"guard_filter_fn": make_npu_guard_filter(**switches)}
)
with torch.no_grad():
got = c(x)
self.assertTrue(torch.allclose(got, ref, atol=1e-5, rtol=1e-5))
def test_parity_dict(self):
self._parity(disable_dict_version=True)
def test_parity_optional_type(self):
self._parity(disable_optional_type=True)
def test_parity_hasattr(self):
self._parity(disable_hasattr=True)
def test_parity_runtime_state(self):
self._parity(disable_runtime_state=True)
def test_parity_all(self):
self._parity(
disable_dict_version=True,
disable_optional_type=True,
disable_hasattr=True,
disable_runtime_state=True,
)
class TestRuntimeStateRecompileBehavior(TestCase):
def setUp(self):
super().setUp()
_reset_dynamo()
def test_no_recompile_across_grad_mode_change(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
compiled_fn = torch.compile(
foo,
options={
"guard_filter_fn": make_npu_guard_filter(disable_runtime_state=True)
},
)
with torch.no_grad():
compiled_fn(x)
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
self.assertTrue(torch.equal(compiled_fn(x), foo(x)))
def test_no_recompile_across_torch_function_mode_change(self):
def foo(x):
return x + 1
x = torch.randn(3, 2)
with GlobalTorchFunctionMode():
compiled_fn = torch.compile(
foo,
options={
"guard_filter_fn": make_npu_guard_filter(disable_runtime_state=True)
},
)
compiled_fn(x)
with torch.compiler.set_stance("fail_on_recompile"):
self.assertTrue(torch.equal(compiled_fn(x), foo(x)))
class TestRecompileCount(TestCase):
def setUp(self):
super().setUp()
_reset_dynamo()
def test_no_recompile_on_repeated_same_shape(self):
m = _make_mlp()
c = torch.compile(
m,
options={
"guard_filter_fn": make_npu_guard_filter(
disable_dict_version=True,
disable_optional_type=True,
disable_hasattr=True,
disable_runtime_state=True,
)
},
)
for _ in range(5):
c(torch.randn(2, 8))
self.assertEqual(_recompile_count(), 0)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()