"""Identity operation."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Optional
import torch
from ...tensor import Quantizer
from ..op import BasicOperation, OperationContext
class Identity(BasicOperation):
"""Identity function."""
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
return grad_output, ()