import torch

from torch import Tensor
from torch.autograd.function import Function
from torch._dynamo.decorators import forbid_in_graph

__all__ = ["enable_deterministic_with_backward", "disable_deterministic_with_backward"]


class _DeterministicAlgorithmsBeginOp(Function):

    @staticmethod
    def forward(ctx, tensor):
        with torch.autograd.profiler.record_function("deterministic_algorithms_begin_op_forward"):
            torch.use_deterministic_algorithms(True)
        return tensor

    @staticmethod
    def backward(ctx, grad_outputs):
        with torch.autograd.profiler.record_function("deterministic_algorithms_begin_op_backward"):
            torch.use_deterministic_algorithms(False)
        return grad_outputs


class _DeterministicAlgorithmsEndOp(Function):
    @staticmethod
    def forward(ctx, tensor):
        with torch.autograd.profiler.record_function("deterministic_algorithms_end_op_forward"):
            torch.use_deterministic_algorithms(False)
        return tensor

    @staticmethod
    def backward(ctx, grad_outputs):
        with torch.autograd.profiler.record_function("deterministic_algorithms_end_op_backward"):
            torch.use_deterministic_algorithms(True)
        return grad_outputs


@forbid_in_graph
def enable_deterministic_with_backward(tensor: Tensor):
    return _DeterministicAlgorithmsBeginOp.apply(tensor)


@forbid_in_graph
def disable_deterministic_with_backward(tensor: Tensor):
    return _DeterministicAlgorithmsEndOp.apply(tensor)