import torch
import torch.nn as nn
from torch import Tensor


class MatmulAddLinear(nn.Linear):
    def forward(self, input_tensor: Tensor) -> Tensor:
        output = torch.matmul(input_tensor, self.weight.t())
        if self.bias is not None:
            output = output + self.bias
        return output