import os
from functools import wraps

import torch


class FusedLayerNormAffineFunction:
    @staticmethod
    def apply(input_, weight, bias, normalized_shape, eps):
        return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)

    @staticmethod
    def forward(*args, **kwargs):
        return FusedLayerNormAffineFunction.apply(*args, **kwargs)


class FastLayerNormFN:
    @staticmethod
    def apply(input_, weight, bias, eps):
        normalized_shape = torch.Size(weight.numel())
        return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)


def fused_layer_norm_affine(input_, weight, bias, normalized_shape, eps):
    return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)