from typing import Optional, Callable
from functools import partial
from itertools import repeat
import torch
import torch.nn as nn
from diffusers.models.activations import GELU, GEGLU, ApproximateGELU
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.common.linear import MatmulAddLinear
def modulate(norm_func, x, shift, scale):
dtype = x.dtype
x = norm_func(x.to(torch.float32)).to(dtype)
x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
x = x.to(dtype)
return x
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class MatmulAddFeedForward(nn.Module):
r"""
A feed-forward layer with MatmulAddLinear.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = MatmulAddLinear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, num_patch, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final, x, shift, scale)
x = self.linear(x)
return x
class T2IFinalLayer(nn.Module):
"""
The final layer of PixArt.
"""
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(
torch.randn(2, hidden_size) / hidden_size**0.5
)
self.out_channels = out_channels
self.d_t = d_t
self.d_s = d_s
def t_mask_select(self, x_mask, x, masked_x):
x = torch.lerp(masked_x, x, x_mask)
return x
def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
if T is None:
T = self.d_t
if S is None:
S = self.d_s
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
if x_mask is not None:
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
x = self.t_mask_select(x_mask, x, x_zero)
x = self.linear(x)
return x
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
enable_tensor_parallel: bool = False,
gather_tensor_parallel_output: bool = True,
zero_initialize: bool = True,
):
super().__init__()
self.enable_tensor_parallel = enable_tensor_parallel
self.gather_tensor_parallel_output = gather_tensor_parallel_output and enable_tensor_parallel
self.act = act_layer()
if self.enable_tensor_parallel:
args = get_args()
config = core_transformer_config_from_args(args)
config.sequence_parallel = False
self.linear = tensor_parallel.ColumnParallelLinear(
hidden_size,
factor * hidden_size,
bias=True,
config=config,
init_method=config.init_method,
gather_output=False
)
self.sequence_parallel = args.sequence_parallel
else:
self.linear = nn.Linear(
hidden_size, factor * hidden_size, bias=True
)
if zero_initialize:
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.enable_tensor_parallel:
output = self.linear(self.act(x))[0]
if self.gather_tensor_parallel_output:
if self.sequence_parallel:
output = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(output)
else:
output = tensor_parallel.mappings.gather_from_tensor_model_parallel_region(output)
return output
else:
return self.linear(self.act(x))
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
enable_tensor_parallel=False,
enable_tp_sp=True
):
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = bias if isinstance(bias, tuple) else tuple(repeat(bias, 2))
drop_probs = drop if isinstance(drop, tuple) else tuple(repeat(drop, 2))
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.enable_tensor_parallel = not use_conv and enable_tensor_parallel
if self.enable_tensor_parallel:
args = get_args()
config = core_transformer_config_from_args(args)
config.sequence_parallel = enable_tp_sp and args.sequence_parallel
self.fc1 = tensor_parallel.ColumnParallelLinear(
in_channels,
hidden_channels,
config=config,
init_method=config.init_method,
bias=bias[0],
gather_output=False
)
else:
self.fc1 = linear_layer(
in_channels, hidden_channels, bias=bias[0]
)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_channels)
if norm_layer is not None
else nn.Identity()
)
if norm_layer is not None:
for param in self.norm.parameters():
setattr(param, "sequence_parallel", enable_tp_sp and args.sequence_parallel)
if self.enable_tensor_parallel:
config.sequence_parallel = enable_tp_sp and args.sequence_parallel
self.fc2 = tensor_parallel.RowParallelLinear(
hidden_channels,
out_features,
config=config,
init_method=config.init_method,
bias=bias[1],
input_is_parallel=True,
skip_bias_add=False
)
else:
self.fc2 = linear_layer(
hidden_channels, out_features, bias=bias[1]
)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
if self.enable_tensor_parallel:
x = self.fc1(x)[0]
else:
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
if self.enable_tensor_parallel:
x = self.fc2(x)[0]
else:
x = self.fc2(x)
x = self.drop2(x)
return x