import collections
import unittest
import torch
from torch.testing._internal.autocast_test_lists import (
AutocastCPUTestLists,
TestAutocast,
)
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
import torch_npu
import torch_npu.testing
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestAutocast):
def setUp(self):
super().setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
def tearDown(self):
del self.autocast_lists
super().tearDown()
@skipIfTorchDynamo()
def test_autocast_torch_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(
op, args1, torch.float32, device="cpu", out_type=out_type
)
self._run_autocast_outofplace(
op,
args2,
torch.float32,
device="cpu",
out_type=out_type,
amp_dtype=torch.float16,
)
def test_autocast_methods_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(
op, args1, torch.float32, device="cpu", module=None, out_type=out_type
)
self._run_autocast_outofplace(
op,
args2,
torch.float32,
device="cpu",
module=None,
out_type=out_type,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op,
args,
torch.bfloat16,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op,
args,
torch.float32,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
def test_autocast_torch_need_autocast_promote(self):
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
self._run_autocast_outofplace(
op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
)
def test_autocast_rnn(self):
if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported():
x = torch.randn(1, 2, 1)
hx = torch.randn(2, 2, 1)
cx = torch.randn(2, 2, 1)
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
with self.assertRaisesRegex(ValueError, "input must have the type"):
m(x, (hx, cx))
with torch.amp.autocast(device_type="cpu"):
m(x, (hx, cx))
def test_autocast_disabled_with_fp32_dtype(self):
with torch.autocast(device_type='cpu', dtype=torch.float32, enabled=False):
_ = torch.ones(10)
def test_generic_autocast(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
with torch.amp.autocast(device_type="cpu"):
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
with torch.amp.autocast(device_type="cpu"):
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
self.assertEqual(generic_autocast_output, cpu_autocast_output)
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type='npu'):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __init__(self, weight):
super().__init__()
self.dtype_cast_counter = 0
self.weight = weight
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default and
args[0] is self.weight and
kwargs['dtype'] is torch.float16
):
self.dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
self.old_clear_cache = torch.clear_autocast_cache
torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
@unittest.skipIf(not torch.npu.is_available(), "requires npu")
class TestAutocastNPU(TestCase):
def test_cast_cache_is_global(self):
"""
Verifies that the autocast cache is global. This is done by
mocking out cache clearing at the end of the forward pass,
running forward+backward with an explicit call to autocast in the
backward, and verifying that the weight only get cast to float16 once.
"""
data = torch.randn(2, 3).npu()
weight = torch.nn.Parameter(torch.randn(4, 3).npu())
with WeightDTypeCastCounterMode(weight) as mode:
with torch.autocast(device_type='npu'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(mode.dtype_cast_counter, 0)
def test_cache_disabled(self):
data = torch.randn(2, 3).npu()
weight = torch.nn.Parameter(torch.randn(4, 3).npu())
try:
torch._C._set_cached_tensors_enabled(True)
torch._C._add_cached_tensor(weight)
with WeightDTypeCastCounterMode(weight) as mode:
with torch.autocast(device_type='npu'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(mode.dtype_cast_counter, 0)
finally:
torch._C._set_cached_tensors_enabled(False)
def test_autocast_prioritize(self):
device = "npu"
dtype = torch.bfloat16
with torch.autocast(device_type=device, enabled=True, dtype=dtype):
t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True)
index = torch.randint(
low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device
)
val = torch.randn(1, dtype=dtype, device=device)
res = torch.index_put(t, [index], val)
loss = res.mean()
loss.backward()
def test_set_autocast_dtype(self):
torch_npu.npu.set_autocast_dtype(torch.float16)
self.assertTrue(torch_npu.npu.get_autocast_dtype(), torch.float16)
torch_npu.npu.set_autocast_dtype(torch.float32)
self.assertTrue(torch_npu.npu.get_autocast_dtype(), torch.float32)
def test_set_autocast_enable(self):
torch_npu.npu.set_autocast_enabled(True)
self.assertTrue(torch_npu.npu.is_autocast_enabled())
torch_npu.npu.set_autocast_enabled(False)
self.assertTrue(torch_npu.npu.is_autocast_enabled() is not True)
@unittest.skipIf(not torch.npu.is_available(), "requires npu")
class TestAutocastNPUfp32(TestCase):
def test_autocast_fp32_when_origin_dtype_is_float16(self):
device = "npu"
a = torch.rand((8, 8), device=device, dtype=torch.float16)
with torch.autocast(device_type=device, dtype=torch.float32):
b = torch.mm(a, a)
self.assertEqual(b.dtype, torch.float32)
def test_autocast_fp32_when_origin_dtype_is_bfloat16(self):
device = "npu"
a = torch.rand((8, 8), device=device, dtype=torch.bfloat16)
with torch.autocast(device_type=device, dtype=torch.float32):
b = torch.mm(a, a)
self.assertEqual(b.dtype, torch.float32)
def test_autocast_fp32_when_origin_dtype_is_float32(self):
device = "npu"
a = torch.rand((8, 8), device=device, dtype=torch.float32)
with torch.autocast(device_type=device, dtype=torch.float32):
b = torch.mm(a, a)
self.assertEqual(b.dtype, torch.float32)
def test_autocast_fp32_when_disabled(self):
device = "npu"
a = torch.rand((8, 8), device=device, dtype=torch.bfloat16)
with torch.autocast(device_type=device, dtype=torch.float32, enabled=False):
b = torch.mm(a, a)
self.assertEqual(b.dtype, torch.bfloat16)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
npu_fast_dtype = torch.get_autocast_dtype(device_type="privateuseone")
cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu")
self.assertEqual(npu_fast_dtype, torch.float32)
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
def test_invalid_device(self):
dev = "not a real device"
msg = f"Invalid device string: '{dev}'"
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)
with self.assertRaisesRegex(RuntimeError, msg):
assert torch.amp.is_autocast_available(device_type=dev)
def test_non_string_device(self):
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
dev = torch.device("cpu")
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
torch.autocast(device_type=dev)
if __name__ == '__main__':
run_tests()