import collections
import itertools
import traceback
import types as tps
import unittest
from copy import deepcopy
from functools import partial
from typing import Tuple
from unittest.mock import patch
import torch
import torch_npu
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.nn.functional as F
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter
try:
from . import test_functions
except ImportError:
import test_functions
class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.scale = torch.randn(1, 10)
def forward(self, x):
return F.relu(self.linear1(x)) * self.scale
class FnMember(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.activation = F.relu
def forward(self, x):
x = self.linear1(x)
if self.activation:
x = self.activation(x)
return x
class FnMemberCmp(torch.nn.Module):
def __init__(self, activation):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.activation = activation
def forward(self, x):
x = self.linear1(x)
if self.activation is not None:
x = self.activation(x)
if self.activation is None:
x = torch.sigmoid(x)
return x
class SubmoduleExample(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x * self.scale
class IsTrainingCheck(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
self.train(True)
def forward(self, x):
if self.training:
mod = self.linear1
else:
mod = self.linear2
return F.relu(mod(x))
class IsEvalCheck(IsTrainingCheck):
def __init__(self):
super().__init__()
self.train(False)
class ModuleMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
def call_and_scale(self, mod, x):
x = mod(x)
return x * self.scale
def forward(self, x):
x1 = self.call_and_scale(self.layer1, x)
x2 = self.call_and_scale(self.layer2, x)
return x1 + x2
class UnsupportedMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.scale = torch.randn(1, 10)
def call_and_scale(self, mod, x):
x = mod(x)
x = x * self.scale
return unsupported(x, x)
def forward(self, x):
x1 = self.call_and_scale(self.layer1, x)
return x + x1
class UnsupportedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.scale = torch.randn(1, 10)
def forward(self, x):
x = self.layer1(x) * self.scale
return unsupported(x, x)
class UnsupportedModuleCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = UnsupportedModule()
def forward(self, x):
return 1 + self.mod(x * 1.5)
class ModuleWithStaticForward(torch.nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
class ModuleCallModuleWithStaticForward(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = ModuleWithStaticForward()
def forward(self, x):
return self.mod(x)
class ModuleStaticMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
@staticmethod
def call_and_scale(scale, mod, x):
x = mod(x)
return x * scale
def forward(self, x):
x1 = self.call_and_scale(self.scale, self.layer1, x)
x2 = self.call_and_scale(self.scale, self.layer2, x)
return x1 + x2
class ModuleClassMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
@classmethod
def call_and_scale(cls, scale, mod, x):
x = mod(x)
return x * scale
def forward(self, x):
x1 = self.call_and_scale(self.scale, self.layer1, x)
x2 = self.call_and_scale(self.scale, self.layer2, x)
return x1 + x2
class ModuleProperty(torch.nn.Module):
def __init__(self):
super().__init__()
self.scale = torch.randn(1, 10)
@property
def scale_alias(self):
return self.scale
def forward(self, x):
return x * self.scale_alias
class NestedModuleList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList([])
for _ in range(3):
self.layers.append(
torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
)
def forward(self, x):
for layer, act in self.layers:
x = act(layer(x))
return x
class ConstLoop(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.count = 3
def forward(self, x):
for i in range(self.count):
x = torch.sigmoid(self.linear1(x))
return x
class ViaModuleCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
def forward(self, x):
return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
class IsNoneLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = None
self.train(True)
def forward(self, x):
if self.layer1 is not None:
x = self.layer1(x)
if self.layer2 is not None:
x = self.layer2(x)
return x
class LayerList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = [
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class ModuleList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
def forward(self, x):
for i, _ in enumerate(self.layers):
x = self.layers[i](x)
for layer in self.layers:
x = layer(x)
for layer, val in zip(self.layers, (x, x, x, x)):
x = layer(x) + val
for layer, val in zip(self.layers, (1, 2, 3, 4)):
x = layer(x) + val
for idx, layer in enumerate(self.layers):
x = layer(x) * idx
for idx, layer in enumerate(self.layers[::-1]):
x = layer(x) * idx
return x
class CustomGetItemModuleList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
def __getitem__(self, idx: int):
return self.layers[idx]
def forward(self, x):
for i, _ in enumerate(self.layers):
x = self[i](x)
return x
class ModuleDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleDict(
{
"0": torch.nn.Linear(10, 10),
}
)
def forward(self, x):
x = self.layers["0"](x)
return x
class ParameterDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ParameterDict(
{
"0": torch.nn.Parameter(torch.randn(10, 10)),
}
)
def forward(self, x):
x = self.layers["0"].mm(x)
return x
class CustomGetItemParameterDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ParameterDict(
{
"0": torch.nn.Parameter(torch.randn(10, 10)),
}
)
def __getitem__(self, key: str) -> torch.nn.Module:
return self.layers[key]
def forward(self, x):
x = self["0"].mm(x)
return x
class CustomGetItemModuleDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleDict(
{
"0": torch.nn.Linear(10, 10),
}
)
def __getitem__(self, key: str) -> torch.nn.Module:
return self.layers[key]
def forward(self, x):
x = self["0"](x)
return x
class TensorList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = (
torch.randn((1, 10)),
torch.randn((10, 1)),
torch.randn((1, 10)),
torch.randn((10, 1)),
)
def forward(self, x):
for layer in self.layers:
x = x * layer
return x
class Children(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(10, 10)
self.l2 = torch.nn.ReLU()
self.l3 = torch.nn.Linear(10, 10)
self.l4 = torch.nn.ReLU()
def forward(self, x):
for block in self.children():
x = block(x)
return x
class NamedChildren(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(10, 10)
self.l2 = torch.nn.ReLU()
self.l3 = torch.nn.Linear(10, 10)
self.l4 = torch.nn.ReLU()
def forward(self, x):
for _, block in self.named_children():
x = block(x)
return x
class IntArg(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
def forward(self, x, offset=1):
x = F.relu(self.layer1(x)) + offset
return x
class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
class Cfg:
def __init__(self):
self.val = 0.5
self.count = 3
class CfgModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.cfg = Cfg()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
for i in range(self.cfg.count):
x = self.layer(x + self.cfg.val)
return x
class StringMember(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.mode = "some_string"
def forward(self, x):
if self.mode == "some_string":
return F.relu(self.linear1(x))
class _Block(torch.nn.Module):
def forward(self, x):
return 1.5 * torch.cat(x, 1)
class _DenseBlock(torch.nn.ModuleDict):
_version = 2
def __init__(
self,
num_layers: int = 3,
) -> None:
super().__init__()
for i in range(num_layers):
self.add_module("denselayer%d" % (i + 1), _Block())
def forward(self, init_features):
features = [init_features]
for layer in self.values():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class DenseNetBlocks(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = _DenseBlock()
def forward(self, x):
return self.layers(x)
class MaterializedModule(torch.nn.Module):
"""Once the below lazy module is initialized with its first input,
it is transformed into this module.
"""
param: Parameter
def __init__(self):
super().__init__()
self.register_parameter("param", None)
def forward(self, x):
return x
class LazyModule(LazyModuleMixin, MaterializedModule):
param: UninitializedParameter
cls_to_become = MaterializedModule
def __init__(self):
super().__init__()
self.param = UninitializedParameter()
def initialize_parameters(self, x):
torch._dynamo.graph_break()
self.param.materialize(x.shape)
class LazyMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.LazyLinear(10)
self.relu1 = torch.nn.ReLU()
self.fc2 = torch.nn.LazyLinear(1)
self.relu2 = torch.nn.ReLU()
def forward(self, ipt):
x = self.relu1(self.fc1(ipt))
y = self.relu2(self.fc2(x))
return y
class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module):
def __init__(self):
super().__init__()
def initialize_parameters(self, ipt):
with torch.no_grad():
self._param = torch.nn.Parameter(torch.empty(ipt[0].shape).fill_(0.5))
def forward(self, ipt):
x = 0
for i, _ in enumerate(ipt):
x = x + ipt[i]
return x
class LazyModuleWithListInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = LazyLayerWithListInput()
def forward(self, ipt):
return self.layer(ipt[:-1])
class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module):
def __init__(self):
super().__init__()
def initialize_parameters(self, ipt):
with torch.no_grad():
self.layer = LazyLayerWithListInput()
def forward(self, x):
return self.layer(x)
class LazyParentModule(LazyModuleMixin, torch.nn.Module):
def __init__(self):
super().__init__()
def impl(self, x):
return x.cos() + self._val
class LazyChildModuleNoClsToBecome(LazyParentModule):
def __init__(self):
super().__init__()
def forward(self, x):
return super().impl(x.sin())
def initialize_parameters(self, ipt):
self._val = torch.nn.Parameter(torch.ones(2, 2))
def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
requires_grad = any(p.requires_grad for p in module.parameters(recurse))
return requires_grad
def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
requires_grad = any(p.requires_grad for p in module.parameters(recurse))
return requires_grad
class ParametersModule1(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.scale = torch.nn.Parameter(torch.randn(1, 10))
def forward(self, x):
if not requires_grad1(self):
return F.relu(self.linear1(x)) * self.scale
else:
return x + 1
class ParametersModule2(ParametersModule1):
def forward(self, x):
if not requires_grad2(self):
return F.relu(self.linear1(x)) * self.scale
else:
return x + 1
class ParametersModule3(ParametersModule1):
def forward(self, x):
ones = torch.ones(10, dtype=next(self.parameters()).dtype)
return F.relu(self.linear1(x)) * self.scale + ones
class SuperModule(BasicModule):
def forward(self, x):
x = super().forward(x)
return x + 10.0
class SuperModule2(BasicModule):
def forward(self, x):
return BasicModule.forward(self, x)
class ComplicatedSuperParent(torch.nn.Module):
@classmethod
def custom_add(cls, x):
x = x + x
return x
class SuperChildCallsClassMethod(ComplicatedSuperParent):
@classmethod
def child_func(cls, x):
x = super().custom_add(x)
return x
def forward(self, x):
x = self.child_func(x)
return x
class HasAttrModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.scale = torch.nn.Parameter(torch.randn(1, 10))
def forward(self, x):
x = F.relu(x)
if hasattr(self, "scale"):
x *= self.scale
if hasattr(self, "scale2"):
x *= self.scale2
return x
class EnumValues(torch.nn.ModuleDict):
def __init__(
self,
num_layers: int = 3,
) -> None:
super().__init__()
for i in range(num_layers):
self.add_module("denselayer%d" % (i + 1), _Block())
def forward(self, init_features):
features = [init_features]
for idx, layer in enumerate(self.values()):
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class AccessByKeys(torch.nn.ModuleDict):
def __init__(
self,
num_layers: int = 3,
) -> None:
super().__init__()
for i in range(num_layers):
self.add_module("denselayer%d" % (i + 1), _Block())
def forward(self, init_features):
features = [init_features]
for k in self.keys():
new_features = self[k](features)
features.append(new_features)
return torch.cat(features, 1)
class CallForwardDirectly(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.layer1.forward(x)
x = self.layer2.forward(x)
return x
class ConvCallForwardDirectly(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
def forward(self, x):
return self.layer.forward(x)
class ConvTransposeCallForwardDirectly(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.ConvTranspose2d(4, 4, 4)
def forward(self, x):
return self.layer.forward(x)
class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
**kwargs,
)
def forward(self, inputs, mask=None):
outputs = super().forward(inputs)
return outputs
class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
**kwargs,
)
def forward(self, x):
if x.numel() > 0:
return super().forward(x)
zip_x = zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride,
self.output_padding,
)
output_shape = [
((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
for i, p, di, k, d, op in zip_x
]
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
class ModuleNameString(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
def forward(self, x):
if self.__class__.__name__ == "ABC":
return 10
if self.linear1.__class__.__name__ == "Linear":
return F.relu(self.linear1(x) + 10)
return 11
class SelfMutatingModule(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.counter = 0
def forward(self, x):
result = self.layer(x) + self.counter
self.counter += 1
return F.relu(result)
class ModuleAttributePrecedenceBase(torch.nn.Module):
def linear(self, x):
return x * 2.0
class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
def __init__(self):
super().__init__()
self.activation = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.initializer = torch.ones([10, 10])
self.scale = 0.5
def activation(self, x):
return x * 1.2
def initializer(self):
return torch.zeros([10, 10])
def scale(self):
return 2.0
def forward(self, x):
return self.activation(self.linear(self.initializer + x)) * self.scale
class ModuleForwardHasGraphBreak(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
self.layer4 = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
self.layer5 = torch.nn.ModuleDict(
{
"0": torch.nn.Linear(10, 10),
}
)
self.scale = torch.randn(1, 10)
def forward(self, x):
"""
This is used to test if the results of functions like `named_parameters`
can be reconstructed correctly after graph break.
See pytorch/torchdynamo/issues/1931
"""
x = self.layer1(x)
params1 = dict(self.named_parameters())
params2 = list(self.parameters())
buffers1 = dict(self.named_buffers())
buffers2 = list(self.buffers())
modules1 = dict(self.named_modules())
modules2 = list(self.modules())
torch._dynamo.graph_break()
y = modules2
y = modules1
y = buffers2
y = buffers1
y = params2
y = params1
x = (
self.layer2(x)
+ y["layer3.1.linear1.weight"]
+ y["layer4.2.weight"]
+ y["layer5.0.weight"]
)
return x * self.scale
class ModuleGuardNameIsValid(torch.nn.ModuleDict):
def __init__(self):
super().__init__()
for i in range(2):
self.add_module("l@yer-%d" % (i + 1), BasicModule())
def forward(self, x):
for layer in self.values():
x = layer(x)
return x
class SequentialWithDuplicatedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.layer = torch.nn.Sequential(
torch.nn.Linear(10, 20),
self.relu,
torch.nn.Linear(20, 20),
self.relu,
torch.nn.Linear(20, 10),
self.relu,
)
def forward(self, x):
return self.layer(x)
class SequentialWithDuplicatedModule2(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.layer = torch.nn.Sequential(
collections.OrderedDict(
[
("linear1", torch.nn.Linear(10, 20)),
("relu1", self.relu),
("linear2", torch.nn.Linear(20, 20)),
("relu2", self.relu),
("linear3", torch.nn.Linear(20, 10)),
("relu3", self.relu),
]
)
)
def forward(self, x):
return self.layer(x)
class ModuleComparison(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = torch.nn.Linear(10, 10)
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = torch.nn.Linear(10, 10)
@property
def encoder_layers(self):
return [self.layer0, self.layer1, self.layer2]
def forward(self, x):
for layer in self.encoder_layers:
output = layer(x)
if layer is None or layer == self.layer0:
output = F.relu6(output)
else:
output = F.relu(output)
return output
class ModulePatch1(torch.nn.Module):
pass
class ModulePatch2(torch.nn.Module):
def forward(self, x):
return x - 1
class UnspecNonInlinableModule(torch.nn.Module):
torchdynamo_force_dynamic = True
def forward(self, x):
if x.sum() > 0:
return x + 1
else:
return x - 1
class UnspecNonInlinableToplevelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = UnspecNonInlinableModule()
def forward(self, x):
return self.m(x)
def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
self, fn=fn, nargs=1, expected_ops=expected_ops
)
fn.eval()
return test_fn
class NNModuleTests(torch._dynamo.test_case.TestCase):
test_seq = make_test(Seq())
test_basicmodule1 = make_test(BasicModule())
test_basicmodule2 = make_test(BasicModule())
test_submodules1 = make_test(SubmoduleExample())
test_submodules2 = make_test(SubmoduleExample())
test_modulemethod1 = make_test(ModuleMethodCall())
test_modulemethod2 = make_test(ModuleMethodCall())
test_module_call_module_with_static_forward = make_test(
ModuleCallModuleWithStaticForward()
)
test_module_static_method = make_test(ModuleStaticMethodCall())
test_fnmember = make_test(FnMember())
test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
test_fnmembercmp2 = make_test(FnMemberCmp(None))
test_constloop = make_test(ConstLoop())
test_istraining1 = make_test(IsTrainingCheck())
test_istraining2 = make_test(IsTrainingCheck())
test_iseval1 = make_test(IsEvalCheck())
test_iseval2 = make_test(IsEvalCheck())
test_viamodulecall = make_test(ViaModuleCall())
test_isnonelayer = make_test(IsNoneLayer())
test_layerlist = make_test(LayerList())
test_tensorlist = make_test(TensorList())
test_intarg = make_test(IntArg())
test_cfgmod = make_test(CfgModule())
test_stringmember = make_test(StringMember())
test_modulelist = make_test(ModuleList())
test_modulelist_nested = make_test(NestedModuleList())
test_modulelist_custom = make_test(CustomGetItemModuleList())
test_moduledict = make_test(ModuleDict())
test_moduledict_custom = make_test(CustomGetItemModuleDict())
test_parameterdict = make_test(ParameterDict())
test_parameterdict_custom = make_test(CustomGetItemParameterDict())
test_super1 = make_test(SuperModule())
test_super2 = make_test(SuperModule2())
test_super_class_method = make_test(SuperChildCallsClassMethod())
test_children = make_test(Children())
test_named_children = make_test(NamedChildren())
test_densenet = make_test(DenseNetBlocks())
test_parameters1 = make_test(ParametersModule1())
test_parameters2 = make_test(ParametersModule2())
test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
test_hasattr = make_test(HasAttrModule())
test_enumvalues = make_test(EnumValues())
test_access_by_keys = make_test(AccessByKeys())
test_module_class_method = make_test(ModuleClassMethodCall())
test_module_property = make_test(ModuleProperty())
test_forward_directly = make_test(CallForwardDirectly())
test_module_name_string = make_test(ModuleNameString())
test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid())
test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule())
test_sequential_with_duplicated_module2 = make_test(
SequentialWithDuplicatedModule2()
)
test_module_comparison = make_test(ModuleComparison())
def test_module_forward_has_graph_break(self):
m = ModuleForwardHasGraphBreak()
x = torch.rand([10, 10])
ref = m(x)
opt_m = torch._dynamo.optimize("eager")(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
def test_unsupportedmethod(self):
m = UnsupportedMethodCall()
i = torch.randn(10)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt)(m)
r = opt_m(i)
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 5)
def test_unsupportedmodule(self):
m = UnsupportedModuleCall()
i = torch.randn(10)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt)(m)
r = opt_m(i)
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 6)
def test_self_mutating1(self):
m1 = torch.nn.Linear(10, 10)
m2 = SelfMutatingModule(m1)
m3 = SelfMutatingModule(m1)
m4 = SelfMutatingModule(m1)
i = torch.randn(10)
out2 = [m2(i), m2(i), m2(i)]
cnt = torch._dynamo.testing.CompileCounter()
opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
self.assertTrue(torch._dynamo.testing.same(out2, out3))
self.assertTrue(torch._dynamo.testing.same(out2, out4))
self.assertEqual(cnt.frame_count, 3)
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_generation_tag(self):
cnt = torch._dynamo.testing.CompileCounter()
with torch._dynamo.optimize_assert(cnt):
pass
m1 = torch.nn.Linear(10, 10)
prev_generation = GenerationTracker.get_generation_value(m1)
cur_generation = prev_generation + 1
with torch._dynamo.optimize_assert(cnt):
m2 = torch.nn.Linear(10, 10)
self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
m3 = deepcopy(m1)
self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
def test_simple_torch_function(self):
def foo(x):
x = F.sigmoid(x)
x = F.sigmoid(x)
x = x.sigmoid()
x = x.sigmoid()
return x
class TensorProxy(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
try:
x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
finally:
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
def test_torch_function_with_closure(self):
def run():
counter = 0
def foo(x):
x = F.sigmoid(x)
x = F.sigmoid(x)
x = x.sigmoid()
x = x.sigmoid()
return x
class TensorProxy(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
nonlocal counter
counter + 1
return super().__torch_function__(func, types, args, kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
try:
x = torch.randn(1).as_subclass(TensorProxy)
x = torch.randn(1)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
finally:
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
run()
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_nn_moduledict_contains(self):
class M(torch.nn.Module):
def __init__(self, module_dict):
super().__init__()
self.module_dict = module_dict
def forward(self, x):
if "foo" in self.module_dict:
x = torch.mul(x, 1.0)
x = torch.add(x, 1.0)
return x
module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
out2 = opt_m(data)
self.assertEqual(cnt.op_count, 2)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.reset()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
out2 = opt_m(data)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
pre = m(data)
cnt.clear()
with torch._dynamo.optimize(cnt, nopython=False):
opt_pre = m(data)
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
out_post = m(data)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
@expectedFailureDynamic
def test_lazy_module1(self):
input_shape = (16, 3, 6, 7, 8)
cnt = torch._dynamo.testing.CompileCounter()
module = LazyModule()
def test_static_module():
ipt = torch.ones(*input_shape)
module(ipt)
opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)(
test_static_module
)
opt_test_static_module()
self.assertTrue(
isinstance(module, MaterializedModule),
"Module should be transformed to an instance of MaterializedModule.",
)
self.assertEqual(module.param.shape, input_shape)
module = LazyModule()
def test_unspecialized():
nonlocal module
module = LazyModule()
ipt = torch.ones(*input_shape)
module(ipt)
opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
opt_test_unspecialized()
self.assertTrue(
isinstance(module, MaterializedModule),
"Module should be transformed to an instance of MaterializedModule.",
)
self.assertEqual(module.param.shape, input_shape)
module = torch.nn.modules.LazyBatchNorm3d(
affine=False, track_running_stats=False
)
cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.reset()
def test_torch_static():
ipt = torch.ones(*input_shape)
return module(ipt)
opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)(
test_torch_static
)
opt_test_torch_static()
out = opt_test_torch_static()
self.assertTrue(same(out, module(torch.ones(*input_shape))))
self.assertTrue(
isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
"Module should be transformed to an instance of BatchNorm3d.",
)
self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
@expectedFailureDynamic
def test_lazy_module2(self):
m = LazyMLP()
x = torch.rand([10, 10])
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
@expectedFailureDynamic
@unittest.skipIf(not torch.npu.is_available(), "requires npu")
def test_lazy_module3(self):
m = LazyMLP()
x = torch.rand([10, 10])
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
m = m.to("npu:0")
x = x.to("npu:0")
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
self.assertEqual(cnt.frame_count, 2)
@expectedFailureDynamic
def test_lazy_module4(self):
m = LazyMLP()
x = torch.rand([10, 10])
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
x = torch.rand([20, 20])
try:
opt_m(x)
except RuntimeError:
self.assertIn("must have same reduction dim", traceback.format_exc())
@expectedFailureDynamic
def test_lazy_module5(self):
m = LazyModuleWithListInput()
x = [torch.rand([5, 5])] * 3 + [None]
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
@expectedFailureDynamic
def test_lazy_module6(self):
m = LazyModuleWithLazySubmodule()
x = [torch.rand([5, 5])] * 3
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
def test_lazy_module_no_cls_to_become(self):
m = LazyChildModuleNoClsToBecome()
x = torch.rand(2, 2)
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
res = opt_m(x)
ref = m(x)
self.assertTrue(torch.allclose(ref, res))
def test_call_fn_with_non_const_inputs_safe(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=20, kernel_size=(5, 5)
)
def _conv_forward(self, x):
return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
def forward(self, x):
return self._conv_forward(x)
mod = ModuleSpecialFwd()
rx = torch.randn([3, 10, 10])
real = mod(rx)
graph, _ = torch._dynamo.export(mod)(rx)
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
def test_conv_call_forward_directly(self):
m = ConvCallForwardDirectly()
x = torch.rand([4, 3, 9, 9])
ref = m(x)
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
def test_conv_transpose_call_forward_directly(self):
m = ConvTransposeCallForwardDirectly()
x = torch.rand([4, 4, 4, 4])
ref = m(x)
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
def test_conv_call_super_forward_directly(self):
x = torch.randn(4, 4)
m = ConvCallSuperForwardDirectly(4, 4, 4)
ref = m(x)
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
def test_conv_transpose_call_super_forward_directly(self):
x = torch.randn(4, 4, 4)
m = ConvTransposeCallSuperForwardDirectly(4, 4, 4)
ref = m(x)
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
def forward(self, x):
return self.relu(self.linear(x) + self.buf0)
class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
def test_nn_module(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)
def test_to(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
x = torch.randn(10, 10)
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
self.assertEqual(cnt.frame_count, 1)
opt_mod(x)
self.assertEqual(cnt.frame_count, 1)
opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
x = torch.randn(10, 10).to(dtype=torch.float64)
opt_mod(x)
self.assertEqual(cnt.frame_count, 2)
opt_mod(x)
self.assertEqual(cnt.frame_count, 2)
torch._dynamo.reset()
opt_mod(x)
self.assertEqual(cnt.frame_count, 3)
def test_attr(self):
class MockModule_attr(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
mod = MockModule_attr()
opt_mod = torch._dynamo.optimize("eager")(mod)
for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))
for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
self.assertTrue(id(b1) == id(b2))
def get_parameter_dtype(mod: torch.nn.Module):
parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
return next(parameters_and_buffers).dtype
opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
out_dtype = opt_mod(mod)
self.assertEqual(out_dtype, torch.float32)
def test_dir(self):
class MockModule_dir(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.register_parameter(
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
)
def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
mod = MockModule_dir()
mod_keys = dir(mod)
opt_mod = torch._dynamo.optimize("eager")(mod)
opt_mod_keys = dir(opt_mod)
self.assertIn("linear", opt_mod_keys)
self.assertIn("buf0", opt_mod_keys)
self.assertIn("param0", opt_mod_keys)
self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0)
def test_no_recompile_on_nn_guarded_modules(self):
size = (10, 10)
cache_size_limit = 1
num_submodules = 4
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(*size)
def forward(self, x):
a = torch.sin(torch.cos(x))
return self.linear(a)
class MockModule_toy(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = [SubModule() for _ in range(num_submodules)]
self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
def forward(self, x):
for mod in self.mods:
x = mod(x)
return x
mod = MockModule_toy()
with unittest.mock.patch(
"torch._dynamo.config.error_on_recompile", True
), unittest.mock.patch(
"torch._dynamo.config.cache_size_limit",
cache_size_limit,
):
x = torch.randn(*size)
mod(x)
self.assertEqual(cnts.frame_count, num_submodules)
def test_cache_size_limit_on_guarded_nn_modules(self):
cache_size_limit = 2
num_submodules = 4
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
a = torch.sin(torch.cos(x))
return self.relu(a)
class MockModule_1(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = [SubModule() for _ in range(num_submodules)]
self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
def forward(self, x):
for mod in self.mods:
x = mod(x)
return x
mod = MockModule_1()
with unittest.mock.patch(
"torch._dynamo.config.cache_size_limit",
cache_size_limit,
):
for size in [
(4,),
(4, 4),
(4, 4, 4),
]:
x = torch.randn(size)
mod(x)
self.assertEqual(cnts.frame_count, 2 * num_submodules)
def test_recursion(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch._dynamo.optimize(cnt)(mod)
for _ in range(5):
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
opt_mod(torch.randn(10, 10))
self.assertEqual(cnt.frame_count, 1)
def test_composition(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.sin(x))
opt_inner_mod = InnerModule()
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod
def forward(self, x):
return self.mod(torch.cos(x))
outer_mod = OuterModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
self.assertEqual(cnt.frame_count, 1)
def test_composition_with_opt_mod(self):
class InnerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.sin(x))
inner_mod = InnerModule()
cnt = torch._dynamo.testing.CompileCounter()
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = opt_inner_mod
def forward(self, x):
return self.mod(torch.cos(x))
outer_mod = OuterModule()
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
x = torch.randn(4)
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
self.assertEqual(cnt.frame_count, 2)
def test_module_patch(self):
mod = ModulePatch1()
mod.forward = tps.MethodType(ModulePatch2.forward, mod)
def fn(x):
return mod(x)
self.assertTrue(
torch.allclose(
torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)),
torch.zeros(1),
)
)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_outer(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
handle = m.register_forward_hook(forward_hook)
inp = torch.tensor(1.0, requires_grad=True)
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
compiled_m = torch._dynamo.optimize(
guard_fail_fn=guard_fail_fn, backend="eager"
)(m)
self.assertEqual(compiled_m(inp), m(inp))
self.assertEqual(compiled_m(inp).item(), 7)
self.assertTrue(failure_reason is None)
handle.remove()
self.assertEqual(compiled_m(inp), m(inp))
self.assertEqual(compiled_m(inp).item(), 3)
"""
Summary:
- removing a hook doesn't fail a guard, because we weren't compiling the hook
(at least into the same graph) as forward in the first place! We do correctly
omit calling the removed hook, but since this hook is a post forward hook,
the 'RETURN' from forward is breaking the graph.
Why is 'forward' the entrypoint to an InstructionTranslator, after I changed
the eval_frame entrypoint to Module.__call__?
"""
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_inner(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
handle = m.register_forward_hook(forward_hook)
def outer_func(tensor):
x = tensor * 2 + 1
y = m(x)
return y
inp = torch.tensor(1.0, requires_grad=True)
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_func = torch._dynamo.optimize(
guard_fail_fn=guard_fail_fn,
backend=cc,
)(outer_func)
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
self.assertEqual(cc.op_count, 6)
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)
torch._dynamo.reset()
m = TestModule()
handle = m.register_forward_hook(forward_hook)
failure_reason = None
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
def new_forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 2
m._forward_hooks[handle.id] = new_forward_hook
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 16)
self.assertRegex(
failure_reason, r"^___check_obj_id\(.*\(L\['m'\]\._forward_hooks"
)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
def test_hooks_skip_guards(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
handle = m.register_forward_hook(forward_hook)
def outer_func(tensor):
x = tensor * 2 + 1
y = m(x)
return y
inp = torch.tensor(1.0, requires_grad=True)
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_func = torch._dynamo.optimize(
guard_fail_fn=guard_fail_fn,
backend=cc,
)(outer_func)
m = TestModule()
handle = m.register_forward_hook(forward_hook)
failure_reason = None
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
self.assertEqual(cc.op_count, 6)
handle.remove()
self.assertNotEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
def _forward_hook_test_helper(self, model):
forward_handles = {}
compiled_activations = dict()
eager_activations = dict()
activations = None
def save_activations(name, mod, inp, out):
activations[name] = inp
for name, module in model.named_modules():
forward_handles[name] = module.register_forward_hook(
partial(save_activations, name)
)
compiled_model = torch.compile(model, backend="aot_eager")
activations = compiled_activations
for i in range(2):
compiled_activations.clear()
x = torch.randn((20, 10))
pred = compiled_model(x)
loss = pred.sum()
loss.backward()
activations = eager_activations
for i in range(2):
eager_activations.clear()
x = torch.randn((20, 10))
pred = model(x)
loss = pred.sum()
loss.backward()
print(f"Recorded Layers: {compiled_activations.keys()}\n\n")
print(f"Expected Layers: {eager_activations.keys()}")
self.assertTrue(compiled_activations.keys() == eager_activations.keys())
self.assertTrue(activations.keys() == forward_handles.keys())
def test_hooks_allowed_modules(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
)
def forward(self, x):
return self.net(x)
model = ToyModel()
self._forward_hook_test_helper(model)
def test_hooks_allowed_modules_compiles(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
)
def forward(self, x):
return self.net(x)
model = ToyModel()
activations = []
def save_activations(mod, inp, out):
activations.append(inp)
for name, module in model.named_modules():
module.register_forward_hook(save_activations)
cnt = torch._dynamo.testing.CompileCounter()
model = torch._dynamo.optimize(cnt, nopython=True)(model)
for i in range(2):
activations.clear()
x = torch.randn((20, 10))
pred = model(x)
loss = pred.sum()
loss.backward()
self.assertEqual(len(activations), 6)
self.assertEqual(cnt.frame_count, 1)
def test_hooks_allowed_modules_compiles_self_contained(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
)
def forward(self, x):
return self.net(x) * self.net(x)
model = ToyModel()
forward_handles = {}
def output_modifying_hook(mod, inp, out):
return 2 * out + 1
for name, module in model.named_modules():
forward_handles[name] = module.register_forward_hook(output_modifying_hook)
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn((20, 10))
pred_eager = model(x)
loss_eager = pred_eager.sum()
eager_loss_bwd = loss_eager.backward()
model = torch._dynamo.optimize(cnt, nopython=True)(model)
pred = model(x)
loss = pred.sum()
loss_bwd = loss.backward()
self.assertEqual(eager_loss_bwd, loss_bwd)
self.assertEqual(cnt.frame_count, 2)
pred = model(torch.randn([10, 10, 10]))
self.assertEqual(cnt.frame_count, 4)
pred = model(torch.randn([10, 10, 10]))
self.assertEqual(cnt.frame_count, 4)
def test_dunder_call_explicitly(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10000)
def forward(self, x):
return self.linear.__call__(x)
model = ToyModel()
self._forward_hook_test_helper(model)
def test_backward_hooks(self):
class CustomLinear(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(a, b))
def forward(self, x):
return torch.mm(x, self.weight)
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net = torch.nn.Sequential(
*[CustomLinear(10, 10)]
+ [CustomLinear(10, 10000)]
+ [CustomLinear(10000, 5)]
)
def forward(self, x):
return self.net(x)
model = ToyModel()
backward_hook_handles = {}
pre_backward_hook_handles = {}
grad_sizes = {}
def backward_hook(name, mod, grad_inp, grad_out):
grad_sizes[name] = (
(gi.shape for gi in grad_inp),
(go.shape for go in grad_out),
)
return None
pre_grad_sizes = {}
def backward_pre_hook(name, mod, grad_out):
pre_grad_sizes[name] = (go.shape for go in grad_out)
return None
for name, module in model.named_modules():
backward_hook_handles[name] = module.register_full_backward_hook(
partial(backward_hook, name)
)
pre_backward_hook_handles[name] = module.register_full_backward_pre_hook(
partial(backward_pre_hook, name)
)
model = torch.compile(model, backend="aot_eager")
for i in range(2):
x = torch.randn((20, 10))
pred = model(x)
loss = pred.sum()
loss.backward()
self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
def test_module_dict_iter_name(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.activations = torch.nn.ModuleDict(
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
)
def forward(self, x):
for activation_name in self.activations:
x = self.activations[activation_name](x)
return x
cnt = torch._dynamo.testing.CompileCounter()
eager_res = MyModule()(torch.ones(10, 10))
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
self.assertEqual(eager_res, optim_res)
self.assertEqual(cnt.frame_count, 1)
def test_module_dict_iter_keys(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.activations = torch.nn.ModuleDict(
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
)
def forward(self, x):
for activation_name in self.activations.keys():
x = self.activations[activation_name](x)
return x
cnt = torch._dynamo.testing.CompileCounter()
eager_res = MyModule()(torch.ones(10, 10))
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
self.assertEqual(eager_res, optim_res)
self.assertEqual(cnt.frame_count, 1)
def test_assign_does_not_exist(self):
class MyModule(torch.nn.Module):
def forward(self, x):
self.text_encoding = x + 1
return self.text_encoding
mod = MyModule()
out = torch.compile(mod, fullgraph=True)(torch.randn(10))
assert mod.text_encoding is out
def test_module_dict_iter_values(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.activations = torch.nn.ModuleDict(
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
)
def forward(self, x):
for activation in self.activations.values():
x = activation(x)
return x
cnt = torch._dynamo.testing.CompileCounter()
eager_res = MyModule()(torch.ones(10, 10))
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
self.assertEqual(eager_res, optim_res)
self.assertEqual(cnt.frame_count, 1)
def test_unspecialized_seq(self):
models = torch.nn.Sequential(torch.nn.Linear(3, 3))
def fn(x):
models[0].training = False
return models(x)
opt_fn = torch._dynamo.optimize("eager")(fn)
x = torch.randn(1, 3)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_no_op_assignment(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.buffer = torch.rand([4])
def forward(self, x):
x = x + 1
self.buffer = self.buffer.to(x)
return self.buffer + x
compiles_without_buffers = 0
def debug_compile(gm, *args, **kwargs):
nonlocal compiles_without_buffers
compiles_without_buffers += len(list(gm.buffers())) == 0
return gm
@torch.compile(backend=debug_compile)
def foo(mod, x):
return mod(x)
mod = Mod()
foo(mod, torch.rand([4]))
self.assertEqual(compiles_without_buffers, 0)
foo(mod, torch.rand([4], dtype=torch.half))
self.assertEqual(compiles_without_buffers, 1)
class Mod2(Mod):
def __setattr__(self, name, value):
return super().__setattr__(name, value)
foo(Mod2(), torch.rand([4]))
self.assertTrue(compiles_without_buffers >= 2)
def test_unspec_non_inlinable_module(self):
mod = UnspecNonInlinableModule()
opt_fn = torch._dynamo.optimize("eager")(mod)
x = torch.randn(100)
actual = opt_fn(x)
expected = mod(x)
self.assertEqual(actual, expected)
def test_no_guard_on_torch_nn_modules(self):
class MockModule2(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
mod = MockModule2()
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt)
def generate(x, c):
return mod(x) + c
for _ in range(0, 10):
generate(torch.randn(10, 10), 0)
generate(torch.randn(10, 10), 1)
self.assertEqual(cnt.frame_count, 2)
mod.eval()
generate(torch.randn(10, 10), 0)
self.assertEqual(cnt.frame_count, 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()