from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from diffusers.models.activations import GEGLU, ApproximateGELU
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
args = get_args()
config = core_transformer_config_from_args(args)
self.net.append(
tensor_parallel.RowParallelLinear(
inner_dim,
dim_out,
config=config,
init_method=config.init_method,
bias=bias,
input_is_parallel=True,
skip_bias_add=False
)
)
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
args = get_args()
config = core_transformer_config_from_args(args)
self.proj = tensor_parallel.ColumnParallelLinear(
dim_in,
dim_out,
config=config,
init_method=config.init_method,
bias=bias,
gather_output=False
)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states, _ = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states