import unittest
import io
import os
import sys
import copy
from collections import namedtuple
from itertools import chain
from typing import Dict, List, Optional, Tuple
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
from torch.testing import FileCheck
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import suppress_warnings, \
skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \
IS_SANDCASTLE, TemporaryFileName, skipIfCrossRef, skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \
_tmp_donotuse_dont_inline_everything, _trace, make_global
from torch.testing._internal.common_cuda import with_tf32_off
from torch import Tensor
RUN_CUDA = torch.npu.is_available()
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.npu.device_count() > 1
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires NPU")
def test_large_nbr_kernel_args(self):
class Recurrence(nn.Module):
def __init__(self, seq_len):
super().__init__()
self.seq_len = seq_len
def forward(self, input_):
input_ = input_.transpose(0, 1)
output = []
for i in range(self.seq_len):
b = input_[i] * 2
output.append(b)
output = torch.cat(output, 0).view(input_.size(0), *output[0].size())
output = output.transpose(0, 1)
return output
input_size = 8
batch_size = 2
seq_len = 130
rec = Recurrence(seq_len)
input_ = torch.rand(batch_size, seq_len, input_size)
torch.npu.set_device(0)
rec = rec.npu()
input_ = input_.npu()
traced_rec = torch.jit.trace(rec, (input_))
def test_trace_legacy_ctor(self):
class MyModule(nn.Module):
def forward(self, x):
return (x + 1, torch.FloatTensor([0]))
traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2))
def test_simple(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0.7], requires_grad=True)
def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
self.checkTrace(f, (x, y))
def test_trace_checking_with_global_name(self):
class MyClass(torch.nn.Module):
def forward(self, xs: List[Tensor]):
y = torch.cat(xs, dim=0)
return y
model = MyClass()
global input1, input2
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 2)
m2 = torch.jit.trace(model, ((input1, input2),))
def test_trace_aliased_parameter(self):
class M(nn.Module):
def __init__(self, x):
super().__init__()
self.x = nn.Parameter(x)
def forward(self, y):
return self.x + y
m = M(torch.rand(3, 4))
r = torch.jit.trace(m, m.x)
t2 = torch.rand(3, 4)
self.assertEqual(r(t2), m.x + t2)
def test_trace_nested_fn(self):
class TracedInlineDecision(torch.nn.Module):
def forward(self, x, flag):
@torch.jit.script
def make_decision(flag, x):
if flag:
return x
else:
return torch.zeros_like(x)
x = torch.neg(x)
return make_decision(flag, x)
decision = TracedInlineDecision()
torch.jit.trace(decision, (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)), check_trace=True)
def test_trace_single_tuple(self):
x = torch.tensor(2.)
def f2(x):
return (x,)
jit_f2 = torch.jit.trace(f2, x)
assert f2(x) == jit_f2(x)
def test_trace_out_operator_with_two_output(self):
example_input = torch.rand(2, 8)
out_1, out_2 = torch.cummax(example_input, 1)
def run_cummax(example_input, out_1, out_2):
output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2))
return output_1, output_2
trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2))
def test_trace_namedtuple(self):
Point = namedtuple('point', ['x', 'y'])
def f(p):
if type(p) is tuple:
p = Point(*p)
return p.x + p.y
p = Point(torch.randn(1), torch.randn(1))
traced = torch.jit.trace(f, (p,))
self.assertEqual(f(p), traced(p))
def test_trace_topk(self):
class M(torch.nn.Module):
def forward(self, x, y):
return x.topk(y, dim=1)[1]
mod = M()
inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17))
traced_func = torch.jit.trace(mod, inputs)
test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8))
eager_out = mod(*test_inputs)
traced_out = traced_func(*test_inputs)
self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here")
self.assertEqual(eager_out, traced_out)
test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12))
eager_out = mod(*test_inputs)
traced_out = traced_func(*test_inputs)
self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here")
self.assertEqual(eager_out, traced_out)
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)
def f(x, y):
return x.type_as(y)
trace = torch.jit.trace(f, (a, b))
def test_trace_index(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0], dtype=torch.int64)
def fn(x, y):
return x[y]
fn_traced = torch.jit.trace(fn, (x, y,))
self.assertEqual(fn(x, y), fn_traced(x, y))
def test_trace_index_constant(self):
x = torch.tensor([0.4], requires_grad=True)
def fn(x):
return x[0]
def run(f):
y = f(x)
grad = torch.autograd.grad(y, x)[0].clone()
return y, grad
traced_fn = torch.jit.trace(fn, torch.ones(1))
self.assertEqual(run(fn), run(traced_fn))
def test_index_put(self):
ten = torch.zeros(3, 3)
mask = torch.tensor([[True, True, True],
[True, False, False],
[True, True, False]])
def test_fn(ten, mask):
ten[mask] = torch.ones(6)
return ten
traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
ten = torch.rand(3, 3)
self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
def test_canonicalize_tensor_iterator(self):
x = torch.randn(4, 4)
def f(x):
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
traced = torch.jit.trace(f, (x,))
f(x)
graph = traced.graph_for(x)
self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant') == 5)
@suppress_warnings
def test_constant(self):
x = torch.randn(2, 2, requires_grad=True)
def f(x):
return x.matmul(torch.diag(torch.tensor([2., 2.])))
self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
def test_wrapped_number(self):
def foobar():
x = -10000.0
result = x * torch.ones(1, dtype=torch.float)
return result
scripted = torch.jit.trace(foobar, (), check_trace=True)
def test_inplace_transplant(self):
x = torch.tensor([0.], requires_grad=True)
def fn(x):
y = x.clone()
y.add_(2)
y.add_(3)
return y
g, _ = torch.jit._get_trace_graph(fn, (x,))
self.run_pass('dce', g)
FileCheck().check_count("aten::clone", 1, exactly=True) \
.check_count("aten::add_", 2, exactly=True) \
.check_next("return").run(str(g))
self.assertExportImport(g, (x,))
def test_inplace_flags(self):
class InplaceFn(Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
return x.add_(1)
@staticmethod
def backward(ctx, go):
return go
class RegularFn(Function):
@staticmethod
def forward(ctx, x):
return x.add(1)
@staticmethod
def backward(ctx, go):
return go
x = torch.tensor([0.], requires_grad=True)
def fn(x):
y = RegularFn.apply(x)
y = InplaceFn.apply(y)
y = InplaceFn.apply(y)
y = RegularFn.apply(y)
return y
trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True)
self.run_pass('dce', trace_graph)
ops = list(trace_graph.nodes())
for op in ops:
self.assertTrue(op.hasAttribute('inplace'))
inplace_flags = [False, True, True, False]
for op, is_inplace in zip(ops, inplace_flags):
self.assertEqual(op.i('inplace'), is_inplace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@staticmethod
def forward(self, x):
x.add_(1)
self.mark_dirty(x)
return x
@staticmethod
def backward(self, grad):
return grad
def fn(x):
return MyInplaceFn.apply(x)
x = torch.randn(5, 5)
ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
ge(x)
def test_force_outplace_check_fill(self):
def f(x):
return torch.empty(x.shape).fill_(7)
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def test_force_outplace_check_zero(self):
def f(x):
return torch.empty(x.shape).zero_()
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def do_trace_size(self, requires_grad):
def fn(x):
return x.view(x.shape[1] * 2, x.size(0), 2)
x = torch.randn(5, 2, 4, requires_grad=requires_grad)
y = torch.randn(4, 8, 4, requires_grad=requires_grad)
traced_fn = torch.jit.trace(fn, x)
self.assertEqual(traced_fn(y), fn(y))
self.assertEqual(traced_fn(x), fn(x))
def test_trace_size(self):
self.do_trace_size(False)
def test_trace_size_with_grad(self):
self.do_trace_size(True)
def test_trace_numel(self):
def fn(x):
return x.numel()
x = torch.randn(2, 3, 4)
y = torch.randn(4, 5, 6)
traced_fn = torch.jit.trace(fn, x)
self.assertEqual(traced_fn(y), fn(y))
self.assertEqual(traced_fn(x), fn(x))
def do_trace_arange(self, requires_grad):
def arange(x):
return torch.arange(x.shape[0])
def arange_scalar(x):
return torch.arange(12)
def arange_start_end(x):
return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
x = torch.randn(5, 3, 2, requires_grad=requires_grad)
y = torch.randn(8, 2, 4, requires_grad=requires_grad)
traced_arange = torch.jit.trace(arange, x)
self.assertEqual(traced_arange(y), arange(y))
self.assertEqual(traced_arange(x), arange(x))
traced_arange_scalar = torch.jit.trace(arange_scalar, x)
self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
traced_arange_start_end = torch.jit.trace(arange_start_end, x)
self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
def test_trace_arange(self):
self.do_trace_arange(False)
def test_trace_arange_with_grad(self):
self.do_trace_arange(True)
def test_trace_full_dynamic_shape(self):
def full_with_shape_like(x):
return torch.full(x.shape, 2.)
x = torch.randn(3, 4)
ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
y = torch.randn(2, 7)
self.assertEqual(ge(y).shape, y.shape)
self.assertEqual(ge(x).shape, x.shape)
def test_trace_slice_setitem_dynamic_shape(self):
def slice_setitem(x, y):
x[:, 2] = y + 1
return x
x = torch.randn(3, 4)
traced = torch.jit.trace(slice_setitem, (x, x[:, 0]))
x = torch.randn(10, 5)
self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0]))
@suppress_warnings
def do_trace_slice(self, requires_grad):
def slice(x):
results = []
for i in range(4):
results.append(x[:x.size(0) - i, i:x.size(2), i:3])
return tuple(results)
def slice_select(x):
results = []
for i in range(4):
results.append(x[:, i:, x.size(2) - 5])
return tuple(results)
x = torch.randn(5, 6, 7, requires_grad=requires_grad)
y = torch.randn(7, 8, 9, requires_grad=requires_grad)
traced_slice = torch.jit.trace(slice, x)
self.assertEqual(traced_slice(y), slice(y))
self.assertEqual(traced_slice(x), slice(x))
traced_slice_select = torch.jit.trace(slice_select, x)
self.assertEqual(traced_slice_select(y), slice_select(y))
self.assertEqual(traced_slice_select(x), slice_select(x))
def test_trace_slice(self):
self.do_trace_slice(False)
def test_trace_slice_with_grad(self):
self.do_trace_slice(True)
def test_trace_casts(self):
casts = [
lambda x: x.byte(),
lambda x: x.float(),
lambda x: x.cpu(),
lambda x: x.to(device='cpu'),
lambda x: x.to(dtype=torch.int64),
lambda x: x.to(device='cpu', dtype=torch.float),
lambda x: x.to(x)
]
def assertContainsCast(trace):
self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
for cast in casts:
trace = torch.jit.trace(cast, torch.randn(2, 2))
assertContainsCast(trace)
x = torch.randn(2, 2)
self.assertEqual(trace(x), cast(x))
def to_tensor(x, y):
return x.to(y)
to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
assertContainsCast(to_tensor_trace)
x, y = torch.randn(2, 2), torch.randn(1, 10)
self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
@skipIfCompiledWithoutNumpy
@skipIfCrossRef
def test_trace_warn(self):
def fn(x):
int(x)
y = x * 1
if y:
pass
q = [x, x * 4]
z = q[y]
float(z)
z.tolist()
z.numpy()
for _ in torch.ones(4, 4):
pass
return z + 4
with warnings.catch_warnings(record=True) as warns:
traced_fn = torch.jit.trace(fn, torch.tensor([1]))
for warn in warns:
self.assertIs(warn.category, torch.jit.TracerWarning)
warns = [str(w.message) for w in warns]
self.assertIn('a Python integer', warns[0])
self.assertIn('a Python boolean', warns[1])
self.assertIn('a Python float', warns[2])
self.assertIn('a Python list', warns[3])
self.assertIn('a NumPy array', warns[4])
self.assertIn('Iterating over', warns[5])
def test_trace_tuple(self):
def fn(x, y):
return x, (x * y[1], x * y[0])
x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
traced_fn = torch.jit.trace(fn, (x, y))
self.assertEqual(traced_fn(x, y), fn(x, y))
FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
.run(str(traced_fn.graph))
self.assertExportImport(traced_fn.graph, (x, y))
def test_trace_random(self):
def f(mean, std):
return torch.normal(mean, std)
traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
mean, std = torch.zeros(5, 5), torch.ones(5, 5)
with torch.random.fork_rng(devices=[]):
output = f(mean, std)
traced_output = traced(mean, std)
self.assertEqual(output, traced_output)
def test_trace_tensor_factory(self):
def run(**kwargs):
inputs_require_grads = kwargs.pop('inputs_require_grads', True)
def fn(x):
return x + torch.ones(2, 3, **kwargs)
input_kwargs = kwargs.copy()
if 'out' in input_kwargs:
del input_kwargs['out']
input_ = torch.ones(2, 3, **input_kwargs)
self.checkTrace(fn, (input_,), inputs_require_grads=inputs_require_grads)
tfn = torch.jit.trace(fn, input_)
self.assertTrue("ones" in str(tfn.graph))
run()
run(dtype=torch.int, inputs_require_grads=False)
run(out=torch.tensor([]))
if RUN_CUDA:
run(device="npu:0")
if RUN_CUDA_MULTI_GPU:
run(device="npu:1")
def test_trace_indexed_assignment(self):
def stuff(x, y):
x = x.clone()
x[0] = y
return x
example = torch.rand(3, 4)
self.checkTrace(stuff, (example, example[0] + 1))
@unittest.expectedFailure
def test_output_unflatten(self):
"""Check that outputs of traced functions retain the original structure and nesting"""
def fn(x):
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
self.checkTrace(fn, (torch.randn(2, 2),))
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
def fn(x, t):
y, z = t
return x * y * z
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
self.checkTrace(fn, inputs)
def test_input_dict_empty(self):
def test(d):
pass
with self.assertRaises(RuntimeError):
self.checkTrace(test, {})
def test_input_dict_remembers_keys(self):
"""Check that the trace remembers which keys were in a dict input"""
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input['x']
input_1 = {'x': torch.tensor(1)}
m = TestModule()
m_traced = torch.jit.trace(m, (input_1, ))
self.assertEqual(m_traced(input_1), torch.tensor(1))
input_same_key_different_value = {'x': torch.tensor(2)}
self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2))
input_different_key = {'y': torch.tensor(3)}
with self.assertRaises(RuntimeError):
m_traced(input_different_key)
input_additional_key = {'x': torch.tensor(4), 'y': torch.tensor(3)}
self.assertEqual(m_traced(input_additional_key), torch.tensor(4))
def test_input_dict_insertion_order(self):
"""Check that dictionary access doesn't care about insertion order"""
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input['x'], dict_input['y']
input_x_then_y = {}
input_x_then_y['x'] = torch.tensor(1)
input_x_then_y['y'] = torch.tensor(2)
m = TestModule()
m_traced = torch.jit.trace(m, (input_x_then_y, ))
self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2)))
input_y_then_x = {}
input_y_then_x['y'] = torch.tensor(4)
input_y_then_x['x'] = torch.tensor(3)
self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4)))
def test_input_dict_recursive(self):
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input['x'][1]
input_1 = {'x': {1: torch.tensor(1)}}
m = TestModule()
m_traced = torch.jit.trace(m, (input_1, ))
input_2 = {'x': {1: torch.tensor(2)}}
self.assertEqual(m_traced(input_2), torch.tensor(2))
def test_input_dict_checkTrace_mut(self):
def test(d):
d['x'].tanh_()
return d['x']
inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_dict_unify(self):
def test(d):
return d['int'], d['float']
inputs = {'int': torch.ones((2, 2), dtype=torch.int32),
'float': torch.ones((2, 2), dtype=torch.float32)}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_tuple_of_dicts(self):
def test(t):
d = t[0]
return d['x']['y']
inputs = {'x': {'y': torch.rand(2, 3)}}
self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
def test_input_dict_of_dicts(self):
def test(d):
return d['x']['y']
nested_input = {'y': torch.rand(2, 3)}
unified_nested = {'y': torch.rand(3, 2)}
inputs = {'x': nested_input, 'force_unify': unified_nested}
self.checkTrace(test, (inputs,), allow_unused=True)
def test_input_dict_of_lists(self):
def test(d):
return d['x'][0]
inputs = {'x': [torch.rand(3, 2)]}
self.checkTrace(test, (inputs,))
def test_input_list_toplevel_flatten(self):
def test(t1, t2):
return torch.add(t1, t2)
inputs = [torch.ones(2, 2), torch.rand(2, 2)]
self.checkTrace(test, inputs)
def test_input_list_toplevel_flatten_direct(self):
class Test(torch.nn.Module):
def forward(self, t1, t2):
return torch.add(t1, t2)
inputs = [torch.ones(2, 2), torch.rand(2, 2)]
torch.jit.trace(Test(), inputs)
def test_input_list_of_tuples(self):
def test(data):
return data[0][0]
inputs = [(torch.ones(2, 2),)]
self.checkTrace(test, (inputs,))
def test_input_dict_empty_list(self):
def test(d):
pass
inputs = {1: []}
with self.assertRaisesRegex(RuntimeError, 'List trace'):
self.checkTrace(test, (inputs,))
def test_input_list_mixed_type(self):
def test(d):
pass
inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))]
with self.assertRaisesRegex(RuntimeError, 'consistent'):
self.checkTrace(test, (inputs,))
def test_conv(self):
x = torch.ones(20, 16, 50, 40)
g, outputs, inputs = torch.jit._get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
def test_max_pool(self):
x = torch.rand(20, 16, 10, 10)
def max_pool2d(x):
return F.max_pool2d(x, 2) + 2
trace = torch.jit.trace(max_pool2d, (x))
graph = trace.graph_for(x)
FileCheck().check("aten::max_pool2d(").run(graph)
self.assertEqual(max_pool2d(x), trace(x))
def test_nested_inplace(self):
x = torch.randn(2, 2)
g, outputs, inputs = torch.jit._get_trace_graph(
lambda x: F.threshold(x, 0, 0, inplace=True), (x, ), return_inputs=True)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
FileCheck().check("threshold_").run(str(g))
self.assertExportImport(g, (x,))
def test_repeated_input(self):
def fn(a, b):
return a + b
ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
inputs = set(ge.graph.inputs())
self.assertTrue(len(inputs) == 3)
def test_repeated_output(self):
def fn(a, b):
z = a + b
return z, z
ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
tuple_output = list(ge.graph.outputs())[0]
tuple_inputs = list(tuple_output.node().inputs())
self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
def test_inplace_copy(self):
x = torch.randn(4, 4, requires_grad=True)
def f(x):
out = torch.zeros(x.size())
out.copy_(x)
return out
g, outputs, inputs = torch.jit._get_trace_graph(f, (x, ), return_inputs=True)
self.run_pass('dce', g)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
self.assertExportImport(g, (x,))
def test_inplace_copy_force_outplace(self):
x = torch.randn(4, 4, requires_grad=True)
def f(x):
out = torch.zeros(x.size())
out.copy_(x)
return out
g, outputs, inputs = torch.jit._get_trace_graph(
f, (x, ), return_inputs=True, _force_outplace=True)
self.run_pass('dce', g)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
self.assertExportImport(g, (x,))
FileCheck().check("expand_as").run(str(g))
def test_shared_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = self.a = nn.Parameter(torch.randn(2, 2))
def forward(self, x):
return x * self.a + self.b
m = MyModule()
g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),))
self.run_pass('dce', g)
self.assertEqual(len(list(g.inputs())), 2)
FileCheck().check("mul").check("add").run(str(g))
def test_trace_c10_ops(self):
try:
_ = torch.ops._caffe2.GenerateProposals
except AttributeError:
self.skipTest("Skip the test since c2 ops are not registered.")
class MyModel(torch.nn.Module):
def forward(self, scores, bbox_deltas, im_info, anchors):
a, b = torch.ops._caffe2.GenerateProposals(
(scores), (bbox_deltas), (im_info), (anchors),
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
)
return a, b
model = MyModel()
A = 4
H = 10
W = 8
img_count = 3
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
dtype=torch.float32)
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
im_info = torch.ones(img_count, 3, dtype=torch.float32)
anchors = torch.ones(A, 4, dtype=torch.float32)
inputs = (scores, bbox_deltas, im_info, anchors)
traced_model = torch.jit.trace(model, inputs)
self.assertEqual(traced_model(*inputs), model(*inputs))
self.assertExportImportModule(traced_model, (scores, bbox_deltas, im_info, anchors))
def run_ge_tests(self, optimize, use_cuda):
with enable_profiling_mode_for_profiling_tests():
with torch.jit.optimized_execution(optimize):
def rand(*args):
t = torch.rand(*args).float()
if use_cuda:
t = t.npu()
return t
self.checkTrace(lambda a, b: a * b + b,
[rand(1), rand(1)], [rand(2, 3), rand(2, 3)])
self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)])
def foo(a):
t = a * a
return t * t, 4 * t
self.checkTrace(foo, [rand(1)])
self.checkTrace(
lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True)
self.checkTrace(foo, [rand(1)], drop=1)
self.checkTrace(lambda a, b: a * b /
(a - 2 * b) + b, [rand(1), rand(1)])
def test_ge_unoptimized(self):
self.run_ge_tests(False, False)
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_ge_optimized(self):
with enable_profiling_mode_for_profiling_tests():
self.run_ge_tests(True, False)
@unittest.skipIf(not RUN_CUDA, "requires NPU")
def test_ge_cuda(self):
self.run_ge_tests(True, True)
def test_ge(self):
def foo(a, b):
return a * b / (a - b) + b
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True)
r, = ge(a, b)
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
l2 = (da * db + db * db)
g2result = torch.autograd.grad(l2, [da, db])
r = foo(a, b)
da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
self.assertEqual(da, da2)
self.assertEqual(db, db2)
l3 = (da2 * db2 + db2 * db2)
g2result2 = torch.autograd.grad(l3, [da2, db2])
self.assertEqual(g2result, g2result2)
def test_trace_annotation(self):
@_trace(torch.rand(1))
def foo(a):
return a + a + a
x = torch.randn(5, 5)
self.assertEqual(foo(x), x + x + x)
@unittest.skipIf(not RUN_CUDA, "calls .npu()")
@with_tf32_off
def test_traced_module_cuda(self):
class Model(nn.Module):
def __init__(self, num_features, num_layers):
super().__init__()
self.num_layers = num_layers
layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
for _ in range(num_layers)]
self.submodule = nn.Sequential(*chain(*layers))
def forward(self, x):
for i in range(self.num_layers):
x = self.submodule[i](x) + x
return x
model = Model(5, 3)
x = torch.randn(2, 5)
traced_model = torch.jit.trace(model, x)
model.__repr__()
linear_submodule = next(iter(traced_model.submodule._modules.values()))
with self.assertRaises(AttributeError):
linear_submodule.in_features
linear_submodule.weight
linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
with self.assertRaises(RuntimeError):
del linear_submodule.weight
with self.assertRaises(RuntimeError):
linear_submodule(x)
linear_submodule.npu()
traced_model.float().npu()
cuda_out = traced_model(x.float().npu())
traced_model.cpu()
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.to('npu')
cuda_out = traced_model(x.float().npu())
traced_model.to('cpu')
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.to(torch.get_default_dtype())
state = {k: v.clone() for k, v in traced_model.state_dict().items()}
new_state = {k: v.clone().fill_(1) for k, v in state.items()}
out = traced_model(x)
traced_model.load_state_dict(new_state)
out_ones = traced_model(x)
traced_model.load_state_dict(state)
out_state = traced_model(x)
self.assertEqual(out, out_state)
self.assertNotEqual(out, out_ones)
@unittest.skipIf(not RUN_CUDA, "uses npu")
def test_type_same_device(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
def forward(self, x=None):
h = x.type(self.dtype)
return h
a = Model()
b = torch.jit.trace(a, example_inputs=(torch.ones([1], device=torch.device("npu")),))
FileCheck().check_not("device").run(b.code)
def test_export_no_reorder(self):
def func(a, b):
return a * b / (a - 2 * b) + b
recording_inputs = [torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=True),
torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=True)]
ge1 = torch.jit.trace(func, recording_inputs)
ge2 = self.getExportImportCopy(ge1)
outputs_ge1 = ge1(*recording_inputs)
outputs_ge2 = ge2(*recording_inputs)
grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
self.assertTrue(outputs_ge1 == outputs_ge2)
self.assertTrue(grad_ge1 == grad_ge2)
def test_python_function(self):
class MyFn(Function):
@staticmethod
def forward(ctx, x):
return x + 1
@staticmethod
def backward(ctx, grad_output):
return grad_output
@_trace(torch.zeros(2))
def fn(x):
return MyFn.apply(x + 2) + 3
x = torch.tensor([1., 2., 3.])
y = torch.randn(2, 2, requires_grad=True)
fn(x)
fn(y)
def test_python_function_tup(self):
class MyFn(Function):
@staticmethod
def forward(ctx, x):
return x + 1, x - 1
@staticmethod
def backward(ctx, grad_output):
return grad_output, grad_output
@_trace(torch.zeros(2))
def fn(x):
a, b = MyFn.apply(x + 2)
return a + b + 3
x = torch.tensor([1., 2., 3.])
y = torch.randn(2, 2, requires_grad=True)
fn(x)
fn(y)
def test_trace_detach(self):
def foo(x, w):
return torch.matmul(x, w).detach()
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
FileCheck().check("matmul").check("detach").run(str(traced.graph))
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
traced_result = traced(x, w)
self.assertEqual(foo(x, w), traced_result)
self.assertFalse(traced_result.requires_grad)
self.assertIsNone(traced_result.grad_fn)
def test_trace_detach_redispatch(self):
def foo(x, w):
y = torch.matmul(x, w)
assert y.requires_grad
y = y.detach()
assert not y.requires_grad
return y
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
torch.jit.trace(foo, (x, w), check_trace=False)
def test_trace_detach_inplace(self):
def foo(x, w):
y = torch.matmul(x, w)
y.detach_()
return y
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
FileCheck().check("matmul").check("detach(").run(str(traced.graph))
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
traced_result = traced(x, w)
self.assertEqual(foo(x, w), traced_result)
self.assertFalse(traced_result.requires_grad)
self.assertIsNone(traced_result.grad_fn)
def test_trace_detach_inplace_redispatch(self):
def foo(x, w):
y = torch.matmul(x, w)
assert y.requires_grad
y.detach_()
assert not y.requires_grad
return y
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
torch.jit.trace(foo, (x, w), check_trace=False)
def test_trace_slice_full_dim(self):
def foo(x):
return x[0:5, 0] + 1.0
traced = torch.jit.trace(foo, (torch.rand(5, 4),))
test_x = torch.rand(6, 3)
self.assertEqual(foo(test_x), traced(test_x))
def test_trace_dict_input(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Foo()
def forward(self, a, b):
return self.foo({'a': a, 'b': b})['a']
class Foo(torch.nn.Module):
def forward(self, x):
return {'a': x['a'] * x['b']}
x = (torch.rand(3), torch.rand(3))
model = Bar()
self.checkTrace(model, x)
def test_trace_dict_output(self):
class TraceDictStrTensor(torch.nn.Module):
def forward(self, a, b):
return {'a': a, 'b': b}
class TraceDictTensorTensor(torch.nn.Module):
def forward(self, a, b):
return {a: b, b: a}
x = (torch.rand(3), torch.rand(3))
with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"):
torch.jit.trace(TraceDictStrTensor(), x)
traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False)
self.assertEqual(traced_dict_str_mod(*x), {'a': x[0], 'b': x[1]})
traced_dict_tensor_mod = torch.jit.trace(TraceDictTensorTensor(), x, strict=False)
self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]})
def test_trace_with_tensor_list_output(self):
def f():
return [torch.zeros(1), torch.zeros(5)]
with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
torch.jit.trace(f, [])
traced_non_strict_f = torch.jit.trace(f, [], strict=False)
self.assertEqual(traced_non_strict_f(), f())
def test_trace_with_number_list_output(self):
def f():
return [1, 5]
with self.assertRaisesRegex(RuntimeError, r"Only tensors.+can be output from traced functions"):
traced_f = torch.jit.trace(f, [])
def test_trace_with_nested_tensor_list_output(self):
def f():
return [[torch.zeros(1)], [torch.zeros(5)]]
with self.assertRaisesRegex(RuntimeError, r"Only tensors.+can be output from traced functions"):
traced_f = torch.jit.trace(f, [])
def test_trace_with_nested_strided_tensor_output(self):
@torch.jit.script
def nt_construct(values, kv_lengths):
kv_lengths_list: List[int] = kv_lengths.tolist()
return torch._nested_tensor_from_tensor_list(
list(values.split(kv_lengths_list, dim=0)), None, None, None, None
)
def f(x, offsets):
kv_lengths = offsets[1:] - offsets[:-1]
return nt_construct(x, kv_lengths).cos()
x = torch.rand(5, 4)
offsets = torch.tensor([0, 2, 5])
ref = f(x, offsets)
f_t = torch.jit.trace(f, (x, offsets))
res = f_t(x, offsets)
self.assertEqual(ref, res)
x2 = torch.rand((8, 4))
offsets2 = torch.tensor([0, 2, 4, 8])
self.assertEqual(f(x2, offsets2), f_t(x2, offsets2))
def test_trace_variable_instantiation(self):
def random_foo(x):
return Variable(Variable(x) + 1.0)
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
x = torch.rand(5, 6)
self.assertEqual(random_foo(x), random_foo_traced(x))
def test_trace_slice_expr_complete_type(self):
def random_foo(x):
return x + 1.0
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
@torch.jit.script
def random_bar(x):
return random_foo_traced(x)[0:1]
x = torch.rand(3, 4)
self.assertEqual(random_bar(x), (x + 1)[0:1])
def test_trace_inline_shape(self):
@torch.jit.script
def tensor_size(x: torch.Tensor) -> torch.Tensor:
return torch.tensor([x.size()[0]])
self.assertEqual(
tensor_size(torch.rand(15,)),
torch.tensor([15])
)
traced_tensor_size = torch.jit.trace(tensor_size, torch.rand(7,))
self.assertEqual(
traced_tensor_size(torch.rand(15,)),
torch.tensor([15])
)
@torch.jit.script
def use_device(x):
return torch.zeros_like(x, device=x.device)
def foo(x):
return use_device(x)
traced_tensor_size = torch.jit.trace(foo, torch.rand(7,))
self.run_pass('inline', traced_tensor_size.graph)
FileCheck().check("prim::device").run(traced_tensor_size.graph)
def test_trace_save(self):
def fn(x):
return x + 2
def check(func):
with TemporaryFileName() as fname:
func.save(fname)
loaded = torch.jit.load(fname)
input_ = torch.randn(2, 2)
self.assertEqual(func(input_), loaded(input_))
out = torch.jit.trace(fn, (torch.ones(2, 2),))
check(out)
def test_trace_optioanl_dtype(self):
class Test(torch.nn.Module):
def forward(self):
return torch.arange(5)
traced = torch.jit.trace(Test(), ())
torch.allclose(traced(), Test()())
def test_trace_save_load_copy(self):
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224))
buffer = io.BytesIO()
torch.jit.save(traced, buffer)
buffer.seek(0)
loaded = torch.jit.load(buffer)
copy.copy(loaded)
copy.deepcopy(loaded)
def test_trace_export_fns(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 3
@torch.jit.export
def __getstate__(self):
return (3, self.training)
@torch.jit.export
def __setstate__(self, state):
self.a = state[0]
self.training = state[1]
def forward(self, x):
return x + self.a
f = Foo()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ['__getstate__', '__setstate__']
def check(mod):
self.assertTrue(all(name in mod._c._method_names() for name in expected_names))
check(traced)
imported = self.getExportImportCopy(traced)
check(imported)
def test_trace_export_fns_recursive(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 3
@torch.jit.export
def __getstate__(self):
return (3, self.training)
@torch.jit.export
def __setstate__(self, state):
self.a = state[0]
self.training = state[1]
def forward(self, x):
return x + self.a
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Foo()
def forward(self, x):
return self.foo(x)
f = Wrapper()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ['__getstate__', '__setstate__']
def check(mod):
self.assertTrue(all(name in mod._c._method_names() for name in expected_names))
check(traced.foo)
imported = self.getExportImportCopy(traced)
check(imported.foo)
class Bar(nn.Module):
@torch.jit.export
def addTwo(self, x):
return x + 2
def forward(self, input_):
return (lambda a: a + 1)(input_)
class WrapperExports(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
@torch.jit.export
def addOne(self, x):
return x + 1
def forward(self, x):
return self.bar(x)
f = WrapperExports()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ['addOne']
check(traced)
def test_trace_autograd_function(self):
class TestFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return torch.neg(input_)
@staticmethod
def backward(ctx, grad_output):
return torch.neg(grad_output)
class TracedModule(torch.nn.Module):
def forward(self, x):
return torch.relu(TestFunc.apply(x))
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.tm = TracedModule()
def forward(self, x):
return self.tm(x)
traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),))
def test_trace_multi_output_function(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return y, x
@staticmethod
def backward(ctx, du, dv):
return dv, du
class Bar(torch.nn.Module):
def forward(self, x, y):
x = x.relu()
y = y.relu()
z = Foo.apply(x, y)
return z
x = torch.rand(3, 2, dtype=torch.double)
y = torch.rand(1, 2, dtype=torch.double)
traced = torch.jit.trace(Bar(), (x, y))
print(traced.graph)
schema = '(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), '\
'Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) '\
'= ^Foo'
FileCheck().check(schema).run(traced.graph)
u, v = traced(x, y)
self.assertEqual(u, y)
self.assertEqual(v, x)
def test_interpolate_trace(self):
class test(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
def forward(self, x):
y = self.conv(x)
w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=3)
return w
f = test()
g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
x = torch.zeros(1, 1, 14, 14)
self.assertEqual(g(x), f(x))
@_tmp_donotuse_dont_inline_everything
def test_trace_optional(self):
@torch.jit.script
def test(x: Optional[Tensor]):
if x is None:
return torch.zeros(1)
else:
return x
def test_none():
return test(None)
def test_tensor():
return test(torch.zeros(2))
f_none = torch.jit.trace(test_none, ())
self.assertEqual(f_none(), torch.zeros(1))
f_tensor = torch.jit.trace(test_tensor, ())
self.assertEqual(f_tensor(), torch.zeros(2))
graph = f_tensor.graph
FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph)
def test_trace_nested_datatypes(self):
@torch.jit.script
def foo(x):
return [[x + 1, x - 1], [x + 2, x - 2]]
def bar(x):
list_stuff = foo(x)
return list_stuff[0][0], list_stuff[1][1]
traced = torch.jit.trace(bar, torch.rand(3, 4))
x = torch.rand(5, 6)
self.assertEqual(bar(x), traced(x))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_fn_from_traced_module(self):
@_trace(torch.rand(3, 4))
def traced_fn(x):
return torch.neg(x)
class TracedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
def forward(self, x):
return traced_fn(torch.mm(x, self.param))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm") \
.check('name="traced_fn"') \
.check_next("prim::CallFunction") \
.run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_module_from_traced_module(self):
class TracedModule1(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(5, 7))
def forward(self, x):
return torch.mm(x, self.param)
class TracedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
def forward(self, x):
return self.mod(torch.mm(x, self.param)) + 1.0
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check("prim::CallMethod").check_same("forward").check("aten::add").run(str(tm.graph))
def test_index_put_trace_with_view(self):
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target
FileCheck().check("aten::view").check("index_put_").run(str(test_index_put.graph))
def test_index_put_trace_without_view(self):
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target
FileCheck().check_not("aten::view").check("index_put_").run(str(test_index_put.graph))
@suppress_warnings
def test_trace_checker_dot_data(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
r'across invocations'):
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
y = x.data
return x + y
@suppress_warnings
def test_trace_checker_control_flow(self):
def foo(x):
for _ in range(x.size(0)):
x = torch.neg(x)
return x
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
@suppress_warnings
def test_trace_checker_memoization(self):
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
def foo(x):
if not hasattr(foo, 'cache'):
foo.cache = torch.neg(x)
return x + foo.cache
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def test_trace_checker_slice_lhs(self):
def foo(x):
for i in range(3):
x[i, :] = torch.zeros(4)
return x
self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)
def test_trace_checker_inplace_on_view(self):
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
with self.assertWarnsRegex(torch.jit.TracerWarning,
'Output nr 1. of the traced function does not match the '
'corresponding output of the Python function'):
torch.jit.trace(foo,
torch.rand(3, 4),
check_inputs=[torch.rand(5, 6)],
_force_outplace=True)
def test_lhs_index_fails(self):
def foo(x):
x[0, 1] = 4
return x
with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
def test_lhs_index_trivial(self):
def foo(y, x):
y[...] = x
return y
self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
def test_inplace_warn(self):
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
@suppress_warnings
def test_trace_checker_dropout_train(self):
def foo(x):
return torch.dropout(x, p=0.5, train=True)
with self.assertWarnsRegex(torch.jit.TracerWarning,
'Output nr 1. of the traced function does not match the '
'corresponding output of the Python function'):
torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
with self.assertWarnsRegex(torch.jit.TracerWarning,
'Trace had nondeterministic nodes'):
torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
def test_trace_checker_dropout_notrain(self):
input_ = torch.rand(3, 4)
@_trace(input_)
def foo(x):
return torch.dropout(x, p=0.5, train=False)
self.assertEqual(foo(input_), input_)
def test_trace_contiguous(self):
def foo(x):
return x[:, :, ::2].contiguous().view(12)
x = torch.rand(2, 3, 4)
traced = torch.jit.trace(foo, (x,))
y = traced(x)
self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
def test_trace_contiguous_short_circuit(self):
def foo(x):
return x.contiguous()
x = torch.rand(2, 3, 4)
traced = torch.jit.trace(foo, (x,))
FileCheck().check("aten::contiguous").run(str(traced.graph))
def test_trace_inverse(self):
def foo(x):
return ~x
foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8))
eg = torch.zeros(3, dtype=torch.uint8)
self.assertEqual(foo_traced(eg), foo(eg))
def test_trace_modulelist(self):
class MySubmod(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
class MyMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList([
MySubmod(),
MySubmod()
])
def forward(self, x):
for mod in self.ml:
x = mod(x)
return x
traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
def test_trace_fork_join_and_module(self):
class MySubmod(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x), torch.neg(x)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList([
MySubmod() for i in range(2)
])
def forward(self, x):
futs = []
for i in range(2):
futs.append(torch.jit._fork(self.ml[i], x))
results = []
for i in range(2):
results.append(torch.jit._wait(futs[i])[0])
return torch.stack(results)
m = Mod()
traced = torch.jit.trace(m, torch.rand(3, 4))
def test_trace_invert_module_hierarchy(self):
class MySubmod(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x), torch.neg(x)
class MyFunctionalMod(torch.nn.Module):
def forward(self, x, submod):
return submod(x)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.sm = MySubmod()
self.fm = MyFunctionalMod()
def forward(self, x):
return self.fm(x, self.sm)
torch.jit.trace(Mod(), (torch.rand(3, 4),))
@skipIfCrossRef
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3
quick_brown_fox = torch.neg(baz)
for _ in range(20):
yeet = quick_brown_fox - 3.14
return yeet
traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
graph_str = str(traced.graph)
assert 'bar' in graph_str
assert 'baz' in graph_str
assert 'quick_brown_fox' in graph_str
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tracing_hooks(self):
class Net(nn.Module):
def forward(self, x):
return x + x
def test_hook(is_post_hook, hook, fc):
n = Net()
if is_post_hook:
n.register_forward_hook(hook)
else:
n.register_forward_pre_hook(hook)
module = torch.jit.trace(n, (torch.tensor(1.0),))
eager_input = torch.tensor(1.0)
eager_out = n(eager_input)
fc.run(module.forward.graph)
input_ = torch.tensor(1.0)
output = module(input_)
self.assertEqual(input_, eager_input)
self.assertEqual(output, eager_out)
def hook_no_return(mod, input_, output):
input_[0].add_(1)
output.sub_(1)
fc = FileCheck().check("add(").check("add_(").check("sub_(")
test_hook(True, hook_no_return, fc)
def hook_return(mod, input_, output):
input_[0].add_(1)
return output - 3
fc = FileCheck().check("add(").check("add_(").check("sub(")
test_hook(True, hook_return, fc)
b = torch.tensor(3.0)
def captured_hook(mod, input_, output):
return output - b
fc = FileCheck().check("add(").check("sub(")
test_hook(True, captured_hook, fc)
def pre_hook_no_ret(mod, input_):
input_[0].add_(3)
fc = FileCheck().check("add_(").check("add(")
test_hook(False, pre_hook_no_ret, fc)
def pre_hook_ret(mod, input_):
return input_[0] - 4
fc = FileCheck().check("sub(").check("add(")
test_hook(False, pre_hook_ret, fc)
def test_tracing_backward_hook_error(self):
class Net(nn.Module):
def forward(self, x):
return x + x
n = Net()
def backward_hook(module, grad_input, grad_output):
pass
n.register_backward_hook(backward_hook)
with self.assertRaisesRegex(Exception, "backward hooks assigned"):
torch.jit.trace(n, (torch.tensor(1.0),))
def test_tracing_multiple_methods(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
def weighted_kernel_sum(self, weight):
return weight * self.conv.weight
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
inputs = {'forward': example_forward_input, 'weighted_kernel_sum': example_weight}
n = Net()
module = torch.jit.trace_module(n, inputs)
check_inputs = []
for i in range(2):
check_weight = torch.rand(1, 1, 3, 3)
check_forward_input = torch.rand(1, 1, 3, 3)
check_inputs.append({'forward' : check_forward_input, 'weighted_kernel_sum' : check_weight})
module = torch.jit.trace_module(n, inputs, check_trace=True, check_inputs=check_inputs)
self.assertTrue(module._c._has_method("forward"))
self.assertTrue(module._c._has_method("weighted_kernel_sum"))
module = torch.jit.trace(n.forward, example_forward_input)
module = torch.jit.trace(n.forward, example_forward_input, check_trace=True, check_inputs=[example_forward_input])
with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"):
module = torch.jit.trace(n.weighted_kernel_sum, inputs)
def test_tensor_with_grad_as_constant(self):
param = torch.randn(3).requires_grad_()
x = torch.randn(3)
def f(x):
return x + param
with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"):
torch.jit.trace(f, x)
def test_non_tensor_tracing(self):
def f(x):
return x + param
with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"):
torch.jit.trace(f, (1,))
def test_trace_skip_none_submodule(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = torch.nn.Linear(3, 4)
self.submod = None
def forward(self, inputs):
return inputs
m = TestModule()
tm = torch.jit.trace(m, torch.tensor(1.))
self.assertFalse(hasattr(tm, "submod"))
def test_trace_with_conditional_property(self):
class Net(nn.Module):
def __init__(self, attr=None):
super().__init__()
if attr is not None:
self._attr = attr
self.attr_name = '_attr'
@property
def attr(self):
return getattr(self, self.attr_name)
def forward(self, x):
return x
x = torch.ones(1)
torch.jit.trace(Net(), x)
def test_trace_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
return first_arg + second_arg
traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_fn.graph))
def test_trace_partial_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
return first_arg + second_arg
traced_fn = torch.jit.trace(fn, (torch.ones(1),))
FileCheck().check("first_arg").check_not("second_arg") \
.run(str(traced_fn.graph))
def test_trace_module_argument_names_captured(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
return self.conv(first_arg) + second_arg
m = TestModule()
example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))
traced_module_forward = torch.jit.trace(m.forward, example_input)
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_module_forward.graph))
traced_module = torch.jit.trace(m, example_input)
FileCheck().check("first_arg").check_next("second_arg") \
.run(str(traced_module.graph))
def test_trace_checking_with_deprecated_name(self):
class MyClass(torch.nn.Module):
def __init__(self):
super(MyClass, self).__init__()
def forward(self, x, y, **deprecated_arguments):
if len(deprecated_arguments) > 0:
raise RuntimeError(f"Got unexpected arguments: {deprecated_arguments}")
return x + y
model = MyClass()
m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
m3 = torch.jit.trace(model, example_kwarg_inputs={'x': torch.ones(1), "y": torch.ones(1)}, strict=False)
def test_trace_with_tuple_tensor(self):
class MyClass(torch.nn.Module):
def __init__(self):
super(MyClass, self).__init__()
def forward(self, x, y):
return x + y[0] + y[1]
model = MyClass()
traced_model = torch.jit.trace(model, (torch.ones(1), (torch.ones(1), torch.ones(1))))
input_dict = {"x": torch.tensor([2, 3]), "y": (torch.tensor([5, 6]), torch.tensor([7, 8]))}
self.assertEqual(model(**input_dict), traced_model(**input_dict))
traced_model = torch.jit.trace(model, example_kwarg_inputs={
'x': torch.ones(1), "y": (torch.ones(1), torch.ones(1))})
self.assertEqual(model(**input_dict), traced_model(**input_dict))
def test_trace_no_duplicated_lifted_input_output(self):
class Normalize(nn.Module):
def __init__(self):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=32)
def forward(self, x, y):
if y is None:
y = x
else:
y = self.norm(y)
y = y * 2
return y
class G(nn.Module):
def __init__(self):
super().__init__()
self.norm = Normalize()
def forward(self, x):
A = self.norm(x, None)
B = F.relu(A)
return A, B
class Net(nn.Module):
def __init__(self):
super().__init__()
self.g = G()
self.norm_1 = Normalize()
def forward(self, x):
hs = self.g(x)
A, B = hs
h = self.norm_1(B, A)
return h
net = Net()
net = net.eval()
x = torch.randn(1, 32, 16, 16)
traced = torch.jit.trace(net, x)
FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
@torch.jit.script
def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
return x[0] + x[1]
@torch.jit.script
def func2(x: List[Tensor]) -> Tensor:
return x[0] + x[1]
a = torch.randn(5)
b = torch.randn(5)
self.checkTrace(func1, ((a, b),))
self.checkTrace(func2, ((a, b),))
@torch.jit.script
def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor:
hw = x.shape[2:4]
return F.interpolate(x, hw, mode=method, align_corners=align_corners)
inp = torch.rand(1, 3, 6, 6)
self.checkTrace(func3, (inp,))
@torch.jit.script
def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
if len(a) == 2:
return x + 2
else:
return x
def test_trace_mixed_by_script_with_dict_output(self):
@torch.jit.script
def return_dict(input_: torch.Tensor) -> Dict[str, torch.Tensor]:
return {"foo": input_ + 1}
class TraceModule(torch.nn.Module):
def forward(self, input_):
dict_ = return_dict(input_)
return dict_["foo"] + dict_["foo"]
x = torch.ones(1)
tm = torch.jit.trace(TraceModule(), x)
self.assertEqual(tm(x), x + 1 + x + 1)
def test_trace_of_script(self):
@torch.jit.script
def foo(a, c):
b = 0.0
if bool(a == 0.0):
b = 1.0
return b + c
a = torch.ones(1, dtype=torch.float)
@_trace(torch.zeros(1, dtype=torch.float))
def use(b):
return foo(b - 1.0, a) + 1.0
self.assertTrue("Dynamic" not in str(use.graph))
self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
def test_trace_with_size(self):
@_trace(torch.zeros(1, 1))
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
y = int(foo(x))
if 1 == 1:
y = 7
return y + 1
self.assertEqual(8, bar(torch.ones(1, 1)))
def test_tracing_slicing(self):
@_trace(torch.zeros(10))
def foo_trace(x):
return x[-5:-3]
@torch.jit.script
def foo_script(x):
return x[-5:-3]
def foo(x):
return x[-5:-3]
a = torch.arange(0, 8)
b = torch.arange(0, 20)
self.assertEqual(foo_trace(a), foo_script(a))
self.assertEqual(foo_trace(a), foo(a))
self.assertNotEqual(foo_trace(a), foo_trace(b))
def test_tracing_indexing(self):
@_trace(torch.zeros(10))
def foo_trace(x):
return x[-2]
@torch.jit.script
def foo_script(x):
return x[-2]
def foo(x):
return x[-2]
a = torch.arange(0, 8)
b = torch.arange(0, 20)
self.assertEqual(foo_script(a), foo_trace(a))
self.assertEqual(foo_trace(a), foo(a))
self.assertNotEqual(foo_trace(a), foo_trace(b))
def test_trace_hierarchy(self):
class AnotherScriptMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
@torch.jit.script_method
def bar(self):
return torch.zeros(4, 5)
class SomeScriptMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.asm = AnotherScriptMod()
@torch.jit.script_method
def foo(self):
return torch.zeros(3, 4)
@torch.jit.script_method
def bar(self):
return torch.zeros(4, 3)
class TraceMe(torch.nn.Module):
def __init__(self):
super().__init__()
self.ssm = SomeScriptMod()
def forward(self, x):
return self.ssm.bar() + x
orig = TraceMe()
traced = torch.jit.trace(orig, (torch.rand(4, 3),))
self.assertTrue(traced.ssm._c._has_method('foo'))
self.assertTrue(hasattr(traced.ssm, 'foo'))
imported = self.getExportImportCopy(traced)
self.assertTrue(imported.ssm._c._has_method('foo'))
self.assertTrue(hasattr(imported.ssm, 'foo'))
self.assertTrue(imported.ssm.asm._c._has_method('bar'))
self.assertTrue(hasattr(imported.ssm.asm, 'bar'))
self.assertTrue(hasattr(imported.ssm.asm, 'param'))
def test_trace_parameter(self):
class Param(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))
def forward(self, x):
return x
class M3(torch.jit.ScriptModule):
def __init__(self, model):
super().__init__()
self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
@torch.jit.script_method
def forward(self, x):
return self.traced(x)
class M2(nn.Module):
def __init__(self, model):
super().__init__()
self.module = M3(model)
def forward(self, x):
return self.module(x)
class M1(torch.jit.ScriptModule):
def __init__(self, model):
super().__init__()
self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
@torch.jit.script_method
def forward(self, x):
return self.traced(x)
with torch.jit.optimized_execution(False):
module = M1(Param())
f = io.BytesIO()
torch.jit.save(module, f)
@_tmp_donotuse_dont_inline_everything
def test_call_script_fn_from_traced_module(self):
@torch.jit.script
def scripted_fn(x):
return torch.neg(x)
class TracedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
def forward(self, x):
return scripted_fn(torch.mm(x, self.param))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check("name=\"scripted_fn\"").check("prim::CallFunction").run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_script_module_from_traced_module(self):
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
@torch.jit.script_method
def forward(self, x):
return torch.mm(x, self.param_foo)
class TracedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
self.mod = ScriptMod()
def forward(self, x):
return self.mod(torch.mm(x, self.param)) + 1.0
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check("prim::CallMethod").check_same("forward").check("aten::add").run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_fn_from_script_fn(self):
@_trace(torch.rand(3, 4))
def traced_fn(x):
return torch.neg(x)
@torch.jit.script
def script_fn(x):
return traced_fn(x) + 1
FileCheck().check("prim::CallFunction").check("aten::add").run(str(script_fn.graph))
def test_call_traced_mod_from_script_fn(self):
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
class TracedModule(torch.nn.Module):
def forward(self, x):
return torch.mm(x, torch.zeros(4, 3))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
@torch.jit.script
def script_fn(x):
return tm(x) + 1
@_tmp_donotuse_dont_inline_everything
def test_call_tracing_fn_from_script_module(self):
@_trace(torch.rand(3, 3))
def traced_fn(x):
return torch.neg(x)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
@torch.jit.script_method
def forward(self, x):
return traced_fn(torch.mm(x, self.param))
sm = ScriptMod()
FileCheck().check("aten::mm").check("prim::CallFunction").run(str(sm.forward.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_tracing_mod_from_script_module(self):
class TracedMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 5))
def forward(self, x):
return torch.mm(x, self.param)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
@torch.jit.script_method
def forward(self, x):
return self.tm(torch.mm(x, self.param))
sm = ScriptMod()
FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph))
def test_script_inline_trace_multiple_args(self):
class M(torch.nn.Module):
def forward(self, input_, input2_):
return input_ + input2_
class M2(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
@torch.jit.script_method
def forward(self, inp):
return self.m(inp, inp)
with torch.jit.optimized_execution(False):
m2 = M2()
m2(torch.zeros(4, 3))
def test_trace_dict_mix_script(self):
class testB(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor:
output = []
for j in feature_map.values():
output.append(self.linear(j[0]))
return torch.stack(output)
class testA(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = torch.jit.script(testB())
def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor:
feature_map = {}
for i, j in input_map.items():
feature_map[i] = [j[0]]
return self.b(feature_map)
input_map = {"1" : [torch.rand(2, 2), torch.rand(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]}
model = testA()
traced_model = torch.jit.trace(model, input_map)
new_input_map = {"1" : [torch.rand(2, 2), torch.randn(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]}
self.assertEqual(model(new_input_map), traced_model(new_input_map))
def test_trace_script_returning_complex_dict(self):
"""Tracing over a script function returning a dictionary should work.
The dictionary can should be able to contain other containers (like a tuple) recursively.
"""
class ReturnsDict(torch.nn.Module):
def forward(
self, id_score_list: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
v = id_score_list["1000"]
idx_keys = v[1] - 1500000
weights = v[2]
result = {
"1000": (v[0], idx_keys, weights)
}
return result
class ChecksDict(torch.nn.Module):
def forward(self, input_: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
v = input_["1000"]
return v[1] + 1
class TestModule(torch.nn.Module):
def __init__(self, checks_dict, returns_dict):
super().__init__()
self.checks_dict = checks_dict
self.returns_dict = returns_dict
def forward(self, input_: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
foo = self.returns_dict(input_)
return self.checks_dict(foo)
inputs1 = {
"1000": (
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
torch.tensor([])
)
}
inputs2 = {
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])
)
}
checks_dict = torch.jit.script(ChecksDict())
returns_dict = torch.jit.script(ReturnsDict())
eager_module = TestModule(checks_dict, returns_dict)
traced_module = torch.jit.trace(eager_module, inputs1)
self.assertEqual(traced_module(inputs1), eager_module(inputs1))
self.assertEqual(traced_module(inputs2), eager_module(inputs2))
def test_trace_returning_dict_with_tensor_tuples(self):
"""Tracing over a module returning a dictionary whose values are tuples of tensors
should work.
"""
class ReturnsDict(torch.nn.Module):
def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
x = 2 * k
y = 3 * v
result = {
"imakey": (x, y)
}
return result
class ReturnsBadDict(torch.nn.Module):
def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, float]]:
x = 2 * k
result = {
"imakey": (x, 1)
}
return result
mod = ReturnsDict()
traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False)
out = traced_module(torch.ones(1), torch.ones(1))
expected = {
"imakey": (torch.tensor([2.]), torch.tensor([3.]))
}
self.assertEqual(out, expected)
with self.assertRaisesRegex(RuntimeError, "cannot be understood by the tracer, only outputs matching"):
mod = ReturnsBadDict()
traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False)
def test_trace_linear(self):
m = torch.nn.Linear(20, 20)
inp = torch.rand([20, 20])
self.checkTrace(m, (inp,))
g = torch.jit.trace(m, (inp,)).graph
FileCheck().check("aten::linear").run(g)
def test_traced_module_implements_interface(self):
@torch.jit.interface
class TestModuleInterface(nn.Module):
def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
pass
make_global(TestModuleInterface)
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
return self.conv(first_arg) + second_arg
def fn_takes_interface(x: TestModuleInterface):
ones = torch.ones(1, 1, 3, 3)
return x.forward(ones, ones)
scripted_test_module = torch.jit.script(TestModule())
self.checkScript(fn_takes_interface, (scripted_test_module,))
def test_traced_module_contains_scripted_interface_types(self):
class LeafModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(19))
def forward(self, input_: torch.Tensor):
return input_ + self.weight
class LowerModuleImpl(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = LeafModule()
def forward(self, input_: torch.Tensor) -> torch.Tensor:
return self.leaf(input_)
@torch.jit.interface
class LowerModuleInterface(torch.nn.Module):
def forward(self, input_: torch.Tensor) -> torch.Tensor:
pass
class MiddleModule(torch.nn.Module):
lower: LowerModuleInterface
def __init__(self, feature_processor_modules=None):
super().__init__()
self.lower = LowerModuleImpl()
def forward(self, input_):
return self.lower(input_)
class WrapperModule(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.middle = m
def forward(self, input_):
return self.middle(input_)
class TopModule(torch.nn.Module):
def __init__(self):
super().__init__()
m = MiddleModule()
m = torch.jit.script(m)
self.sub1 = m
self.sub2 = WrapperModule(m)
def forward(self, input_: torch.Tensor):
return self.sub1(input_) + self.sub2(input_)
top = TopModule()
top_example_input = torch.ones(1)
torch.jit.trace(top, top_example_input)
def test_jit_trace_callfunction_return_shapes(self):
@torch.jit.script
def inner_fn(x):
return torch.cat((x, x))
def outer_fn(x, y):
return inner_fn(x + y).relu()
x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)]
fn_t = torch.jit.trace(outer_fn, (x, y))
FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph)
for n in fn_t.graph.nodes():
if n.kind() == "prim::CallFunction":
self.assertTrue(n.output().isCompleteTensor())