from copy import deepcopy
from itertools import product
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.parametrize as parametrize
from torch.nn import Parameter
import torch_npu
import torch_npu.testing
from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \
TemporaryFileName, instantiate_parametrized_tests, set_default_dtype
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import gradcheck, TEST_PRIVATEUSE1
TEST_MULTINPU = TEST_PRIVATEUSE1 and torch_npu.npu.device_count() >= 2
class TestNNParametrization(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
@skipIfNoLapack
def test_register_and_remove_parametrization(self):
r"""Test that it is possible to add a few parametrizations
on a parameter or a buffer and that removing them restores the initial state
It also tests that backpropagating through them works as expected
"""
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T
class Orthogonal(nn.Module):
def forward(self, X):
Id = torch.eye(X.size(0), device=X.device)
return torch.linalg.solve(Id + X, Id - X).contiguous()
class Resize(nn.Module):
def forward(self, X):
return X[[0]]
class NoResize(nn.Module):
def forward(self, X):
return X
class FirstZero(nn.Module):
def forward(self, x):
return torch.cat([x.new_zeros(1), x[1:]])
class LastZero(nn.Module):
def forward(self, x):
return torch.cat([x[:-1], x.new_zeros(1)])
model = nn.Linear(8, 8)
initial_weight_id = id(model.weight)
initial_bias_id = id(model.bias)
initial_model = deepcopy(model)
with self.assertRaisesRegex(ValueError, "Registering a parametrization may not change the shape of the tensor"):
parametrize.register_parametrization(model, "weight", Resize())
model(torch.ones(8, 8))
parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertTrue(A.shape[0] == 1)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(model.__class__, nn.Linear)
parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertTrue(A.shape[0] == 1)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(model.__class__, nn.Linear)
parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertEqual(A, -A.T)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(model.__class__, nn.Linear)
parametrize.register_parametrization(model, "weight", Skew())
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertEqual(A, -A.T)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(model.__class__, nn.Linear)
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
X = model.weight
Id = torch.eye(X.size(0), device=X.device)
self.assertEqual(X.T @ X, Id)
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertIn("weight", model.parametrizations)
self.assertNotIn("weight", model._parameters)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
parametrize.register_parametrization(model, "bias", FirstZero())
parametrize.register_parametrization(model, "bias", LastZero())
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertTrue(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertEqual(len(list(model.parameters())), 2)
sgd = torch.optim.SGD(model.parameters(), lr=0.01)
weight_copy = model.weight.clone()
bias_copy = model.bias.clone()
sgd.zero_grad()
(model.weight.T @ model.bias).sum().backward()
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertTrue(parametrize.is_parametrized(model))
self.assertFalse(parametrize.is_parametrized(model, "weight"))
self.assertTrue(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertNotEqual(model.weight, initial_model.weight)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(len(list(model.parameters())), 2)
weight_copy = model.weight.clone()
bias_copy = model.bias.clone()
sgd.zero_grad()
(model.weight.T @ model.bias).sum().backward()
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
self.assertFalse(parametrize.is_parametrized(model))
self.assertNotEqual(model.bias, initial_model.bias)
self.assertNotEqual(model.bias[0].item(), 0.)
self.assertNotEqual(model.bias[-1].item(), 0.)
self.assertEqual(id(model.bias), initial_bias_id)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
self.assertEqual(len(list(model.parameters())), 2)
weight_copy = model.weight.clone()
bias_copy = model.bias.clone()
sgd.zero_grad()
(model.weight.T @ model.bias).sum().backward()
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
for _ in range(2):
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
self.assertEqual(id(model.weight), initial_weight_id)
self.assertEqual(id(model.bias), initial_bias_id)
weight_copy = model.weight.clone()
bias_copy = model.bias.clone()
sgd.zero_grad()
(model.weight.T @ model.bias).sum().backward()
sgd.step()
self.assertNotEqual(model.weight, weight_copy)
self.assertNotEqual(model.bias, bias_copy)
def test_register_and_remove_nested_parametrization(self):
r"""Test that it is possible to nest the parametrizations
meaning that the original param is parametrized again
"""
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T
model = nn.Linear(8, 8)
parametrize.register_parametrization(model, "weight", Skew())
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
A = model.weight
self.assertEqual(A, -A.T)
param_mod = model.parametrizations.weight
self.assertFalse(hasattr(param_mod, "parametrizations"))
self.assertFalse(parametrize.is_parametrized(param_mod))
self.assertFalse(parametrize.is_parametrized(param_mod, "original"))
parametrize.register_parametrization(param_mod, "original", Skew())
self.assertTrue(hasattr(param_mod, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(param_mod))
self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
self.assertNotIn("original", param_mod._parameters)
A = param_mod.original
self.assertEqual(A, -A.T)
parametrize.remove_parametrizations(param_mod, "original", leave_parametrized=False)
self.assertFalse(hasattr(param_mod, "parametrizations"))
self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
def test_register_and_remove_buffer_parametrization(self):
r"""Test that it is possible to add and remove parametrizations on buffers"""
class FirstZero(nn.Module):
def forward(self, x):
return torch.cat([x.new_zeros(1), x[1:]])
class LastZero(nn.Module):
def forward(self, x):
return torch.cat([x[:-1], x.new_zeros(1)])
model = nn.Linear(8, 8)
delattr(model, "bias")
model.register_buffer("bias", torch.ones(8))
parametrize.register_parametrization(model, "bias", FirstZero())
parametrize.register_parametrization(model, "bias", LastZero())
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
self.assertEqual(len(list(model.parameters())), 1)
parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
self.assertFalse(parametrize.is_parametrized(model))
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertEqual(model.bias[0].item(), 0.)
self.assertEqual(model.bias[-1].item(), 0.)
self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
self.assertEqual(len(list(model.parameters())), 1)
@skipIfNoLapack
def test_serialization_parametrization(self):
r"""Test that it is possible to serialize a parametrized model via state_dict"""
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("id", torch.eye(n))
self.register_buffer("B", torch.empty(n, n))
init.orthogonal_(self.B)
def forward(self, X):
A = X.triu(1)
A = A - A.T
return self.B @ torch.linalg.solve(self.id + A, self.id - A)
def get_model():
model = torch.nn.Sequential(
torch.nn.Linear(5, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 1),
)
parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
return model
model = get_model()
prev_weight = model[0].weight
prev_B = model[0].parametrizations.weight[0].B
new_model = get_model()
with TemporaryFileName() as fname:
torch.save(model.state_dict(), fname)
new_model.load_state_dict(torch.load(fname))
self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
self.assertEqual(prev_weight, new_model[0].weight)
self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)
with self.assertRaisesRegex(RuntimeError, "state_dict"):
with TemporaryFileName() as fname:
torch.save(model, fname)
@skipIfNoLapack
def test_initialization_parametrization(self):
r"""Test that it is possible to initialize a parametrization when it
implements a `right_inverse` method
"""
class Skew(nn.Module):
def forward(self, X):
A = X.triu(1)
return A - A.T
def is_skew(self, A):
return torch.allclose(A, -A.T, atol=1e-6)
def right_inverse(self, X):
if not self.is_skew(X):
raise ValueError("The matrix is not skew-symmetric.")
return X.triu(1)
class Orthogonal(nn.Module):
def __init__(self, n):
super().__init__()
self.register_buffer("B", torch.eye(n))
def forward(self, X):
Id = torch.eye(X.size(0))
return self.B @ torch.linalg.solve(Id + X, Id - X)
def is_orthogonal(self, X):
Id = torch.eye(X.size(0))
return torch.allclose(X.T @ X, Id, atol=1e-4)
def right_inverse(self, X):
if not self.is_orthogonal(X):
raise ValueError("The input is not orthogonal.")
self.B = X
return torch.zeros_like(X)
N = 5
model = nn.Linear(N, N)
skew = Skew()
with torch.no_grad():
model.weight.set_(skew(model.weight))
parametrize.register_parametrization(model, "weight", skew)
X = torch.rand(N, N)
with self.assertRaises(ValueError):
model.weight = X
X = X - X.T
model.weight = X
self.assertEqual(model.parametrizations.weight.original, X.triu(1))
self.assertEqual(model.weight, X)
parametrize.register_parametrization(model, "weight", Orthogonal(N))
X = torch.rand(N, N)
with self.assertRaises(ValueError):
model.weight = X
init.orthogonal_(X)
model.weight = X
self.assertEqual(model.weight, X)
self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))
def test_errors_unparametrized_tensor_parametrization(self):
module = nn.Linear(3, 4)
weight_init = module.weight.clone()
class Identity(nn.Module):
def forward(self, x):
return x
with self.assertRaisesRegex(ValueError, "does not have a parameter"):
parametrize.register_parametrization(module, "foo", Identity())
self.assertFalse(parametrize.is_parametrized(module))
with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
parametrize.remove_parametrizations(module, "bias")
self.assertFalse(parametrize.is_parametrized(module))
class Sum(nn.Module):
def forward(self, x, y):
return x + y
def right_inverse(self, z):
return z, torch.zeros_like(z)
parametrize.register_parametrization(module, "weight", Sum())
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
parametrize.remove_parametrizations(module, "weight", leave_parametrized=False)
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
class WrongNumberParams(nn.Module):
def forward(self, x, y, z):
return x + y + z
def right_inverse(self, w):
return w, torch.zeros_like(w)
with self.assertRaisesRegex(TypeError, "positional argument"):
parametrize.register_parametrization(module, "weight", WrongNumberParams())
self.assertFalse(parametrize.is_parametrized(module))
class WrongRightInverse(Identity):
def right_inverse(self, z):
return None
with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
parametrize.register_parametrization(module, "weight", WrongRightInverse())
self.assertFalse(parametrize.is_parametrized(module))
class WrongRightInverseSequence(nn.Module):
def forward(self, x, y):
return x
def right_inverse(self, z):
return None, z
with self.assertRaisesRegex(ValueError, "of the sequence with type"):
parametrize.register_parametrization(module, "weight", WrongRightInverseSequence())
self.assertFalse(parametrize.is_parametrized(module))
class ChangeDtypeInverse(nn.Module):
def forward(self, x):
return x.float()
def right_inverse(self, w):
return w.bool()
with self.assertRaisesRegex(ValueError, "outputs one tensor, it may not change the dtype"):
parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
self.assertFalse(parametrize.is_parametrized(module))
class NotTensor(nn.Module):
def forward(self, x):
return 2
with self.assertRaisesRegex(ValueError, "must return a tensor"):
parametrize.register_parametrization(module, "weight", NotTensor())
self.assertFalse(parametrize.is_parametrized(module))
class ChangeDtype(nn.Module):
def forward(self, x):
return x.bool()
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
parametrize.register_parametrization(module, "weight", ChangeDtype())
self.assertFalse(parametrize.is_parametrized(module))
class ChangeShape(nn.Module):
def forward(self, x):
return x[:-1]
with self.assertRaisesRegex(ValueError, "may not change the shape"):
parametrize.register_parametrization(module, "weight", ChangeShape())
self.assertFalse(parametrize.is_parametrized(module))
class ChangeDtypeMulti(nn.Module):
def forward(self, x, y):
return (x + y).bool()
def right_inverse(self, w):
return w, w + 1
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
self.assertFalse(parametrize.is_parametrized(module))
class SequenceLen1(nn.Module):
def forward(self, x):
return x
def right_inverse(self, w):
return (w,)
parametrize.register_parametrization(module, "weight", SequenceLen1())
self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
_ = module.weight
self.assertTrue(parametrize.is_parametrized(module))
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
self.assertFalse(parametrize.is_parametrized(module))
self.assertEqual(module.weight, weight_init)
def test_errors_parametrized_tensor_parametrization(self):
class Identity(nn.Module):
def forward(self, x):
return x
module = nn.Linear(3, 4)
parametrize.register_parametrization(module, "weight", Identity())
class WrongReturn(nn.Module):
def forward(self, x):
return x, x
with self.assertRaisesRegex(ValueError, "must return a tensor"):
parametrize.register_parametrization(module, "weight", WrongReturn())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
class ChangeDtype(nn.Module):
def forward(self, x):
return x.bool()
with self.assertRaisesRegex(ValueError, "may not change the dtype"):
parametrize.register_parametrization(module, "weight", ChangeDtype())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
class ChangeShape(nn.Module):
def forward(self, x):
return x[:-1]
with self.assertRaisesRegex(ValueError, "may not change the shape"):
parametrize.register_parametrization(module, "weight", ChangeShape())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
class WrongReturnInverse(Identity):
def right_inverse(self, x):
return x, x
with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
parametrize.register_parametrization(module, "weight", WrongReturnInverse())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
class ChangeDtypeInverse(Identity):
def right_inverse(self, x):
return x.bool()
with self.assertRaisesRegex(ValueError, "must have the same dtype"):
parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
class ChangeShapeInverse(Identity):
def right_inverse(self, x):
return x[:-1]
with self.assertRaisesRegex(ValueError, "must have the same shape"):
parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
self.assertTrue(parametrize.is_parametrized(module))
self.assertEqual(len(module.parametrizations.weight), 1)
self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))
@skipIfNoLapack
def test_multiple_inputs_parametrization(self):
class RankOne(nn.Module):
def forward(self, x, y):
return x.unsqueeze(-1) @ y.unsqueeze(-2)
def right_inverse(self, Y):
U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
s0_sqrt = S[0].sqrt().unsqueeze(-1)
return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
class Double(nn.Module):
def forward(self, x):
return 2.0 * x
def right_inverse(self, w):
return 0.5 * w
model = nn.Linear(3, 3)
parametrize.register_parametrization(model, "weight", RankOne())
self.assertTrue(hasattr(model, "parametrizations"))
self.assertTrue(parametrize.is_parametrized(model))
self.assertTrue(parametrize.is_parametrized(model, "weight"))
self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
self.assertIn("original0", model.parametrizations.weight._parameters)
self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
self.assertIn("original1", model.parametrizations.weight._parameters)
self.assertFalse(parametrize.is_parametrized(model, "bias"))
self.assertNotIn("weight", model._parameters)
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
self.assertFalse(parametrize.is_parametrized(model))
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
self.assertIn("weight", model._parameters)
init_weight = model.weight.clone()
parametrize.register_parametrization(model, "weight", RankOne())
self.assertEqual(init_weight, model.weight)
parametrize.register_parametrization(model, "weight", Double())
self.assertEqual(2.0 * init_weight, model.weight)
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
self.assertEqual(len(list(model.parameters())), 3)
sgd = torch.optim.SGD(model.parameters(), lr=0.1)
for _ in range(2):
sgd.zero_grad()
loss = (model.weight.T @ model.bias).sum()
loss.backward()
sgd.step()
with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
self.assertFalse(hasattr(model, "parametrizations"))
self.assertEqual(model.__class__, nn.Linear)
self.assertFalse(parametrize.is_parametrized(model))
self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
self.assertIn("weight", model._parameters)
self.assertEqual(len(list(model.parameters())), 2)
sgd = torch.optim.SGD(model.parameters(), lr=0.1)
for _ in range(2):
sgd.zero_grad()
loss = (model.weight.T @ model.bias).sum()
loss.backward()
sgd.step()
@skipIfNoLapack
def test_caching_parametrization(self):
r"""Test the caching system of a parametrization"""
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T
class Orthogonal(nn.Module):
def forward(self, X):
Id = torch.eye(X.size(0), device=X.device)
return torch.linalg.solve(Id + X, Id - X)
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
with parametrize.cached():
X = model.weight
Y = model.weight
self.assertEqual(id(X), id(Y))
@skipIfNoLapack
def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
r"""Test that transferring parametrizations doesn't cause issues with caching"""
class Skew(nn.Module):
def forward(self, X):
X = X.tril(-1)
return X - X.T
class Orthogonal(nn.Module):
def forward(self, X):
Id = torch.eye(X.size(0), device=X.device)
return torch.linalg.solve(Id + X, Id - X)
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", Skew())
parametrize.register_parametrization(model, "weight", Orthogonal())
to_model = nn.Linear(5, 5)
parametrize.transfer_parametrizations_and_params(model, to_model)
with parametrize.cached():
X = model.weight
Y = model.weight
self.assertEqual(id(X), id(Y))
A = to_model.weight
B = to_model.weight
self.assertEqual(id(A), id(B))
self.assertNotEqual(id(A), id(X))
def test_parametrization_same_training_mode(self):
r"""Test training mode updated on parametrization registration"""
class Identity(nn.Module):
def forward(self, X):
return X
module = nn.Linear(4, 4)
module.eval()
parametrize.register_parametrization(module, "weight", Identity())
self.assertFalse(module.parametrizations.weight[0].training)
module.train()
parametrize.register_parametrization(module, "weight", Identity().eval())
self.assertTrue(module.parametrizations.weight[0].training)
self.assertTrue(module.parametrizations.weight[1].training)
def test_type_before_parametrizations(self):
r"""Test that type_before_parametrizations always retrieves original type"""
class Identity(nn.Module):
def forward(self, X):
return X
model = nn.Linear(5, 5)
original_type = type(model)
self.assertTrue(
parametrize.type_before_parametrizations(model) == original_type
)
parametrize.register_parametrization(model, "weight", Identity())
self.assertTrue(
parametrize.type_before_parametrizations(model) == original_type
)
def test_deepcopy_after_parametrization(self):
r"""Test that we are able to create a deepcopy of the module when it's parametrized."""
class AddOne(nn.Module):
def forward(self, x):
return x + 1.0
class ModelWithoutDeepcopy(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True)
self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True)
self.attr = [1.0, 2.0, 3.0, 4.0]
class ActualModel(ModelWithoutDeepcopy):
def __deepcopy__(self, memo):
result = self.__new__(self.__class__)
memo[id(self)] = result
result.__dict__ = deepcopy(self.__dict__, memo)
return result
def check_deepcopy(m1: nn.Module, m2: nn.Module):
w1 = m1.parametrizations.weight.original
w2 = m2.parametrizations.weight.original
b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias
b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias
self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
self.assertIsNot(m1, m2)
self.assertEqual(w1, w2)
self.assertIsNot(w1, w2)
self.assertEqual(b1, b2)
self.assertIsNot(b1, b2)
self.assertEqual(m1.attr, m2.attr)
self.assertIsNot(m1.attr, m2.attr)
for model in (ModelWithoutDeepcopy(), ActualModel()):
parametrize.register_parametrization(model, "weight", AddOne())
check_deepcopy(model, deepcopy(model))
parametrize.register_parametrization(model, "bias", AddOne())
check_deepcopy(model, deepcopy(model))
parametrize.register_parametrization(model, "weight", AddOne())
check_deepcopy(model, deepcopy(model))
def test_transfer_parametrizations_and_params(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
class AddOne(nn.Module):
def forward(self, x):
return x + 1.0
class Double(nn.Module):
def forward(self, x):
return 2.0 * x
def right_inverse(self, x):
return 0.5 * x
class MinusOne(nn.Module):
def forward(self, x):
return x - 1.0
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", AddOne())
parametrize.register_parametrization(model, "weight", Double())
parametrize.register_parametrization(model, "weight", MinusOne())
hold_weight = model.weight
to_model = torch.ao.nn.qat.Linear(
5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
)
parametrize.transfer_parametrizations_and_params(model, to_model)
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
self.assertEqual(model.weight, to_model.weight)
self.assertEqual(
model.parametrizations.weight.original,
to_model.parametrizations.weight.original,
)
self.assertEqual(hold_weight, model.weight)
parametrize.remove_parametrizations(to_model, "weight")
self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))
model.test_param = Parameter(torch.randn(5, 5))
self.assertTrue(not hasattr(to_model, "test_param"))
parametrize.register_parametrization(model, "test_param", Double())
hold_test_param = model.test_param
parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
self.assertEqual(model.test_param, to_model.test_param)
self.assertEqual(
model.parametrizations.test_param.original,
to_model.parametrizations.test_param.original,
)
self.assertEqual(hold_test_param, model.test_param)
def test_transfer_parametrizations_and_params_right_inverse(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
class Double(nn.Module):
def forward(self, x):
return 2.0 * x
def right_inverse(self, x):
return 0.5 * x
model = nn.Linear(5, 5)
parametrize.register_parametrization(model, "weight", Double())
hold_weight = model.weight
to_model = torch.ao.nn.qat.Linear(
5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
)
parametrize.transfer_parametrizations_and_params(model, to_model)
self.assertEqual(model.weight, to_model.weight)
self.assertEqual(
model.parametrizations.weight.original,
to_model.parametrizations.weight.original,
)
self.assertEqual(hold_weight, model.weight)
def test_transfer_parametrizations_and_params_single_param(self):
r"""Test that all parametrizations and their associated parameters are transferred."""
class AddOne(nn.Module):
def forward(self, x):
return x + 1.0
class Double(nn.Module):
def forward(self, x):
return 2.0 * x
class MinusOne(nn.Module):
def forward(self, x):
return x - 1.0
model = nn.Linear(5, 5, bias=True)
parametrize.register_parametrization(model, "weight", AddOne())
parametrize.register_parametrization(model, "weight", Double())
parametrize.register_parametrization(model, "weight", MinusOne())
parametrize.register_parametrization(model, "bias", AddOne())
parametrize.register_parametrization(model, "bias", Double())
parametrize.register_parametrization(model, "bias", MinusOne())
to_model = torch.ao.nn.qat.Linear(
5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
)
parametrize.transfer_parametrizations_and_params(model, to_model, "weight")
self.assertEqual(model.weight, to_model.weight)
self.assertEqual(
model.parametrizations.weight.original,
to_model.parametrizations.weight.original,
)
self.assertTrue("bias" not in to_model.parametrizations)
@skipIfNoLapack
def test_transfer_parametrizations_and_params_many_to_one(self):
class RankOne(nn.Module):
def forward(self, x, y):
return x.unsqueeze(-1) @ y.unsqueeze(-2)
def right_inverse(self, Y):
U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
s0_sqrt = S[0].sqrt().unsqueeze(-1)
return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
class Double(nn.Module):
def forward(self, x):
return 2.0 * x
model = nn.Linear(3, 3)
parametrize.register_parametrization(model, "weight", RankOne())
parametrize.register_parametrization(model, "weight", Double())
hold_weight = model.weight
to_model = torch.ao.nn.qat.Linear(
3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
)
parametrize.transfer_parametrizations_and_params(model, to_model)
self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
self.assertEqual(model.weight, to_model.weight)
self.assertEqual(
model.parametrizations.weight.original0,
to_model.parametrizations.weight.original0,
)
self.assertEqual(
model.parametrizations.weight.original1,
to_model.parametrizations.weight.original1,
)
self.assertEqual(hold_weight, model.weight)
model.test_param = Parameter(torch.randn(3, 3))
self.assertTrue(not hasattr(to_model, "test_param"))
parametrize.register_parametrization(model, "test_param", RankOne())
hold_test_param = model.test_param
parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")
self.assertEqual(model.test_param, to_model.test_param)
self.assertEqual(
model.parametrizations.test_param.original0,
to_model.parametrizations.test_param.original0,
)
self.assertEqual(
model.parametrizations.test_param.original1,
to_model.parametrizations.test_param.original1,
)
self.assertEqual(hold_test_param, model.test_param)
def test_new_spectral_norm(self):
with set_default_dtype(torch.double):
input1 = torch.randn(3, 5)
m = nn.Linear(5, 7)
m = torch.nn.utils.parametrizations.spectral_norm(m)
spectral_norm_m = m.parametrizations.weight[0]
self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))
self.assertTrue(hasattr(m.parametrizations.weight, 'original'))
self.assertTrue('original' in m.parametrizations.weight._parameters)
self.assertTrue(hasattr(spectral_norm_m, '_u'))
self.assertTrue('_u' in spectral_norm_m._buffers)
self.assertTrue('_v' in spectral_norm_m._buffers)
self.assertIsNotNone(m.weight)
self.assertFalse('weight' in m._buffers)
self.assertFalse('weight' in m._parameters)
self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
self.assertEqual(m.parametrizations.weight.original.stride(), m.weight.stride())
m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight')
self.assertFalse(hasattr(m, 'parametrizations'))
self.assertTrue('weight' in m._parameters)
m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight')
m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight')
m = torch.nn.utils.parametrizations.spectral_norm(m, 'bias')
m = torch.nn.utils.parametrize.remove_parametrizations(m, 'bias')
self.assertTrue('bias' in m._parameters)
self.assertTrue(hasattr(m, 'parametrizations'))
self.assertFalse('weight' in m._parameters)
m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight')
self.assertFalse(hasattr(m, 'parametrizations'))
self.assertTrue('weight' in m._parameters)
self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))
for apply_dp in (True, False):
if apply_dp:
if not TEST_MULTINPU:
continue
device = torch.device('npu:0')
def maybe_wrap(m):
return torch.nn.DataParallel(m, [0, 1])
else:
device = torch.device('cpu')
def maybe_wrap(m):
return m
for requires_grad in (True, False):
def get_modules():
m = nn.Linear(3, 4).to(device)
m.weight.requires_grad_(requires_grad)
m = torch.nn.utils.parametrizations.spectral_norm(m)
wrapped_m = maybe_wrap(m)
spectral_norm_m = m.parametrizations.weight[0]
return m, wrapped_m, spectral_norm_m
input1 = torch.randn(2, 3, device=device)
m, wrapped_m, spectral_norm_m = get_modules()
self.assertTrue(hasattr(spectral_norm_m, '_u'))
u0 = spectral_norm_m._u.clone()
v0 = spectral_norm_m._v.clone()
opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)
opt.zero_grad()
wrapped_m(input1).sum().backward()
opt.step()
out = wrapped_m(input1)
if requires_grad:
self.assertNotEqual(u0, spectral_norm_m._u)
self.assertNotEqual(v0, spectral_norm_m._v)
if requires_grad:
torch.autograd.grad(out.sum(), m.parametrizations.weight.original)
saved_u = spectral_norm_m._u.clone()
saved_v = spectral_norm_m._v.clone()
def fn(input1):
spectral_norm_m._u.data.copy_(saved_u)
spectral_norm_m._v.data.copy_(saved_v)
out0 = wrapped_m(input1)
out1 = wrapped_m(input1)
return out0 + out1
fn(input1.clone().requires_grad_()).sum().backward()
gradcheck(fn, (input1.clone().requires_grad_(),), check_batched_grad=False)
m, wrapped_m, _ = get_modules()
pre_remove_out = wrapped_m(input1)
m.eval()
m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight')
self.assertEqual(wrapped_m(input1), pre_remove_out)
torch.nn.utils.parametrizations.spectral_norm(m)
for _ in range(3):
pre_remove_out = wrapped_m(input1)
m.eval()
m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight')
self.assertEqual(wrapped_m(input1), pre_remove_out)
m, wrapped_m, spectral_norm_m = get_modules()
wrapped_m(input1)
last_train_out = wrapped_m(input1)
last_train_u = spectral_norm_m._u.clone()
last_train_v = spectral_norm_m._v.clone()
wrapped_m.zero_grad()
wrapped_m.eval()
eval_out0 = wrapped_m(input1)
self.assertEqual(eval_out0, last_train_out)
self.assertEqual(eval_out0, wrapped_m(input1))
self.assertEqual(last_train_u, spectral_norm_m._u)
self.assertEqual(last_train_v, spectral_norm_m._v)
if apply_dp:
continue
saved_u = spectral_norm_m._u.clone()
saved_v = spectral_norm_m._v.clone()
def fn(input1):
spectral_norm_m._u.data.copy_(saved_u)
spectral_norm_m._v.data.copy_(saved_v)
wrapped_m.train()
out0 = wrapped_m(input1)
wrapped_m.eval()
out1 = wrapped_m(input1)
wrapped_m.train()
out2 = wrapped_m(input1)
wrapped_m.eval()
out3 = wrapped_m(input1)
return out0 + out1 + out2 + out3
gradcheck(fn, (input1.clone().requires_grad_(),))
if requires_grad:
def fn(weight):
return wrapped_m(input1)
gradcheck(fn, (m.parametrizations.weight.original,))
def test_new_spectral_norm_load_state_dict(self):
for activate_times in (0, 3):
inp = torch.randn(2, 3)
m = nn.Linear(3, 5)
snm = torch.nn.utils.parametrizations.spectral_norm(m)
snm.train()
for _ in range(activate_times):
snm(inp)
state_dict = deepcopy(snm.state_dict())
self.assertEqual({
'parametrizations.weight.original',
'bias',
'parametrizations.weight.0._v',
'parametrizations.weight.0._u'
}, set(state_dict.keys()))
non_strict_state_dict = deepcopy(state_dict)
non_strict_state_dict['nonsense'] = 'nonsense'
with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'):
snm.load_state_dict(non_strict_state_dict, strict=True)
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['parametrizations.weight.original']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['parametrizations.weight.0._u']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['parametrizations.weight.0._v']
snm.load_state_dict(non_strict_state_dict, strict=False)
non_strict_state_dict['weight'] = snm.weight.detach().clone()
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict._metadata['parametrizations.weight.0']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['weight']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['bias']
snm.load_state_dict(non_strict_state_dict, strict=False)
m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight')
snm = torch.nn.utils.parametrizations.spectral_norm(m)
snm.load_state_dict(state_dict)
with torch.no_grad():
snm.eval()
out0_eval = snm(inp)
snm.train()
out1_train = snm(inp)
out2_train = snm(inp)
snm.eval()
out3_eval = snm(inp)
m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight')
snm = torch.nn.utils.parametrizations.spectral_norm(m)
snm.load_state_dict(state_dict)
with torch.no_grad():
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))
def test_new_spectral_norm_dim(self):
inp = torch.randn(2, 3, 10, 12)
m = nn.ConvTranspose2d(3, 4, (5, 6))
m = torch.nn.utils.parametrizations.spectral_norm(m)
snm = m.parametrizations.weight[0]
x = m(inp)
self.assertEqual(snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape)
def test_new_spectral_norm_forward(self):
input1 = torch.randn(3, 5)
m = nn.Linear(5, 7)
m = torch.nn.utils.parametrizations.spectral_norm(m)
snm = m.parametrizations.weight[0]
_weight = m.parametrizations.weight.original
_bias, _v = m.bias, snm._v
_weight_mat = _weight.view(_weight.size(0), -1)
_u = torch.mv(_weight_mat, _v)
_u = F.normalize(_u, dim=0, eps=1e-12)
_v = torch.mv(_weight_mat.t(), _u)
_v = F.normalize(_v, dim=0, eps=1e-12)
_weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
out_hat = torch.nn.functional.linear(input1, _weight, _bias)
expect_out = m(input1)
self.assertEqual(expect_out, out_hat)
@skipIfNoLapack
def test_orthogonal_parametrization(self):
def assert_is_orthogonal(X):
n, k = X.size(-2), X.size(-1)
if n < k:
X = X.mT
n, k = k, n
Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k)
eps = 10 * n * torch.finfo(X.dtype).eps
torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.)
def assert_weight_allclose_Q(weight, W):
wide_matrix = W.size(-2) < W.size(-1)
if wide_matrix:
W = W.mT
Q, R = torch.linalg.qr(W)
Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
if wide_matrix:
Q = Q.mT
torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.)
for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)),
(torch.float32, torch.complex64),
(True, False)):
if not use_linear:
continue
if use_linear:
input1 = torch.randn(3, shape[0], dtype=dtype)
else:
input1 = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)
for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"),
(False, True)):
can_initialize = use_trivialization or parametrization == "householder"
if use_linear:
m = nn.Linear(*shape, dtype=dtype)
else:
m = nn.Conv2d(2, 3, shape, dtype=dtype)
w_init = m.weight.clone()
if parametrization == "householder" and m.weight.is_complex():
msg = "householder parametrization does not support complex tensors"
with self.assertRaisesRegex(ValueError, msg):
torch.nn.utils.parametrizations.orthogonal(m,
"weight",
parametrization,
use_trivialization=use_trivialization)
continue
wide_matrix = w_init.size(-2) < w_init.size(-1)
torch.nn.utils.parametrizations.orthogonal(m,
"weight",
parametrization,
use_trivialization=use_trivialization)
self.assertEqual(w_init.shape, m.weight.shape)
assert_is_orthogonal(m.weight)
if can_initialize:
assert_weight_allclose_Q(m.weight, w_init)
X = torch.randn_like(m.weight)
if wide_matrix:
X = X.mT
w_new = torch.linalg.qr(X).Q
if wide_matrix:
w_new = w_new.mT
if can_initialize:
m.weight = w_new
torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.)
else:
msg = "assign to the matrix exponential or the Cayley parametrization"
with self.assertRaisesRegex(NotImplementedError, msg):
m.weight = w_new
w_new = torch.randn_like(m.weight)
if can_initialize:
m.weight = w_new
assert_weight_allclose_Q(m.weight, w_new)
else:
msg = "assign to the matrix exponential or the Cayley parametrization"
with self.assertRaisesRegex(NotImplementedError, msg):
m.weight = w_new
opt = torch.optim.SGD(m.parameters(), lr=0.1)
for _ in range(2):
opt.zero_grad()
m(input1).norm().backward()
grad = m.parametrizations.weight.original.grad
self.assertIsNotNone(grad)
if grad.size(-2) >= grad.size(-1):
zeros_grad = grad.triu(1)
else:
zeros_grad = grad.tril(-1)
self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
diag_grad = grad.diagonal(dim1=-2, dim2=-1)
if grad.is_complex():
diag_grad = diag_grad.real
self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
opt.step()
assert_is_orthogonal(m.weight)
@skipIfNoLapack
def test_orthogonal_errors(self):
m = nn.Linear(3, 4)
with self.assertRaisesRegex(ValueError, "has to be one of"):
torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")
with self.assertRaisesRegex(ValueError, "Expected a matrix"):
torch.nn.utils.parametrizations.orthogonal(m, "bias")
torch.nn.utils.parametrizations.orthogonal(m, "weight")
with self.assertRaisesRegex(ValueError, "matrices of shape"):
m.weight = torch.randn(5, 5)
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
def test_weight_norm_state_dict_compat(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.weight_norm(m)
old_dict = m.state_dict()
m2 = nn.Linear(4, 5)
m2 = torch.nn.utils.parametrizations.weight_norm(m2)
m2.load_state_dict(old_dict)
input1 = torch.randn(3, 4)
self.assertEqual(m(input1), m2(input1))
def test_weight_norm_pickle(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
with self.assertRaisesRegex(RuntimeError, 'state_dict'):
pickle.dumps(m)
def test_weight_norm_deepcopy(self):
m = nn.Linear(4, 5)
m = torch.nn.utils.parametrizations.weight_norm(m)
m2 = deepcopy(m)
input1 = torch.randn(3, 4)
self.assertEqual(m(input1), m2(input1))
class TestNNParametrizationDevice(NNTestCase):
def test_weight_norm_parametrization(self, device):
for dtype in [torch.float, torch.bfloat16]:
input1 = torch.randn(3, 4, dtype=dtype, device=device)
m = nn.Linear(4, 5, dtype=dtype, device=device)
expected_output = m(input1)
m = torch.nn.utils.parametrizations.weight_norm(m)
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
self.assertEqual(m(input1), expected_output)
torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
self.assertFalse(hasattr(m, "parametrizations"))
self.assertEqual(m(input1), expected_output)
m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size())
self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
self.assertEqual(m(input1), expected_output)
m = nn.Linear(4, 5, dtype=dtype, device=device)
expected_output = m(input1)
m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
self.assertEqual(m(input1), expected_output)
only_for = ("cpu", "npu")
instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
instantiate_parametrized_tests(TestNNParametrization)
if __name__ == '__main__':
run_tests()