"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_adam in OptimizerTests)
"""
import functools
import inspect
import torch
import torch_npu
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch.nn import Parameter
input1 = torch.ones([10, 10])
model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)])
model(input1).sum().backward()
def get_optimizer_step(opt_arg, closure=None):
torch._dynamo.eval_frame.TorchPatcher.patch()
step_fn = opt_arg.step.__wrapped__
if closure is not None:
def fn():
step_fn(opt_arg, closure)
else:
def fn():
step_fn(opt_arg)
return fn
def make_test(optim_cls, closure=None, **kwargs):
opt = optim_cls(model.parameters(), **kwargs)
def test_fn(self):
nonlocal opt
fn = get_optimizer_step(opt, closure=closure)
with torch.set_grad_enabled(False):
torch.compile(fn, backend="eager", fullgraph=True)()
return test_fn
class OptimizerTests(torch._dynamo.test_case.TestCase):
test_sgd = make_test(torch.optim.SGD, lr=0.01)
exclude = {
"SGD",
"Optimizer",
"SparseAdam",
"LBFGS",
"RAdam",
}
def check_opt(opt_ipt):
if inspect.isclass(opt_ipt) and issubclass(opt_ipt, torch.optim.Optimizer) and opt_ipt.__name__ not in exclude:
return True
return False
optimizers = [
opt
for opt in torch.optim.__dict__.values()
if check_opt(opt)
]
for opt in optimizers:
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
class MyOptimizer(torch.optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
def _init_group(self, params, group):
any_complex = False
for p in group["params"]:
params.append(p)
any_complex |= p.is_complex()
return any_complex
def step(self):
for group in self.param_groups:
params = []
any_complex = self._init_group(params, group)
if any_complex:
params[0] -= 1
else:
params[0] += 1
class End2EndTests(torch._dynamo.test_case.TestCase):
def test_optimizing_over_tensor_with_requires_grad(self):
class Net(torch.nn.Module):
def forward(self, x, y):
z = torch.bmm(x, y)
z = torch.flatten(z, 1)
return z
def training_iter_fn(batch, model, optimizer):
optimizer.zero_grad()
out = model(**batch)
target = torch.tensor([0, 7])
loss = torch.nn.CrossEntropyLoss()(out, target)
loss.backward()
optimizer.step()
return loss
net = Net()
input_1 = torch.randn(2, 1, 4)
input_2 = torch.randn(2, 4, 8, requires_grad=True)
optimizer = torch.optim.Adam([input_2], lr=0.1)
cnts = torch._dynamo.testing.CompileCounter()
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
batch = {"x": input_1, "y": input_2}
for _ in range(2):
opt_training_iter_fn(batch, net, optimizer)
self.assertEqual(cnts.frame_count, 2)
def test_state_dict(self):
@torch.compile(backend="eager")
def _test_state_dict(weight, bias, ipt):
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = ipt
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = torch.optim.Adagrad([weight, bias])
fn = functools.partial(fn_base, optimizer, weight, bias)
return optimizer, fn
optimizer, fn = _test_state_dict(
Parameter(torch.randn(10, 5)),
Parameter(torch.randn(10)),
torch.randn(5, requires_grad=True),
)
optimizer.step(fn)
def test_init_group(self):
for dtype in [torch.float32, torch.cfloat]:
tensor = torch.randn(5, 5, dtype=dtype)
params = Parameter(tensor.detach().clone(), requires_grad=False)
opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
print(params, opt_params)
optim = MyOptimizer([params])
optim.step()
opt_optim = MyOptimizer([opt_params])
opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
opt_step()
print(params, opt_params)
self.assertEqual(params, opt_params)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()