import torch
from torch.autograd.variable import Variable
def custom_backward(custom_output, custom_grad_output):
"""
Directly call C++ autograd engine
"""
if not custom_output.numel() == 1:
raise RuntimeError("output should be pseudo-'freed' in schedule, to optimize memory")
if not isinstance(custom_output, torch.Tensor):
raise RuntimeError("There has an error that output == '%s' ." % type(custom_output).__name__)
if not isinstance(custom_grad_output, (torch.Tensor, type(None))):
raise RuntimeError("There has an error that grad_output == '%s' ." % type(custom_grad_output).__name__)
if custom_grad_output is None:
if not custom_output.numel() == 1:
raise RuntimeError("There has an error that implicit grad requires scalar output .")
custom_grad_output = torch.ones_like(custom_output, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward(
accumulate_grad=True,
inputs=tuple(),
keep_graph=True,
tensors=(custom_output,),
grad_tensors=(custom_grad_output,),
create_graph=False,
allow_unreachable=True)