import torch
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entrypoint for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
try:
if inp._base is None:
return inp
else:
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
except AttributeError as e:
return inp
def _kernel_make_viewless_tensor(inp, requires_grad):
"""Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
"""
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=False)
out.data = inp.data
out.requires_grad_(requires_grad)
return out