# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def make_viewless_tensor(inp, requires_grad, keep_graph):
    '''
    Entrypoint for creating viewless tensors.

    This method should be used, rather than calling 'MakeViewlessTensor'
    or '_kernel_make_viewless_tensor' directly. This method acts as a
    switch for determining if an autograd function or a regular method
    should be used to create the tensor.
    '''

    # return tensor asis, if not a 'view'
    try:
        if inp._base is None:
            return inp
        else:          
            if keep_graph:
                return MakeViewlessTensor.apply(inp, requires_grad)
            else:
                return _kernel_make_viewless_tensor(inp, requires_grad)
    except AttributeError as e:
        return inp


def _kernel_make_viewless_tensor(inp, requires_grad):
    """Make a viewless tensor.

    View tensors have the undesirable side-affect of retaining a reference
    to the originally-viewed tensor, even after manually setting the '.data'
    field. This method creates a new tensor that links to the old tensor's
    data, without linking the viewed tensor, referenced via the '._base'
    field.
    """
    out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=False)
    out.data = inp.data
    out.requires_grad_(requires_grad)
    return out