from collections.abc import Sequence
from functools import partial
import warnings
import sys
import threading
import time
import unittest
import inspect
import itertools
import contextlib
import re
import os
import stat
from typing import Dict
from collections import defaultdict
from importlib import import_module
import torch
import torch_npu
import torch_npu.testing
from torch.utils._pytree import tree_map
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
all_types_and_complex_and,
)
from torch.testing._internal.common_utils import (
TestCase,
is_iterable_of_tensors,
run_tests,
IS_SANDCASTLE,
clone_input_helper,
set_default_dtype,
suppress_warnings,
noncontiguous_like,
parametrize,
skipIfTorchInductor,
)
from torch.testing._internal.common_methods_invocations import (
op_db,
UnaryUfuncInfo,
ReductionOpInfo,
ReductionPythonRefInfo,
SpectralFuncInfo,
ops_and_refs,
python_ref_db,
BinaryUfuncInfo,
xfail,
skip,
skipOps
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
OpDTypes,
skipMeta,
)
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
)
from torch._subclasses.fake_utils import outputs_alias_inputs
import torch._prims as prims
from torch._prims.context import TorchRefsMode
from torch.testing._internal import opinfo
from torch.testing._internal import composite_compliance
from torch.utils._pytree import tree_flatten
from torch.utils._python_dispatch import TorchDispatchMode
if torch.get_default_dtype() != torch.float32:
raise RuntimeError("default dtype not equals to float32")
_variant_ops = partial(
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
)
_ref_test_ops = tuple(
filter(
lambda op: not isinstance(
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
)
and op.ref is not None,
op_db,
)
)
_ops_and_refs = op_db + python_ref_db
def reduction_dtype_filter(op):
if (not isinstance(op, ReductionPythonRefInfo) or not op.supports_out
or torch.int16 not in op.dtypes):
return False
argspec = inspect.getfullargspec(op.op)
if 'dtype' not in argspec.kwonlyargs:
return False
return True
_ops_and_refs_with_no_numpy_ref = [op for op in _ops_and_refs if op.ref is None]
aten = torch.ops.aten
class TestCommon(TestCase):
exact_dtype = True
@classmethod
def tearDownClass(cls):
super().tearDownClass()
@unittest.skip("NPU doesn't support yet.")
def test_pointwise_tag_coverage(self):
pytorch_dir = os.path.abspath(__file__ + "/../../")
files = [
"aten/src/ATen/native/UnaryOps.cpp",
"aten/src/ATen/native/BinaryOps.cpp",
"aten/src/ATen/native/PointwiseOps.cpp",
"aten/src/ATen/native/TensorCompare.cpp",
]
allowed_functions = (
"aten.max.default",
"aten.max.dim",
"aten.max.dim_max",
"aten.max.names_dim",
"aten.max.names_dim_max",
"aten.max.unary_out",
"aten.min.default",
"aten.min.dim",
"aten.min.dim_min",
"aten.min.names_dim",
"aten.min.names_dim_min",
"aten.min.unary_out",
"aten.isin.Tensor_Tensor",
"aten.isin.Tensor_Tensor_out",
"aten.isin.Tensor_Scalar",
"aten.isin.Tensor_Scalar_out",
"aten.isin.Scalar_Tensor",
"aten.isin.Scalar_Tensor_out",
"aten.mode.default",
"aten.mode.dimname",
"aten.mode.dimname_out",
"aten.mode.values",
)
regex = re.compile(r"DEFINE_DISPATCH\(.*_stub")
def get_opoverloadpacket_from_dispatch(kernel):
if hasattr(torch.ops.aten, kernel):
return kernel
if hasattr(torch.ops.aten, f"__{kernel}__"):
return f"__{kernel}__"
if hasattr(torch.ops.aten, f"special_{kernel}"):
return f"special_{kernel}"
if "_" in kernel:
kernel_split = kernel.split("_")
new_kernel = "_".join(kernel_split[:-1])
if hasattr(torch.ops.aten, new_kernel):
return new_kernel
else:
return None
self.assertTrue(False)
for file_name in files:
with open(os.path.join(pytorch_dir, file_name)) as f:
lines = f.read()
matches = regex.findall(lines)
for match in matches:
kernel = match[len("DEFINE_DISPATCH("):-len("_stub")]
if kernel == "trigamma":
continue
kernel = get_opoverloadpacket_from_dispatch(kernel)
overloadpacket = getattr(torch.ops.aten, kernel)
for overload_name in overloadpacket.overloads():
overload = getattr(overloadpacket, overload_name)
if not torch._C._dispatch_has_kernel(overload.name()):
continue
if torch.Tag.generated in overload.tags:
continue
if str(overload) in allowed_functions:
continue
self.assertTrue(torch.Tag.pointwise in overload.tags)
@suppress_warnings
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device='cpu')
return arg
samples = op.reference_inputs(device, dtype)
for sample in samples:
cpu_sample = sample.transform(to_cpu)
npu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
npu_results = sample.output_process_fn_grad(npu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
self.assertEqual(npu_results, cpu_results, atol=1e-3, rtol=1e-3)
@unittest.skip("NPU doesn't support yet.")
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
def test_errors(self, device, op):
error_inputs = op.error_inputs(device)
for ei in error_inputs:
si = ei.sample_input
with self.assertRaisesRegex(ei.error_type, ei.error_regex):
out = op(si.input, *si.args, **si.kwargs)
self.assertFalse(isinstance(out, type(NotImplemented)))
@ops(_ops_and_refs, dtypes=OpDTypes.any_one)
@skipIfTorchInductor("Inductor does not support complex dtype yet")
@unittest.skip("skip test_out now")
def test_out(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
for sample in samples:
expected = op(sample.input, *sample.args, **sample.kwargs)
op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
expected, include_empty=True
):
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
if not op.supports_out:
with self.assertRaises(Exception):
if op_out(out=expected) == NotImplemented:
raise RuntimeError("Except to support out but get not implemented")
return
def _apply_out_transform(fn, out):
if isinstance(out, torch.Tensor):
return fn(out)
return tuple(map(fn, out))
def _extract_strides(out):
if isinstance(out, torch.Tensor):
return (out.stride(),)
return tuple(t.stride() for t in out)
def _extract_data_ptrs(out):
if isinstance(out, torch.Tensor):
return (out.data_ptr(),)
return tuple(t.data_ptr() for t in out)
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
out_ = _apply_out_transform(transform, expected)
original_strides = _extract_strides(out_)
original_ptrs = _extract_data_ptrs(out_)
op_out(out=out_)
final_strides = _extract_strides(out_)
final_ptrs = _extract_data_ptrs(out_)
self.assertEqual(expected, out_)
if compare_strides_and_data_ptrs:
stride_msg = "Strides are not the same! Original strides were {} and strides are now {}".format(
original_strides, final_strides
)
self.assertEqual(original_strides, final_strides, msg=stride_msg)
self.assertEqual(original_ptrs, final_ptrs)
def _case_zero_transform(t):
try:
info = torch.iinfo(t.dtype)
return torch.full_like(t, info.max)
except TypeError as te:
return torch.full_like(t, float("nan"))
_compare_out(_case_zero_transform)
def _case_one_transform(t):
return make_tensor(
t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
)
_compare_out(_case_one_transform)
def _case_two_transform(t):
return make_tensor((0,), dtype=t.dtype, device=t.device)
_compare_out(_case_two_transform, compare_strides_and_data_ptrs=False)
out = _apply_out_transform(_case_two_transform, expected)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
op_out(out=out)
for w in caught:
if "An output with one or more elements" in str(w.message):
self.fail(
"Resizing an out= argument with no elements threw a resize warning!"
)
wrong_device = None
if torch.device(device).type != "cpu":
wrong_device = "cpu"
elif torch.npu.is_available():
wrong_device = "npu"
factory_fn_msg = (
"\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its "
"OpInfo with `is_factory_function=True`."
)
if wrong_device is not None:
def _case_three_transform(t):
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
out = _apply_out_transform(_case_three_transform, expected)
if op.is_factory_function and sample.kwargs.get("device", None) is None:
op_out(out=out)
else:
msg_fail = (
f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}."
) + factory_fn_msg
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
if (
isinstance(expected, torch.Tensor)
and expected.dtype in _dtypes
or (
not isinstance(expected, torch.Tensor)
and any(t.dtype in _dtypes for t in expected)
)
):
def _case_four_transform(t):
return make_tensor(t.shape, dtype=torch.long, device=t.device)
out = _apply_out_transform(_case_four_transform, expected)
msg_fail = "Expected RuntimeError when doing an unsafe cast!"
msg_fail = (
msg_fail
if not isinstance(expected, torch.Tensor)
else (
"Expected RuntimeError when doing an unsafe cast from a result of dtype "
f"{expected.dtype} into an out= with dtype torch.long"
)
) + factory_fn_msg
if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
op_out(out=out)
else:
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
@ops(filter(reduction_dtype_filter, _ops_and_refs), dtypes=(torch.int16,))
def test_out_integral_dtype(self, device, dtype, op):
def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
out = None
try:
if with_out:
out = torch.empty(0, dtype=torch.int32, device=device)
op_to_test(inputs, out=out, *args, **kwargs)
else:
out = op_to_test(inputs, *args, **kwargs)
self.assertFalse(expectFail)
except RuntimeError as err:
self.assertEqual(
str(err), "dtype argument and out dtype must match in reduction")
self.assertTrue(expectFail)
return out
samples = op.sample_inputs(device, dtype)
for sample in samples:
if 'dtype' not in sample.kwargs:
helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
sample.kwargs['dtype'] = torch.int16
helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
helper(True, True, op, sample.input, *sample.args, **sample.kwargs)
sample.kwargs['dtype'] = torch.int32
helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
helper(True, False, op, sample.input, *sample.args, **sample.kwargs)
else:
helper(False, False, op, sample.input, *sample.args, **sample.kwargs)
helper(True, sample.kwargs['dtype'] != torch.int32, op, sample.input,
*sample.args, **sample.kwargs)
@_variant_ops(op_db)
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_variant_consistency_eager(self, device, dtype, op):
method = op.method_variant
inplace = op.inplace_variant
operator = op.operator_variant
inplace_operator = op.inplace_operator_variant
inplace_ops = [inplace, inplace_operator]
variants_tmp = [method, inplace, operator, inplace_operator]
operators = [operator, inplace_operator]
for a_op in op.aliases:
variants_tmp.append(a_op.op)
variants_tmp.append(a_op.method_variant)
variants_tmp.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
inplace_variants = tuple(filter(None, inplace_ops))
variants = tuple(filter(None, variants_tmp))
operators = tuple(filter(None, operators))
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(
device,
dtype,
requires_grad=_requires_grad,
include_conjugated_inputs=include_conjugated_inputs,
)
samples = list(samples)
def _test_consistency_helper(samples, variants):
for sample in samples:
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
tensor.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
output_process_fn_grad = (
sample.output_process_fn_grad
if sample.output_process_fn_grad
else lambda x: x
)
skip_inplace = False
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True
if isinstance(
expected_forward, torch.Tensor
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
out = output_process_fn_grad(expected_forward).sum()
if out.dtype.is_complex:
out = out.abs()
out.backward()
expected_grad = tensor.grad
for variant in variants:
if variant in inplace_ops and skip_inplace:
continue
tensor.grad = None
cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
if variant in inplace_ops and sample.broadcasts_input:
with self.assertRaises(
RuntimeError,
msg=(
"inplace variant either incorrectly allowed "
f"resizing or you have marked the sample {sample.summary()}"
" incorrectly with `broadcasts_self=True"
),
):
variant_forward = variant(
cloned, *sample.args, **sample.kwargs
)
continue
if variant in operators and sample.kwargs:
continue
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, variant_forward)
if expected_grad is not None and (
variant not in inplace_ops or op.supports_inplace_autograd
):
out = output_process_fn_grad(variant_forward).sum()
if out.dtype.is_complex:
out = out.abs()
out.backward()
self.assertEqual(expected_grad, tensor.grad)
_test_consistency_helper(samples, variants)
def _test_inplace_preserve_storage(samples, variants):
for sample in samples:
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
skip_inplace = False
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True
if skip_inplace:
return
for variant in variants:
cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
inp_tensor = (
cloned if isinstance(cloned, torch.Tensor) else cloned[0]
)
data_ptr = inp_tensor.data_ptr()
if variant in operators and sample.kwargs:
continue
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
if isinstance(variant_forward, torch.Tensor):
self.assertEqual(
data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
)
else:
self.assertTrue(
False,
"Non-tensor outputs for inplace ops are not supported",
)
if len(inplace_ops) > 0:
inplace_samples = list(
filter(lambda sample: not sample.broadcasts_input, samples)
)
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
@ops(op_db, allowed_dtypes=(torch.bool,))
@skipIfTorchInductor("Inductor does not support view with dtype yet")
def test_non_standard_bool_values(self, device, dtype, op):
def convert_boolean_tensors(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
return x
true_vals = torch.randint(2, 255, x.shape, dtype=torch.uint8, device=x.device)
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
x_int = torch.where(x, true_vals, false_vals)
ret = x_int.view(torch.bool)
self.assertEqual(ret, x)
return ret
for sample in op.sample_inputs(device, dtype):
expect = op(sample.input, *sample.args, **sample.kwargs)
transformed = sample.transform(convert_boolean_tensors)
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
self.assertEqual(expect, actual)
@unittest.skip("NPU doesn't support yet.")
@skipMeta
@ops(ops_and_refs, dtypes=OpDTypes.none)
def test_dtypes(self, device, op):
device_type = torch.device(device).type
include_complex32 = (
(torch.complex32,)
if op.supports_dtype(torch.complex32, device_type)
else ()
)
allowed_backward_dtypes = floating_and_complex_types_and(
*((torch.half, torch.bfloat16) + include_complex32)
)
supported_dtypes = set()
unsupported_dtypes = set()
supported_backward_dtypes = set()
unsupported_backward_dtypes = set()
dtype_error: Dict[torch.dtype, Exception] = dict()
def unsupported(dtype, e):
dtype_error[dtype] = e
unsupported_dtypes.add(dtype)
if dtype in allowed_backward_dtypes:
unsupported_backward_dtypes.add(dtype)
for dtype in all_types_and_complex_and(
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)
):
requires_grad = dtype in allowed_backward_dtypes
try:
samples = tuple(
op.sample_inputs(device, dtype, requires_grad=requires_grad)
)
except Exception as e:
unsupported(dtype, e)
continue
for sample in samples:
try:
result = op(sample.input, *sample.args, **sample.kwargs)
supported_dtypes.add(dtype)
except Exception as e:
unsupported(dtype, e)
continue
def _tensor_requires_grad(x):
if isinstance(x, dict):
for v in x.values():
if _tensor_requires_grad(v):
return True
if isinstance(x, (list, tuple)):
for a in x:
if _tensor_requires_grad(a):
return True
if isinstance(x, torch.Tensor) and x.requires_grad:
return True
return False
requires_grad = _tensor_requires_grad(sample.input) \
or _tensor_requires_grad(sample.args) or _tensor_requires_grad(sample.kwargs)
if not requires_grad:
continue
try:
result = sample.output_process_fn_grad(result)
if isinstance(result, torch.Tensor):
backward_tensor = result
elif isinstance(result, Sequence) and isinstance(
result[0], torch.Tensor
):
backward_tensor = result[0]
else:
continue
grad = torch.randn_like(backward_tensor)
backward_tensor.backward(grad)
supported_backward_dtypes.add(dtype)
except Exception as e:
dtype_error[dtype] = e
unsupported_backward_dtypes.add(dtype)
supported_forward = supported_dtypes - unsupported_dtypes
partially_supported_forward = supported_dtypes & unsupported_dtypes
unsupported_forward = unsupported_dtypes - supported_dtypes
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
partially_supported_backward = (
supported_backward_dtypes & unsupported_backward_dtypes
)
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
device_type = torch.device(device).type
claimed_forward = set(op.supported_dtypes(device_type))
supported_but_unclaimed_forward = supported_forward - claimed_forward
claimed_but_unsupported_forward = claimed_forward & unsupported_forward
claimed_backward = set(op.supported_backward_dtypes(device_type))
supported_but_unclaimed_backward = supported_backward - claimed_backward
claimed_but_unsupported_backward = claimed_backward & unsupported_backward
if (len(partially_supported_forward) + len(partially_supported_backward)) > 0:
msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n"
if len(partially_supported_forward) > 0:
msg = (
msg
+ "The following dtypes only worked on some samples during forward: {}.\n".format(
partially_supported_forward
)
)
if len(partially_supported_backward) > 0:
msg = (
msg
+ "The following dtypes only worked on some samples during backward: {}.\n".format(
partially_supported_backward
)
)
print(msg)
if (
len(supported_but_unclaimed_forward)
+ len(claimed_but_unsupported_forward)
+ len(supported_but_unclaimed_backward)
+ len(claimed_but_unsupported_backward)
) == 0:
return
if op in python_ref_db:
if (
len(claimed_but_unsupported_forward)
+ len(claimed_but_unsupported_backward)
) == 0:
return
msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n"
if len(supported_but_unclaimed_forward) > 0:
msg = (
msg
+ "The following dtypes worked in forward but are not listed by the OpInfo: {}.\n".format(
supported_but_unclaimed_forward
)
)
if len(supported_but_unclaimed_backward) > 0:
msg = (
msg
+ "The following dtypes worked in backward but are not listed by the OpInfo: {}.\n".format(
supported_but_unclaimed_backward
)
)
if len(claimed_but_unsupported_forward) > 0:
msg = (
msg
+ "The following dtypes did not work in forward but are listed by the OpInfo: {}.\n".format(
claimed_but_unsupported_forward
)
)
if len(claimed_but_unsupported_backward) > 0:
msg = (
msg
+ "The following dtypes did not work in backward but are listed by the OpInfo: {}.\n".format(
claimed_but_unsupported_backward
)
)
all_claimed_but_unsupported = set.union(claimed_but_unsupported_backward, claimed_but_unsupported_forward)
if all_claimed_but_unsupported:
msg += "Unexpected failures raised the following errors:\n"
for dtype in all_claimed_but_unsupported:
msg += f"{dtype} - {dtype_error[dtype]}\n"
self.fail(msg)
class TestCompositeCompliance(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
def test_operator(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual)
composite_compliance.check_all_permutations(op, args, kwargs, self.assertEqual)
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
def test_backward(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
composite_compliance.check_backward_formula(
op.get_op(), args, kwargs,
sample.output_process_fn_grad,
op.gradcheck_wrapper, self.assertEqual)
@ops(op_db, allowed_dtypes=(torch.float,))
def test_forward_ad(self, device, dtype, op):
if torch.float not in op.supported_backward_dtypes(device):
raise unittest.SkipTest("Does not support autograd")
if not op.supports_forward_ad:
raise unittest.SkipTest("Does not support forward_ad")
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
composite_compliance.check_forward_ad_formula(
op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual)
class TestMathBits(TestCase):
def _test_math_view(
self,
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
out_type,
):
inplace_variant = op.inplace_variant
def clone_and_perform_view(input_, **kwargs):
if isinstance(input_, torch.Tensor):
requires_grad = kwargs.get("requires_grad", input_.requires_grad)
with torch.no_grad():
input_ = math_op_physical(input_)
input_ = math_op_view(input_)
if not input_.is_leaf:
raise RuntimeError("input is not leaf node")
return input_.requires_grad_(requires_grad)
if isinstance(input_, Sequence):
out = list(map(clone_input_helper, input_))
out[0] = clone_and_perform_view(out[0])
return tuple(out)
for sample in samples:
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
cloned1 = clone_and_perform_view(sample.input)
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, forward_with_mathview)
if inplace_variant is not None and not sample.broadcasts_input:
cloned2 = clone_and_perform_view(tensor, requires_grad=False)
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is tensor.dtype
):
inplace_forward = inplace_variant(
cloned2, *sample.args, **sample.kwargs
)
self.assertTrue(is_bit_set(inplace_forward))
self.assertEqual(inplace_forward, expected_forward)
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.requires_grad
):
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
expected_forward = output_process_fn_grad(expected_forward)
forward_with_mathview = output_process_fn_grad(forward_with_mathview)
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
expected_forward.sum().abs().backward(retain_graph=True)
forward_with_mathview.sum().abs().backward(retain_graph=True)
if tensor.grad is not None:
cloned1_tensor = (
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
)
self.assertEqual(tensor.grad, cloned1_tensor.grad)
tensor.grad, cloned1_tensor.grad = None, None
if out_type(expected_forward):
grad = torch.randn_like(expected_forward)
expected_forward.backward(grad)
forward_with_mathview.backward(
math_op_view(math_op_physical(grad))
)
self.assertEqual(tensor.grad, cloned1_tensor.grad)
@ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_conj_view(self, device, dtype, op):
if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
math_op_physical = torch.conj_physical
math_op_view = torch.conj
_requires_grad = torch.cfloat in op.supported_backward_dtypes(
torch.device(device).type
)
is_bit_set = torch.is_conj
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
@ops(ops_and_refs, allowed_dtypes=(torch.double,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_neg_view(self, device, dtype, op):
if not op.test_neg_view:
self.skipTest("Operation not tested with tensors with negative bit.")
math_op_physical = torch.neg
math_op_view = torch._neg_view
is_bit_set = torch.is_neg
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
lambda x: True,
)
@ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_neg_conj_view(self, device, dtype, op):
if not op.test_neg_view:
self.skipTest("Operation not tested with tensors with negative bit.")
if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
def math_op_physical(x):
return -x.conj_physical()
def math_op_view(x):
return torch._neg_view(x).conj()
def is_bit_set(x):
return torch.is_neg(x) and torch.is_conj(x)
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
samples = itertools.islice(samples, 1)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
def check_inplace_view(func, input_, rs, input_size, input_strides):
if func is None:
return
if isinstance(rs, torch.Tensor) and rs is input_:
unequal_size = rs.size() != input_size
unequal_strides = rs.stride() != input_strides
if (unequal_size or unequal_strides):
if isinstance(func, torch._ops.OpOverloadPacket):
func = func.default
if func is not torch.ops.aten.resize_.default:
if torch.Tag.inplace_view not in func.tags:
raise RuntimeError("torch.Tag.inplace_view not in func.tags")
class TestTagsMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if isinstance(args[0], torch.Tensor):
old_size = args[0].size()
old_stride = args[0].stride()
rs = func(*args, **kwargs)
check_inplace_view(func, args[0], rs, old_size, old_stride)
else:
rs = func(*args, **kwargs)
return rs
fake_skips = (
"aminmax",
"cov",
"istft",
"linalg.eigvals",
"linalg.eigvalsh",
"linalg.matrix_power",
"linalg.matrix_rank.hermitian",
"linalg.pinv.hermitian",
"linalg.solve",
"linalg.tensorsolve",
"lu_solve",
"multinomial",
"mvlgamma.mvlgamma_p_1",
"mvlgamma.mvlgamma_p_3",
"mvlgamma.mvlgamma_p_5",
"nanmean",
"quantile",
"nanquantile",
"nn.functional.ctc_loss",
"nn.functional.embedding_bag",
"nn.functional.nll_loss",
"nn.functional.max_pool1d",
"to_sparse",
"tensor_split",
"repeat_interleave",
"_segment_reduce.lengths",
"sparse.sampled.addmm",
"nn.functional.one_hot",
"narrow",
)
fake_autocast_device_skips = defaultdict(dict)
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
dynamic_output_op_tests = (
"argwhere",
"bincount",
"combinations",
"linalg.lstsq",
"masked_select",
"nonzero",
"unique_consecutive",
"unique",
"linalg.lstsq.grad_oriented",
)
sometimes_dynamic_output_op_test = (
"__getitem__",
"index_select",
)
data_dependent_op_tests = (
"equal",
"corrcoef",
"nn.functional.gaussian_nll_loss",
"allclose",
)
aliasing_failures = (
"histogramdd",
)
fake_backward_skips = {
"linalg.cond",
"linalg.matrix_norm",
"linalg.norm",
"linalg.svd",
"linalg.svdvals",
"pca_lowrank",
"roll",
"svd_lowrank",
"sgn",
}
fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
xfail("_segment_reduce", "lengths"),
xfail("fft.ihfftn"),
xfail("fft.ihfft2"),
skip('nn.functional.ctc_loss'),
}
fake_autocast_backward_xfails = {
skip("nn.functional.binary_cross_entropy"),
skip("sparse.sampled_addmm"),
skip("linalg.pinv"),
skip("linalg.pinv", "hermitian"),
skip("linalg.pinv", "singular"),
skip('pinverse'),
}
class TestFakeTensor(TestCase):
def _test_fake_helper(self, device, dtype, op, context):
name = op.name
if op.variant_test_name:
name += "." + op.variant_test_name
if name in fake_skips or "sparse" in name or "jiterator" in name:
self.skipTest("Skip failing test")
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
try:
mode = FakeTensorMode()
def map_to_fake(e):
if isinstance(e, torch.Tensor):
return mode.from_tensor(e)
else:
return e
input_ = tree_map(map_to_fake, sample.input)
args = tree_map(map_to_fake, sample.args)
kwargs = tree_map(map_to_fake, sample.kwargs)
try:
with context():
res = op(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
continue
with context():
with mode:
res_fake = op(input_, *args, **kwargs)
for fake_out, real_out in zip(
tree_flatten(res_fake)[0], tree_flatten(res)[0]
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
continue
self.assertTrue(isinstance(fake_out, FakeTensor))
prims.utils.compare_tensor_meta(fake_out, real_out, True)
if name not in aliasing_failures:
fake_aliasing = outputs_alias_inputs((input_, args, kwargs), res_fake)
real_aliasing = outputs_alias_inputs((sample.input, sample, args, sample.kwargs), res)
self.assertEqual(fake_aliasing, real_aliasing)
self.assertTrue(name not in dynamic_output_op_tests and name not in data_dependent_op_tests)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
pass
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
except torch._subclasses.fake_tensor.DynamicOutputShapeException:
self.assertTrue(name in dynamic_output_op_tests or name in sometimes_dynamic_output_op_test)
except torch._subclasses.fake_tensor.DataDependentOutputException:
self.assertTrue(name in data_dependent_op_tests)
@ops(op_db, dtypes=OpDTypes.any_one)
def test_pointwise_ops(self, device, dtype, op):
name = op.name
if op.variant_test_name:
name += "." + op.variant_test_name
if name in fake_skips or "sparse" in name or "jiterator" in name:
self.skipTest("Skip failing test")
test_self = self
class TestPointwiseMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
out = func(*args, **kwargs)
if torch.Tag.pointwise in func.tags:
shapes = []
for inp in tree_flatten((args, kwargs)):
if isinstance(inp, torch.Tensor):
shapes.append(inp.shape)
out_shape = torch._refs._broadcast_shapes(*shapes)
for out_elem in tree_flatten(out):
if isinstance(out_elem, torch.Tensor):
test_self.assertEqual(out_elem.shape, out_shape)
return out
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
mode = FakeTensorMode()
def map_to_fake(e):
if isinstance(e, torch.Tensor):
return mode.from_tensor(e)
else:
return e
input = tree_map(map_to_fake, sample.input)
args = tree_map(map_to_fake, sample.args)
kwargs = tree_map(map_to_fake, sample.kwargs)
try:
op(input, *args, **kwargs)
except Exception as e:
continue
with TestPointwiseMode():
with mode:
op(input, *args, **kwargs)
@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
self._test_fake_helper(device, dtype, op, contextlib.nullcontext)
@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake_autocast(self, device, dtype, op):
if op.name in fake_autocast_device_skips[device]:
self.skipTest("Skip failing test")
context = torch.npu.amp.autocast if device == "npu" else torch.cpu.amp.autocast
self._test_fake_helper(device, dtype, op, context)
def _test_fake_crossref_helper(self, device, dtype, op, context):
samples = op.sample_inputs(device, dtype, requires_grad=True)
for _, sample in enumerate(samples):
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
common_skip_ops = (
aten.detach.default,
aten.empty_strided.default,
aten.copy_.default,
aten.is_same_size.default,
)
try:
with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True):
with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(False):
composite_compliance.compute_expected_grads(
op.get_op(), args, kwargs,
sample.output_process_fn_grad,
op.gradcheck_wrapper)
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
@skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails)
@unittest.skip("skip test fake crossref backward no amp now")
def test_fake_crossref_backward_no_amp(self, device, dtype, op):
self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
@skipOps('TestFakeTensor', 'test_fake_crossref_backward_amp', fake_backward_xfails | fake_autocast_backward_xfails)
@unittest.skip("skip test fake crossref backward amp now")
def test_fake_crossref_backward_amp(self, device, dtype, op):
self._test_fake_crossref_helper(device, dtype, op, torch.npu.amp.autocast)
instantiate_device_type_tests(TestCommon, globals(), only_for='privateuse1')
instantiate_device_type_tests(TestCompositeCompliance, globals(), only_for='privateuse1')
instantiate_device_type_tests(TestMathBits, globals(), only_for='privateuse1')
instantiate_device_type_tests(TestFakeTensor, globals(), only_for='privateuse1')
"""
Below defines params needed to run ALL test suites and collect corresponding failed cases.
test logs is stored in test_name.log. File will be removed automatically if process exits
with zero.
IO_path: path to store logs. Uses need to manually create a folder to store the log files.
res_log: file name to store all failed test names.
"""
IO_path = "logs"
res_log = "result.log"
flags = os.O_WRONLY | os.O_RDONLY | os.O_CREAT
modes = stat.S_IWUSR | stat.S_IRUSR
def get_list(all_test_name_log):
all_attr = dir(TestCommonPRIVATEUSE1) + dir(TestCompositeCompliancePRIVATEUSE1) + dir(TestMathBitsPRIVATEUSE1) + \
dir(TestFakeTensorPRIVATEUSE1)
with os.fdopen(os.open(all_test_name_log, flags, modes), "a") as f:
for i in all_attr:
if i.startswith("test_"):
f.write(i + "\n")
def check_file_IO(log_file):
size = os.path.getsize(log_file)
time.sleep(30)
new_size = os.path.getsize(log_file)
return size == new_size
def _read_file(t_name):
log_name = os.path.join(IO_path, '{}.log'.format(t_name))
success = False
if os.path.exists(log_name):
while not check_file_IO(log_name):
pass
with open(log_name, 'r', encoding='utf-8') as f:
tmp = f.readlines()
for t in tmp:
if "OK" in t:
success = True
os.remove(log_name)
return
if not success:
with os.fdopen(os.open(res_log, flags, modes), "a") as f:
f.write(t_name + '\n')
if os.path.exists(log_name):
os.remove(log_name)
def start_thread(t_name):
thread_io = threading.Thread(target=_read_file, args=(t_name,))
thread_io.start()
if __name__ == "__main__":
check_end = sys.argv[-1].isdigit()
if check_end:
device_id, test_name = sys.argv[-1], sys.argv[-2]
torch_npu.npu.set_device(int(device_id))
start_thread(test_name)
run_tests(sys.argv[:-1])
else:
run_tests(sys.argv[:])