import torch
import torch.nn.functional as F
from torch import nn
class OjasRule(nn.Module):
"""
OjasRule class implementation.
"""
_sequence_axis: int
def __init__(self, sequence_axis: int = 0) -> None:
"""
Initialize an instance of OjasRule.
Args:
sequence_axis (int): axis of a sequence in the input tensor.
"""
super().__init__()
self._sequence_axis = sequence_axis
@torch.no_grad()
def forward(self, x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Forward of the OjasRule.
Args:
x (torch.Tensor): Input for the Linear layer with weights W.
y (torch.Tensor): Result of the forward the Linear layer: y=xW^t.
weight (torch.Tensor): Weights of the Linear layer.
Returns:
torch.Tensor: weights delta
"""
if self._sequence_axis == 0:
x = torch.permute(x, (1, 0, 2))
y = torch.permute(y, (1, 0, 2))
yx = y.mean(self._sequence_axis).transpose(0, 1) @ x.mean(self._sequence_axis)
yw = (y**2).mean(1).mean(0).unsqueeze(1) * weight
delta = yx - yw
return F.normalize(delta, dim=0)