from typing import Sequence, List
from functools import partial
import torch
from torch.testing._internal.common_methods_invocations import SampleInput
from torch.testing._internal.common_dtype import floating_and_complex_types_and
from torch.testing._internal.common_utils import (clone_input_helper,
first_sample,
is_iterable_of_tensors)
import torch_npu
from torch_npu.testing.common_methods_invocations import op_db, tocpu_db
from torch_npu.testing.decorator import Dtypes, Formats, instantiate_ops_tests
from torch_npu.testing.testcase import TestCase, run_tests
def trans_device_and_dtype(sample, origin, target, npu_format=2, to_npu=False):
def _trans_helper(arg):
if isinstance(arg, torch.Tensor):
if to_npu:
arg = arg.to('npu')
if arg.dtype == origin:
arg = arg.to(target)
if to_npu:
torch_npu.npu_format_cast_(arg, npu_format)
return arg
sample_helper = sample.transform(_trans_helper)
res = SampleInput(input=sample_helper[0],
args=sample_helper[1],
kwargs=sample_helper[2],
broadcasts_input=sample.broadcasts_input)
return res
op_db += tocpu_db
@instantiate_ops_tests(op_db)
class TestOps(TestCase):
def test_correctness(self, dtype, op, npu_format):
def _generate_sample_inputs_requried_grad(sample_input, args):
res = []
if isinstance(sample_input, torch.Tensor):
res.append(sample_input)
elif isinstance(sample_input, Sequence) and isinstance(sample_input[0], torch.Tensor):
res.extend(sample_input)
if isinstance(args, torch.Tensor):
res.append(args)
elif isinstance(args, Sequence):
for arg in args:
if isinstance(arg, torch.Tensor) and (arg.grad_fn or arg.requires_grad):
res.append(arg)
return res
unsupported_dtypes_cpu = {dtype for dtype in op.dtypesIfNPU if dtype not in op.dtypes}
allowed_backward_dtypes = floating_and_complex_types_and(*(torch.half, torch.bfloat16))
requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd)
samples = op.sample_inputs('cpu', dtype, requires_grad=requires_grad)
for index, sample in enumerate(samples):
if op.skipSample and index in op.skipSample.get('test_correctness', {}):
continue
cpu_sample = sample
if dtype in unsupported_dtypes_cpu and dtype == torch.float16:
cpu_sample = trans_device_and_dtype(sample, dtype, torch.float32)
expected = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
npu_sample = trans_device_and_dtype(sample, dtype, dtype, npu_format, to_npu=True)
actual = op(npu_sample.input, *npu_sample.args, **npu_sample.kwargs)
self.assertRtolEqual(expected, actual, auto_trans_dtype=True, message=f'sampleinput #{index} fail')
if not requires_grad:
continue
expected = cpu_sample.output_process_fn_grad(expected)
actual = npu_sample.output_process_fn_grad(actual)
if isinstance(expected, torch.Tensor):
backward_cpu_outputs = expected.sum()
backward_npu_outputs = actual.sum()
elif isinstance(expected, Sequence) and isinstance(expected[0], torch.Tensor):
backward_cpu_outputs = [tensor.sum() for tensor in expected]
backward_npu_outputs = [tensor.sum() for tensor in actual]
else:
raise TypeError("Unsupported {} output".format(type(expected)))
sample_input_required_grad_cpu = _generate_sample_inputs_requried_grad(cpu_sample.input, cpu_sample.args)
sample_input_required_grad_npu = _generate_sample_inputs_requried_grad(npu_sample.input, npu_sample.args)
grads_cpu = torch.autograd.grad(outputs=backward_cpu_outputs,
inputs=sample_input_required_grad_cpu)
grads_npu = torch.autograd.grad(outputs=backward_npu_outputs,
inputs=sample_input_required_grad_npu)
self.assertRtolEqual(grads_cpu, grads_npu, auto_trans_dtype=True, message=f'sampleinput #{index} fail')
@Formats(2)
@Dtypes(torch.float32)
def test_variant_consistency_eager(self, dtype, op, npu_format):
method = op.method_variant
inplace = op.inplace_variant
inplace_ops = [inplace, ]
variants_tmp = [method, inplace, ]
for a_op in op.aliases:
variants_tmp.append(a_op.op)
variants_tmp.append(a_op.method_variant)
variants_tmp.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
inplace_variants = tuple(filter(None, inplace_ops))
variants = tuple(filter(None, variants_tmp))
allowed_backward_dtypes = floating_and_complex_types_and(
*(torch.half, torch.bfloat16))
requires_grad = (dtype in allowed_backward_dtypes and op.supports_autograd)
samples = op.sample_inputs('cpu',
dtype,
requires_grad=requires_grad)
def _test_consistency_helper(samples, variants):
for index, sample in enumerate(samples):
if op.skipSample and index in op.skipSample.get('test_variant_consistency_eager', {}):
continue
sample = trans_device_and_dtype(sample, dtype, dtype, npu_format, to_npu=True)
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
tensor.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
skip_inplace = False
if isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype:
skip_inplace = True
if isinstance(expected_forward, torch.Tensor) and requires_grad:
output_process_fn_grad(expected_forward).sum().backward()
expected_grad = tensor.grad
for variant in variants:
if variant in inplace_ops and skip_inplace:
continue
tensor.grad = None
cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
if variant in inplace_ops and sample.broadcasts_input:
continue
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
self.assertRtolEqual(expected_forward, variant_forward, message=f'sampleinput #{index} fail')
if not requires_grad:
continue
if expected_grad is not None and (
variant not in inplace_ops or op.supports_inplace_autograd
):
output_process_fn_grad(variant_forward).sum().backward()
self.assertRtolEqual(expected_grad, tensor.grad, message=f'sampleinput #{index} fail')
_test_consistency_helper(samples, variants)
def _test_inplace_preserve_storage(samples, variants):
for index, sample in enumerate(samples):
if op.skipSample and index in op.skipSample.get('test_variant_consistency_eager', {}):
continue
sample = trans_device_and_dtype(sample, dtype, dtype, npu_format, to_npu=True)
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
skip_inplace = False
if isinstance(expected_forward, torch.Tensor) and expected_forward.dtype is not tensor.dtype:
skip_inplace = True
if skip_inplace:
return
for variant in variants:
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0]
data_ptr = inp_tensor.data_ptr()
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
if isinstance(variant_forward, torch.Tensor):
self.assertRtolEqual(data_ptr, variant_forward.data_ptr())
else:
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported")
if inplace_ops:
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples))
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
@Formats(2)
@Dtypes(torch.float32)
def test_out(self, op, dtype, npu_format):
if not op.supports_out:
self.skipTest("Skipped! Op doesn't support out= kwarg.")
supported_dtypes = op.supported_dtypes('npu')
if len(supported_dtypes) == 0:
self.skipTest("Skipped! Op has not supported dtypes on this device.")
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
samples = op.sample_inputs('cpu', dtype)
sample = first_sample(self, samples)
sample = trans_device_and_dtype(sample, dtype, dtype, npu_format, to_npu=True)
expected = op(sample.input, *sample.args, **sample.kwargs)
op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
def _apply_out_transform(fn, out):
if isinstance(out, torch.Tensor):
return fn(out)
return tuple(map(fn, out))
def _extract_strides(out):
if isinstance(out, torch.Tensor):
return (out.stride(),)
return tuple(map(lambda t: t.stride(), out))
def _extract_data_ptrs(out):
if isinstance(out, torch.Tensor):
return (out.data_ptr(),)
return tuple(map(lambda t: t.data_ptr(), out))
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
out = _apply_out_transform(transform, expected)
original_strides = _extract_strides(out)
original_ptrs = _extract_data_ptrs(out)
op_out(out=out)
final_strides = _extract_strides(out)
final_ptrs = _extract_data_ptrs(out)
self.assertRtolEqual(expected, out)
if compare_strides_and_data_ptrs:
self.assertRtolEqual(original_strides, final_strides)
self.assertRtolEqual(original_ptrs, final_ptrs)
def _case_zero_transform(t):
return t
_compare_out(_case_zero_transform)
if __name__ == "__main__":
run_tests()