import math
from contextlib import nullcontext
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch_npu
from einops import rearrange
from megatron.core import mpu, tensor_parallel
from megatron.legacy.model.enums import AttnType
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import (
all_to_all,
gather_forward_split_backward,
split_forward_gather_backward,
)
from mindspeed.core.context_parallel.model_parallel_utils import get_context_parallel_group_for_hybrid_ulysses
from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.attention import FlashAttention, ParallelAttention
from mindspeed_mm.models.common.embeddings import TextProjection
from mindspeed_mm.models.common.normalize import normalize, FP32LayerNorm
from mindspeed_mm.models.common.fpdt_layer import (
FPDTFlashAttention,
split_forward_gather_backward_FPDT_tensors,
gather_forward_split_backward_FPDT_tensors)
from mindspeed_mm.utils.utils import change_tensor_layout
class WanDiT(MultiModalModule):
def __init__(
self,
model_type: str = "t2v",
patch_size: Tuple[int] = (1, 2, 2),
text_len: int = 512,
in_dim: int = 16,
hidden_size: int = 2048,
ffn_dim: int = 8192,
freq_dim: int = 256,
text_dim: int = 4096,
img_dim: int = 1280,
out_dim: int = 16,
num_heads: int = 16,
num_layers: int = 32,
qk_norm: bool = True,
qk_norm_type: str = "rmsnorm",
cross_attn_norm: bool = False,
eps: float = 1e-6,
max_seq_len: int = 1024,
fa_layout: str = "bnsd",
clip_token_len: int = 257,
pre_process: bool = True,
post_process: bool = True,
global_layer_idx: Optional[Tuple] = None,
attention_async_offload: bool = False,
fp32_calculate: bool = False,
seperated_timestep: bool = False,
**kwargs,
):
super().__init__(config=None)
if model_type not in ["t2v", "i2v", "flf2v", "ti2v", "wan2.2-t2v", "wan2.2-i2v"]:
raise ValueError("Please only select among 't2v', 'i2v', 'ti2v', 'flf2v', 'wan2.2-t2v' and 'wan2.2-i2v' tasks")
if not ((hidden_size % num_heads) == 0 and (hidden_size // num_heads) % 2 == 0):
raise ValueError(
"The dimension must be divisible by num_heads, and result of 'dim // num_heads' must be even"
)
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.hidden_size = hidden_size
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.img_dim = img_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.qk_norm = qk_norm
self.qk_norm_type = qk_norm_type
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.max_seq_len = max_seq_len
self.fa_layout = fa_layout
self.clip_token_len = clip_token_len
self.pre_process = pre_process
self.post_process = post_process
self.global_layer_idx = global_layer_idx if global_layer_idx is not None else tuple(range(num_layers))
self.head_dim = hidden_size // num_heads
args = get_args()
config = core_transformer_config_from_args(args)
self.recompute_granularity = args.recompute_granularity
self.distribute_saved_activations = args.distribute_saved_activations
self.recompute_method = args.recompute_method
self.recompute_layers = (
args.recompute_num_layers
if args.recompute_num_layers is not None
else num_layers
)
self.recompute_skip_core_attention = args.recompute_skip_core_attention
self.recompute_num_layers_skip_core_attention = args.recompute_num_layers_skip_core_attention
self.attention_async_offload = attention_async_offload
self.fp32_calculate = fp32_calculate
self.h2d_stream = torch_npu.npu.Stream() if attention_async_offload else None
self.d2h_stream = torch_npu.npu.Stream() if attention_async_offload else None
if self.recompute_granularity == "selective":
raise ValueError(
"recompute_granularity does not support selective mode in wanVideo"
)
if self.distribute_saved_activations:
raise NotImplementedError(
"distribute_save_activation is currently not supported"
)
self.enable_tensor_parallel = mpu.get_tensor_model_parallel_world_size() > 1
self.sequence_parallel = args.sequence_parallel and self.enable_tensor_parallel
self.context_parallel_algo = (
args.context_parallel_algo
if mpu.get_context_parallel_world_size() > 1
else None
)
if (
self.context_parallel_algo is not None
and self.context_parallel_algo
not in ["ulysses_cp_algo", "hybrid_cp_algo", "megatron_cp_algo"]
):
raise NotImplementedError(
f"Context_parallel_algo {self.context_parallel_algo} is not implemented"
)
self.FPDT = args.mm.model.to_dict().get('predictor', {}).get('FPDT', False)
self.FPDT_chunk_number = args.mm.model.to_dict().get('predictor', {}).get('FPDT_chunk_number', None)
if self.pre_process:
self.time_embedding = nn.Sequential(
nn.Linear(self.freq_dim, self.hidden_size),
nn.SiLU(),
nn.Linear(self.hidden_size, self.hidden_size),
)
if self.fp32_calculate:
self.time_embedding = self.time_embedding.to(torch.float32)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(self.hidden_size, self.hidden_size * 6)
)
self.text_embedding = TextProjection(
self.text_dim, self.hidden_size, partial(nn.GELU, approximate="tanh")
)
if model_type in ["i2v", "flf2v"]:
self.img_emb = MLPProj(self.img_dim, self.hidden_size, model_type == 'flf2v', clip_token_len, self.fp32_calculate)
self.patch_embedding = nn.Conv3d(
self.in_dim,
self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.rope = RoPE3DWan(head_dim=self.head_dim, max_seq_len=self.max_seq_len)
self.blocks = nn.ModuleList(
[
WanDiTBlock(
model_type,
self.hidden_size,
self.ffn_dim,
self.num_heads,
self.qk_norm,
self.qk_norm_type,
self.cross_attn_norm,
self.eps,
rope=self.rope,
fa_layout=self.fa_layout,
clip_token_len=clip_token_len,
attention_async_offload=self.attention_async_offload,
h2d_stream=self.h2d_stream,
d2h_stream=self.d2h_stream,
layer_idx=index,
num_layers=self.num_layers,
fp32_calculate=self.fp32_calculate,
)
for index in range(self.num_layers)
]
)
if self.post_process:
self.head = Head(self.hidden_size, self.out_dim, self.patch_size, self.eps)
self.use_dpo = getattr(args.mm.model, "dpo", None)
self.seperated_timestep = seperated_timestep
@property
def dtype(self) -> torch.dtype:
"""The dtype of the module (assuming that all the module parameters have the same dtype)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].dtype
else:
buffers = tuple(self.buffers())
return buffers[0].dtype
@property
def device(self) -> torch.device:
"""The device of the module (assuming that all the module parameters are in the same device)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].device
else:
buffers = tuple(self.buffers())
return buffers[0].device
def sinusoidal_embedding_1d(self, dim, position, theta=10000):
sinusoid = torch.outer(
position.type(torch.float64),
torch.pow(
theta,
-torch.arange(
dim // 2, dtype=torch.float64, device=position.device
).div(dim // 2),
),
)
embs = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return embs.to(position.dtype)
def _checkpointed_forward(self, blocks, x, *args):
"Forward method with activation checkpointing."
num_layers = len(blocks)
recompute_layers = self.recompute_layers
recompute_num_layers_skip_core_attention = (
self.recompute_num_layers_skip_core_attention
if self.recompute_skip_core_attention
else 0
)
def custom(start, end):
def custom_forward(*args):
for index in range(start, end):
layer = blocks[index]
x_ = layer(*args)
return x_
return custom_forward
if self.recompute_method == "uniform":
_layer_num = 0
while _layer_num < num_layers:
x = tensor_parallel.checkpoint(
custom(_layer_num, _layer_num + recompute_layers),
self.distribute_saved_activations,
x,
*args,
)
_layer_num += recompute_layers
elif self.recompute_method == "block":
for _layer_num in range(num_layers):
if _layer_num < recompute_layers:
x = tensor_parallel.checkpoint(
custom(_layer_num, _layer_num + 1),
self.distribute_saved_activations,
x,
*args,
)
elif _layer_num < recompute_layers + recompute_num_layers_skip_core_attention:
block = blocks[_layer_num]
x = block(x, *args, recompute_skip_core_attention=True)
else:
block = blocks[_layer_num]
x = block(x, *args)
else:
raise ValueError(
f"Invalid activation recompute method {self.recompute_method}."
)
return x
def patchify(self, embs: torch.Tensor):
grid_sizes = embs.shape[2:]
patch_out = rearrange(embs, "b c f h w -> b (f h w) c").contiguous()
return patch_out, grid_sizes
def unpatchify(self, embs, frames, height, width):
patch_out = rearrange(
embs,
"b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
f=frames,
h=height,
w=width,
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2],
)
return patch_out
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
prompt: torch.Tensor,
prompt_mask: torch.Tensor = None,
i2v_clip_feature: torch.Tensor = None,
i2v_vae_feature: torch.Tensor = None,
**kwargs,
):
if self.pre_process:
timestep = timestep.to(x[0].device)
if self.seperated_timestep:
timestep = torch.concat([
torch.zeros((1, x.shape[3] * x.shape[4] // 4), dtype=x.dtype, device=x.device),
torch.ones((x.shape[2] - 1, x.shape[3] * x.shape[4] // 4), dtype=x.dtype,
device=x.device) * timestep
])
if timestep.ndim == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten()
else:
ts_seq_len = None
timestep = self.sinusoidal_embedding_1d(self.freq_dim, timestep)
if ts_seq_len is not None and not self.seperated_timestep:
timestep = timestep.unflatten(0, (-1, ts_seq_len))
times = self.time_embedding(timestep)
if self.seperated_timestep:
times = times.unsqueeze(0)
time_emb = self.time_projection(times)
if ts_seq_len is None:
time_emb = time_emb.unflatten(1, (6, self.hidden_size))
else:
time_emb = time_emb.unflatten(2, (6, self.hidden_size))
bs = prompt.size(0)
prompt = prompt.view(bs, -1, prompt.size(-1))
if prompt_mask is not None:
seq_lens = prompt_mask.view(bs, -1).sum(dim=-1)
seq_lens = seq_lens.to(torch.int64)
for i, seq_len in enumerate(seq_lens):
prompt[i, seq_len:] = 0
prompt_emb = self.text_embedding(prompt)
if self.model_type in ["i2v", "flf2v"]:
i2v_clip_feature = i2v_clip_feature.to(x)
i2v_vae_feature = i2v_vae_feature.to(x)
x = torch.cat([x, i2v_vae_feature], dim=1)
clip_embedding = self.img_emb(i2v_clip_feature.float() if self.fp32_calculate else i2v_clip_feature.to(time_emb.dtype))
prompt_emb = torch.cat([clip_embedding, prompt_emb], dim=1)
elif self.model_type in ["wan2.2-i2v"]:
i2v_vae_feature = i2v_vae_feature.to(x)
x = torch.cat([x, i2v_vae_feature], dim=1)
patch_emb = self.patch_embedding(x.to(time_emb.dtype))
embs, grid_sizes = self.patchify(patch_emb)
batch_size, frames, height, width = (
embs.shape[0],
grid_sizes[0],
grid_sizes[1],
grid_sizes[2],
)
else:
batch_size, _, frames, height, width = kwargs["ori_shape"]
height, width = height // self.patch_size[1], width // self.patch_size[2]
prompt_emb = kwargs['prompt_emb']
time_emb = kwargs['time_emb']
times = kwargs['times']
embs = x
rotary_pos_emb = self.rope(batch_size, frames, height, width)
if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.context_parallel_algo is not None:
if self.pre_process:
if self.FPDT:
embs = split_forward_gather_backward_FPDT_tensors(embs, seq_dim=1, chunk_number=self.FPDT_chunk_number,
group=mpu.get_context_parallel_group(), grad_scale="down")
if time_emb.ndim == 4:
time_emb = split_forward_gather_backward_FPDT_tensors(time_emb, seq_dim=1, chunk_number=self.FPDT_chunk_number,
group=mpu.get_context_parallel_group(), grad_scale="down")
else:
embs = split_forward_gather_backward(
embs, mpu.get_context_parallel_group(), dim=1, grad_scale="down"
)
if time_emb.ndim == 4:
time_emb = split_forward_gather_backward(
time_emb, mpu.get_context_parallel_group(), dim=1, grad_scale="down"
)
if self.FPDT:
rotary_pos_emb = split_forward_gather_backward_FPDT_tensors(rotary_pos_emb, seq_dim=0, chunk_number=self.FPDT_chunk_number,
group=mpu.get_context_parallel_group(), grad_scale="down")
else:
rotary_pos_emb = split_forward_gather_backward(
rotary_pos_emb,
mpu.get_context_parallel_group(),
dim=0,
grad_scale="down",
)
with rng_context:
if self.recompute_granularity == "full":
embs = self._checkpointed_forward(
self.blocks,
embs,
prompt_emb,
time_emb,
rotary_pos_emb,
)
else:
for block in self.blocks:
embs = block(embs, prompt_emb, time_emb, rotary_pos_emb)
out = embs
if self.post_process:
if self.context_parallel_algo is not None:
if self.FPDT:
embs = gather_forward_split_backward_FPDT_tensors(embs, seq_dim=1, chunk_number=self.FPDT_chunk_number,
group=mpu.get_context_parallel_group(), grad_scale="up")
if time_emb.ndim == 4:
time_emb = gather_forward_split_backward_FPDT_tensors(time_emb, seq_dim=1, chunk_number=self.FPDT_chunk_number,
group=mpu.get_context_parallel_group(), grad_scale="up")
else:
embs = gather_forward_split_backward(
embs, mpu.get_context_parallel_group(), dim=1, grad_scale="up"
)
if time_emb.ndim == 4:
time_emb = gather_forward_split_backward(
time_emb, mpu.get_context_parallel_group(), dim=1, grad_scale="up"
)
embs_out = self.head(embs, times)
out = self.unpatchify(embs_out, frames, height, width)
rtn = (out, prompt, prompt_emb, time_emb, times, prompt_mask)
return rtn
def pipeline_set_prev_stage_tensor(self, input_tensor_list, extra_kwargs):
"""
Implemented for pipeline parallelism. The input tensor is got from last PP stage.
Args:
input_tensor_list: same as the return value of pipeline_set_next_stage_tensor
extra_kwargs: kwargs for forward func.
Returns:
predictor_input_list: values for predictor forward.
training_loss_input_list: values to calculate loss.
"""
score, score_lose = None, None
if self.use_dpo is not None:
(prev_output, prompt, prompt_emb, time_emb, times, prompt_mask, score, score_lose,
latents, noised_latents, timesteps, noise) = input_tensor_list
else:
(prev_output, prompt, prompt_emb, time_emb, times, prompt_mask,
latents, noised_latents, timesteps, noise) = input_tensor_list
predictor_input_list = [prev_output, timesteps, prompt, None, prompt_mask]
training_loss_input_list = [latents, noised_latents, timesteps, noise, None]
extra_kwargs['prompt_emb'] = prompt_emb
extra_kwargs['time_emb'] = time_emb
extra_kwargs['times'] = times
extra_kwargs["ori_shape"] = latents.shape
if self.use_dpo is not None:
score_list = [score, score_lose]
return predictor_input_list, training_loss_input_list, score_list
return predictor_input_list, training_loss_input_list
def pipeline_set_next_stage_tensor(self, input_list, output_list, extra_kwargs=None):
"""
input_list: [latents, noised_latents, timesteps, noise, video_mask]
output_list (predict_output):[out, prompt, prompt_emb, time_emb, times, prompt_mask]
return as
prev_output, prompt, prompt_emb, prompt_emb, time_emb, times, prompt_mask,
latents, timesteps, noise
which should be corresponded with initialize_pipeline_tensor_shapes
"""
latents, noised_latents, timesteps, noise, _ = input_list
if timesteps.dtype != torch.float32:
timesteps = timesteps.to(torch.float32)
return list(output_list) + [latents, noised_latents, timesteps, noise]
@staticmethod
def initialize_pipeline_tensor_shapes():
args = get_args()
micro_batch_size = args.micro_batch_size
dtype = args.params_dtype
model_cfg = args.mm.model
data_cfg = args.mm.data.dataset_param.preprocess_parameters
hidden_size = model_cfg.predictor.hidden_size
height = getattr(data_cfg, "max_height", 480)
width = getattr(data_cfg, "max_width", 832)
vae_scale_factor = getattr(model_cfg.predictor, "vae_scale_factor", [4, 8, 8])
latent_size = ((data_cfg.num_frames + 3) // vae_scale_factor[0], height // vae_scale_factor[1], width // vae_scale_factor[2])
divisor = model_cfg.predictor.patch_size[0] * model_cfg.predictor.patch_size[1] * \
model_cfg.predictor.patch_size[2]
seq_len = latent_size[0] * latent_size[1] * latent_size[2] // divisor // mpu.get_context_parallel_world_size()
channels = model_cfg.predictor.out_dim
text_dim = model_cfg.predictor.text_dim
text_len = model_cfg.predictor.text_len
img_token_len = model_cfg.predictor.clip_token_len if model_cfg.predictor.model_type == 'i2v' else 0
rtn_size = 1
use_dpo = getattr(model_cfg, "dpo", None)
if use_dpo is not None:
micro_batch_size = micro_batch_size * 2
rtn_size = 2
pipeline_tensor_shapes = [
{'shape': (micro_batch_size * rtn_size, seq_len, hidden_size), 'dtype': dtype},
{'shape': (micro_batch_size * rtn_size, text_len, text_dim), 'dtype': dtype},
{'shape': (micro_batch_size * rtn_size, text_len + img_token_len, hidden_size), 'dtype': dtype},
{'shape': (micro_batch_size * rtn_size, 6, hidden_size), 'dtype': dtype},
{'shape': (micro_batch_size * rtn_size, hidden_size), 'dtype': dtype},
{'shape': (micro_batch_size * rtn_size, 1, text_len), 'dtype': dtype},
{'shape': (micro_batch_size, channels, *latent_size), 'dtype': dtype},
{"shape": (micro_batch_size, channels, *latent_size), "dtype": dtype},
{'shape': (micro_batch_size,), 'dtype': torch.float32},
{'shape': (micro_batch_size, channels, *latent_size), 'dtype': dtype},
]
if use_dpo is not None:
score_shape = [
{'shape': (1,), 'dtype': torch.float64},
{'shape': (1,), 'dtype': torch.float64},
]
pipeline_tensor_shapes = pipeline_tensor_shapes[:6] + score_shape + pipeline_tensor_shapes[6:]
return pipeline_tensor_shapes
class WanDiTBlock(nn.Module):
def __init__(
self,
model_type: "t2v",
hidden_size: int,
ffn_dim: int,
num_heads: int,
qk_norm: bool = True,
qk_norm_type: str = "rmsnorm",
cross_attn_norm: bool = False,
eps: float = 1e-6,
attention_bias: bool = True,
attention_out_bias: bool = True,
dropout: float = 0.0,
rope=None,
fa_layout=None,
clip_token_len: int = 257,
attention_async_offload: bool = False,
layer_idx: int = 0,
num_layers: int = 40,
fp32_calculate: bool = False,
h2d_stream: Optional[torch_npu.npu.Stream] = None,
d2h_stream: Optional[torch_npu.npu.Stream] = None,
**kwargs
):
super().__init__()
self.model_type = model_type
self.rope = rope
self.hidden_size = hidden_size
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.clip_token_len = clip_token_len
self.fp32_calculate = fp32_calculate
args = get_args()
self.FPDT = args.mm.model.to_dict().get('predictor', {}).get('FPDT', False)
self.FPDT_chunk_number = args.mm.model.to_dict().get('predictor', {}).get('FPDT_chunk_number', None)
self.distribute_saved_activations = args.distribute_saved_activations
self.attention_async_offload_param = {
"async_offload": attention_async_offload,
"block_idx": layer_idx,
"depth": num_layers,
"h2d_stream": h2d_stream,
"d2h_stream": d2h_stream,
}
self.modulation = nn.Parameter(
torch.randn(1, 6, self.hidden_size) / self.hidden_size**0.5
)
self.norm1 = nn.LayerNorm(self.hidden_size, eps=eps, elementwise_affine=False)
self.self_attn = WanVideoParallelAttention(
query_dim=hidden_size,
key_dim=None,
num_attention_heads=num_heads,
hidden_size=hidden_size,
proj_q_bias=attention_bias,
proj_k_bias=attention_bias,
proj_v_bias=attention_bias,
proj_out_bias=attention_out_bias,
dropout=dropout,
use_qk_norm=qk_norm,
norm_type=qk_norm_type,
norm_eps=eps,
rope=rope,
attention_type=AttnType.self_attn,
has_img_input=False,
fa_layout=fa_layout,
)
self.norm3 = FP32LayerNorm(self.hidden_size, eps=eps) if fp32_calculate else nn.LayerNorm(self.hidden_size, eps=eps)
self.cross_attn = WanVideoParallelAttention(
query_dim=hidden_size,
key_dim=None,
num_attention_heads=num_heads,
hidden_size=hidden_size,
proj_q_bias=attention_bias,
proj_k_bias=attention_bias,
proj_v_bias=attention_bias,
proj_out_bias=attention_out_bias,
dropout=dropout,
use_qk_norm=qk_norm,
norm_type=qk_norm_type,
norm_eps=eps,
attention_type=AttnType.cross_attn,
has_img_input=model_type in ["i2v", "flf2v"],
fa_layout=fa_layout,
)
self.norm2 = nn.LayerNorm(self.hidden_size, eps=eps, elementwise_affine=False)
self.ffn = nn.Sequential(
nn.Linear(self.hidden_size, self.ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(self.ffn_dim, self.hidden_size),
)
def modulate(self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def forward(
self,
latents,
prompt,
time_emb,
rotary_pos_emb,
recompute_skip_core_attention=False
):
torch.npu.synchronize()
if recompute_skip_core_attention:
query, key, value, gate_msa, shift_mlp, scale_mlp, gate_mlp = tensor_parallel.checkpoint(
self._before_self_attention,
self.distribute_saved_activations,
time_emb,
latents,
rotary_pos_emb,
)
else:
query, key, value, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._before_self_attention(
time_emb,
latents,
rotary_pos_emb
)
attention_async_offload_param = (
self.attention_async_offload_param
if recompute_skip_core_attention
else {}
)
self_attn_out = self.self_attn.core_attention_flash(
query=query,
key=key,
value=value,
**attention_async_offload_param
)
if recompute_skip_core_attention:
latents = tensor_parallel.checkpoint(
self._after_self_attention,
self.distribute_saved_activations,
self_attn_out,
latents,
prompt,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp
)
else:
latents = self._after_self_attention(
self_attn_out,
latents,
prompt,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp
)
return latents
def _before_self_attention(
self,
time_emb,
latents,
rotary_pos_emb
):
dtype = time_emb.dtype
modu_dtype = torch.float32 if self.fp32_calculate else dtype
device = time_emb.device
has_seq = time_emb.ndim == 4
chunk_dim = 2 if has_seq else 1
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=modu_dtype, device=device) + time_emb.to(modu_dtype)
).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
self_attn_input = self.modulate(
self.norm1(latents.to(torch.float32)), shift_msa, scale_msa
).to(dtype)
query, key, value = self.self_attn.function_before_core_attention(
query=self_attn_input,
input_layout="bsh",
rotary_pos_emb=rotary_pos_emb.to(time_emb.device)
)
return (
query, key, value,
gate_msa, shift_mlp, scale_mlp, gate_mlp
)
def _after_self_attention(
self,
self_attn_out,
latents,
prompt,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp
):
dtype = torch.float32 if self.fp32_calculate else latents.dtype
self_attn_out = self.self_attn.function_after_core_attention(self_attn_out, output_layout="bsh")
latents = (latents + gate_msa * self_attn_out).to(latents.dtype)
crs_attn_input = self.norm3(latents.to(dtype)).to(latents.dtype)
if self.model_type in ["i2v", "flf2v"]:
img_clip_token_len = 2 * self.clip_token_len if self.model_type == "flf2v" else self.clip_token_len
img = prompt[:, :img_clip_token_len]
txt = prompt[:, img_clip_token_len:]
crs_attn_out = self.cross_attn(
query=crs_attn_input,
key=(img, txt),
input_layout="bsh",
)
else:
txt = prompt
crs_attn_out = self.cross_attn(
query=crs_attn_input,
key=txt,
input_layout="bsh",
)
latents = latents + crs_attn_out
modu_out = self.modulate(self.norm2(latents.to(dtype)), shift_mlp, scale_mlp).to(latents.dtype)
if self.FPDT:
latents = ((latents.to(dtype)) + gate_mlp * self.fpdt_ffn(modu_out).to(dtype)).to(latents.dtype)
else:
latents = ((latents.to(dtype)) + gate_mlp * self.ffn(modu_out).to(dtype)).to(latents.dtype)
return latents
def fpdt_ffn(self, x):
outs = []
inputs = torch.chunk(x, dim=1, chunks=self.FPDT_chunk_number)
for input_chunk in inputs:
outs.append(self.ffn(input_chunk))
output = torch.concat(outs, dim=1).contiguous()
return output
class RoPE3DWan(nn.Module):
def __init__(self, head_dim, max_seq_len):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.freqs = self.get_freq(head_dim)
self.freqs = [x.npu() for x in self.freqs]
def get_freq(self, head_dim):
if head_dim <= 0:
raise ValueError("head dimension must be greater than 0")
dim1 = head_dim - 2 * (head_dim // 3)
dim2 = head_dim // 3
freqs1 = self.rope_params(self.max_seq_len, dim1)
freqs2 = self.rope_params(self.max_seq_len, dim2)
freqs3 = self.rope_params(self.max_seq_len, dim2)
return freqs1, freqs2, freqs3
def rope_params(self, max_seq_len, dim, theta=10000):
if dim % 2 != 0:
raise ValueError("Dimension must be even")
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(max_seq_len, device=freqs.device), freqs)
freqs = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
return freqs
def apply_rotary_pos_emb(self, tokens, freqs):
dtype = tokens.dtype
cos, sin = torch.chunk(torch.view_as_real(freqs.to(torch.complex64)), 2, dim=-1)
B, S, N, D = tokens.shape
def rotate_half(x):
half_1, half_2 = torch.chunk(x.reshape((B, S, N, D // 2, 2)), 2, dim=-1)
return torch.cat((-half_2, half_1), dim=-1).reshape((B, S, N, D))
cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2)
sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2)
res = tokens * cos + rotate_half(tokens) * sin
return res.to(dtype)
def forward(self, b, f, h, w):
seq_len = f * h * w
freqs = (
torch.cat(
[
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
.reshape(seq_len, 1, 1, -1)
.expand(seq_len, b, 1, -1)
)
return freqs
class WanVideoParallelAttention(ParallelAttention):
def __init__(
self,
query_dim: int,
key_dim: Optional[int],
num_attention_heads: int,
hidden_size: int,
proj_q_bias: bool = False,
proj_k_bias: bool = False,
proj_v_bias: bool = False,
proj_out_bias: bool = False,
dropout: float = 0.0,
use_qk_norm: bool = False,
norm_type: str = None,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: int = AttnType.self_attn,
has_img_input: bool = False,
fa_layout: str = "bnsd",
rope=None,
**kwargs,
):
super().__init__(
query_dim=query_dim,
key_dim=key_dim,
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
proj_q_bias=proj_q_bias,
proj_k_bias=proj_k_bias,
proj_v_bias=proj_v_bias,
proj_out_bias=proj_out_bias,
dropout=dropout,
use_qk_norm=use_qk_norm,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
is_qkv_concat=False,
attention_type=attention_type,
is_kv_concat=False,
fa_layout=fa_layout,
rope=rope,
**kwargs,
)
args = get_args()
self.FPDT = args.mm.model.to_dict().get('predictor', {}).get('FPDT', False)
self.FPDT_chunk_number = args.mm.model.to_dict().get('predictor', {}).get('FPDT_chunk_number', None)
self.FPDT_with_offload = args.mm.model.to_dict().get('predictor', {}).get('FPDT_with_offload', False)
if self.cp_size > 1 and attention_type == AttnType.self_attn \
and args.context_parallel_algo in ["megatron_cp_algo", "hybrid_cp_algo"]:
fa_layout = "sbh"
self.core_attention_flash = FlashAttention(
attention_dropout=dropout,
fa_layout=fa_layout,
softmax_scale=1 / math.sqrt(self.head_dim),
)
if self.cp_size > 1 and attention_type == AttnType.self_attn \
and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]:
if args.context_parallel_algo == "hybrid_cp_algo":
ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
else:
ulysses_group = mpu.get_context_parallel_group()
if self.FPDT:
self.core_attention_flash = FPDTFlashAttention(
ulysess_context_parallel_group=ulysses_group,
hidden_size=hidden_size,
head_dim=hidden_size // num_attention_heads,
chunk_number=self.FPDT_chunk_number,
with_offload=self.FPDT_with_offload
)
else:
self.core_attention_flash = UlyssesContextAttention(self.core_attention_flash, ulysses_group)
if self.cp_size > 1 and attention_type == AttnType.cross_attn:
self.core_attention_flash.context_parallel_algo = "ulysses_cp_algo"
if self.use_qk_norm:
self.q_norm = normalize(
norm_type=norm_type,
in_channels=hidden_size,
eps=norm_eps,
affine=norm_elementwise_affine,
**kwargs,
)
self.k_norm = normalize(
norm_type=norm_type,
in_channels=hidden_size,
eps=norm_eps,
affine=norm_elementwise_affine,
**kwargs,
)
if isinstance(self.q_norm, nn.LayerNorm):
for param in self.q_norm.parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
if isinstance(self.k_norm, nn.LayerNorm):
for param in self.k_norm.parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
self.has_img_input = has_img_input
if self.has_img_input:
args = get_args()
config = core_transformer_config_from_args(args)
self.k_img = tensor_parallel.ColumnParallelLinear(
query_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False,
)
self.v_img = tensor_parallel.ColumnParallelLinear(
query_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False,
)
self.k_norm_img = normalize(
norm_type=norm_type,
in_channels=hidden_size,
eps=norm_eps,
affine=norm_elementwise_affine,
**kwargs,
)
def function_after_core_attention(
self,
core_attn_out,
output_layout: str = "sbh"
):
if self.FPDT:
chunk_number = self.FPDT_chunk_number
core_attn_out_chunks = torch.chunk(core_attn_out, chunks=chunk_number, dim=0)
output = [None for _ in range(chunk_number)]
for i in range(chunk_number):
output[i], _ = self.proj_out(core_attn_out_chunks[i])
output = torch.cat(output, dim=0)
else:
output, bias = self.proj_out(core_attn_out)
output = change_tensor_layout(output, "sbh", output_layout)
output = self.dropout(output)
return output
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensor
from `hidden_states` or `key_value_states`.
"""
if self.has_img_input:
img_key_value_states, context_key_value_states = key_value_states
query = self.proj_q(hidden_states)[0]
img_key = self.k_img(img_key_value_states)[0]
img_value = self.v_img(img_key_value_states)[0]
key = self.proj_k(context_key_value_states)[0]
value = self.proj_v(context_key_value_states)[0]
else:
query = self.proj_q(hidden_states)[0]
key = self.proj_k(key_value_states)[0]
value = self.proj_v(key_value_states)[0]
if self.use_qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)
if self.has_img_input:
img_key = self.k_norm_img(img_key)
batch_size = query.shape[1]
query = query.view(
-1, batch_size, self.num_attention_heads_per_partition, self.head_dim
)
key = key.view(
-1, batch_size, self.num_attention_heads_per_partition, self.head_dim
)
value = value.view(
-1, batch_size, self.num_attention_heads_per_partition, self.head_dim
)
if self.has_img_input:
img_key = img_key.view(
-1, batch_size, self.num_attention_heads_per_partition, self.head_dim
)
img_value = img_value.view(
-1, batch_size, self.num_attention_heads_per_partition, self.head_dim
)
key = [img_key, key]
value = [img_value, value]
return query, key, value
def forward(
self,
query: torch.Tensor,
key: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None,
mask: Optional[torch.Tensor] = None,
input_layout: str = "sbh",
rotary_pos_emb: Optional[torch.Tensor] = None,
):
if self.has_img_input:
query, key, value = self.function_before_core_attention(
query, key, input_layout, rotary_pos_emb
)
img_core_attn_out = self.core_attention_flash(query, key[0], value[0], mask)
core_attn_out = self.core_attention_flash(query, key[1], value[1], mask)
core_attn_out = img_core_attn_out + core_attn_out
out = self.function_after_core_attention(core_attn_out, input_layout)
return out
else:
return super().forward(query, key, mask, input_layout, rotary_pos_emb)
class Head(nn.Module):
def __init__(
self, dim: int, out_dim: int, patch_size: List[int], eps: float = 1e-6, fp32_calculate: bool = False
):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.fp32_calculate = fp32_calculate
self.norm = FP32LayerNorm(dim, eps=eps, elementwise_affine=False) if fp32_calculate else nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, latents, times):
if times.ndim == 3:
shift, scale = (
self.modulation.unsqueeze(0).to(dtype=torch.float32 if self.fp32_calculate else times.dtype, device=times.device) + times.unsqueeze(2)
).chunk(2, dim=2)
out = self.head((self.norm(latents.float() if self.fp32_calculate else latents) * (1 + scale.squeeze(2)) + shift.squeeze(2)).to(latents.dtype))
else:
shift, scale = (
self.modulation.to(dtype=torch.float32 if self.fp32_calculate else times.dtype, device=times.device) + times
).chunk(2, dim=1)
out = self.head((self.norm(latents.float() if self.fp32_calculate else latents) * (1 + scale) + shift).to(latents.dtype))
return out
class MLPProj(nn.Module):
def __init__(self, in_dim: int, out_dim: int, flf_pos_emb=False, clip_token_len=257, fp32_calculate=False):
super().__init__()
self.proj = nn.Sequential(
FP32LayerNorm(in_dim) if fp32_calculate else nn.LayerNorm(in_dim),
nn.Linear(in_dim, in_dim),
nn.GELU(),
nn.Linear(in_dim, out_dim),
FP32LayerNorm(out_dim) if fp32_calculate else nn.LayerNorm(out_dim),
)
if flf_pos_emb:
self.emb_pos = nn.Parameter(torch.zeros(1, clip_token_len * 2, in_dim))
def forward(self, image_emb):
if hasattr(self, 'emb_pos'):
bs, n, d = image_emb.shape
image_emb = image_emb.view(-1, 2 * n, d)
image_emb = image_emb + self.emb_pos
return self.proj(image_emb)