import torch
from torch import nn
from torch.nn import functional as F


def fp32_layer_norm_forward(self, inputs: torch.Tensor) -> torch.Tensor:
    origin_dtype = inputs.dtype
    return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None,
                        self.bias.float() if self.bias is not None else None, self.eps).to(origin_dtype)


def fp32_silu_forward(self, inputs: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.silu(inputs.float(), inplace=self.inplace).to(inputs.dtype)


def fp32_gelu_forward(self, inputs: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.gelu(inputs.float(), approximate=self.approximate).to(inputs.dtype)


def replace_with_fp32_forwards():
    nn.GELU.forward = fp32_gelu_forward
    nn.SiLU.forward = fp32_silu_forward
    nn.LayerNorm.forward = fp32_layer_norm_forward