import copy
from contextlib import nullcontext, contextmanager
from typing import Optional, Tuple
import numpy as np
import torch
import torch_npu
import torch.nn as nn
from megatron.core import tensor_parallel, mpu
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.checkpoint import load_checkpoint
from mindspeed_mm.models.predictor.dits.wan_dit import WanDiTBlock, WanDiT, RoPE3DWan
class VaceWanAttentionBlock(nn.Module):
def __init__(self, **kwargs):
super(VaceWanAttentionBlock, self).__init__()
self.layer_idx = kwargs['layer_idx']
self.dim = kwargs['hidden_size']
self.wan_dit_block = WanDiTBlock(**kwargs)
if self.layer_idx == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(
self,
vace_context,
latents,
prompt,
time_emb,
rotary_pos_emb,
recompute_skip_core_attention=False
):
if self.layer_idx == 0:
vace_context = self.before_proj(vace_context) + latents
all_c = []
else:
all_c = list(torch.unbind(vace_context))
vace_context = all_c.pop(-1)
vace_context = self.wan_dit_block(vace_context, prompt, time_emb, rotary_pos_emb, recompute_skip_core_attention)
c_skip = self.after_proj(vace_context)
all_c += [c_skip, vace_context]
vace_context = torch.stack(all_c)
return vace_context
class VaceDit(nn.Module):
def __init__(
self,
vace_layers: Tuple[int] = (0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
has_image_input: bool = False,
patch_size: Tuple[int] = (1, 2, 2),
text_len: int = 512,
in_dim: int = 96,
hidden_size: int = 1536,
ffn_dim: int = 8960,
freq_dim: int = 256,
text_dim: int = 4096,
img_dim: int = 1280,
out_dim: int = 16,
num_heads: int = 12,
num_layers: int = 32,
qk_norm: bool = True,
qk_norm_type: str = 'rmsnorm',
cross_attn_norm: bool = False,
eps: float = 1e-6,
max_seq_len: int = 1024,
fa_layout: str = "bnsd",
clip_token_len: int = 257,
pre_process: bool = True,
post_process: bool = True,
global_layer_idx: Optional[Tuple] = None,
atention_async_offload: bool = False,
fp32_calculate: bool = False,
**kwargs,
):
super(VaceDit, self).__init__()
self.vace_layers = vace_layers
self.vace_to_wan = {vace_layer_num: wan_layer_num for wan_layer_num, vace_layer_num in enumerate(self.vace_layers)}
self.has_image_input = has_image_input
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.hidden_size = hidden_size
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.img_dim = img_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.qk_norm = qk_norm
self.qk_norm_type = qk_norm_type
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.max_seq_len = max_seq_len
self.fa_layout = fa_layout
self.clip_token_len = clip_token_len
self.pre_process = pre_process
self.post_process = post_process
self.global_layer_idx = global_layer_idx
self.head_dim = hidden_size // num_heads
args = get_args()
config = core_transformer_config_from_args(args)
self.recompute_granularity = args.recompute_granularity
self.distribute_saved_activations = args.distribute_saved_activations
self.recompute_method = args.recompute_method
self.recompute_layers = {
args.recompute_num_layers
if args.recompute_num_layers is not None
else num_layers
}
self.recompute_skip_core_attention = args.recompute_skip_core_attention
self.recompute_num_layers_skip_core_attention = args.recompute_num_layers_skip_core_attention
self.attention_async_offload = atention_async_offload
self.fp32_calculate = fp32_calculate
self.h2d_stream = torch_npu.npu.Stream() if atention_async_offload else None
self.d2h_stream = torch_npu.npu.Stream() if atention_async_offload else None
self.rope = RoPE3DWan(head_dim=self.head_dim, max_seq_len=self.max_seq_len)
self.vace_blocks = torch.nn.ModuleList([
VaceWanAttentionBlock(
model_type="t2v",
hidden_size=self.hidden_size,
ffn_dim=self.ffn_dim,
num_heads=self.num_heads,
qk_norm=self.qk_norm,
qk_norm_type=self.qk_norm_type,
cross_attn_norm=self.cross_attn_norm,
eps=self.eps,
rope=self.rope,
fa_layout=self.fa_layout,
clip_token_len=self.clip_token_len,
atention_async_offload=self.attention_async_offload,
layer_idx=index,
num_layers=self.num_layers,
fp32_calculate=self.fp32_calculate,
h2d_stream=self.h2d_stream,
d2h_stream=self.d2h_stream
)
for index in range(len(self.vace_layers))
])
self.vace_patch_embedding = torch.nn.Conv3d(
self.in_dim,
self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size
)
@property
def dtype(self) -> torch.dtype:
"""The dtype of the module (assuming that all the module parameters have the same dtype)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].dtype
else:
buffers = tuple(self.buffers())
return buffers[0].dtype
@property
def device(self) -> torch.device:
"""The device of the module (assuming that all the module parameters are in the same device)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].device
else:
buffers = tuple(self.buffers())
return buffers[0].device
def forward(
self, embs, vace_context, prompt_emb, time_emb, rotary_pos_emb
):
vace_context_embed = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
vace_context_reshape = [u.flatten(2).transpose(1, 2) for u in vace_context_embed]
c = torch.cat([
torch.cat([u, u.new_zeros(1, embs.shape[1] - u.size(1), u.size(2))],
dim=1) for u in vace_context_reshape
])
for block in self.vace_blocks:
c = block(c, embs, prompt_emb, time_emb, rotary_pos_emb)
hints = torch.unbind(c)[:-1]
return hints
class VACEModel(MultiModalModule):
def __init__(
self,
**kwargs
):
super().__init__(config=None)
self.vace_config = kwargs.pop('vace_dit')
self.wan_config = kwargs
with self.meta_init():
self.wan_dit = WanDiT(**self.wan_config)
self.vace_dit = VaceDit(**self.vace_config)
self.freeze()
@property
def dtype(self) -> torch.dtype:
"""The dtype of the module (assuming that all the module parameters have the same dtype)."""
return self.wan_dit.dtype
@property
def device(self) -> torch.device:
"""The device of the module (assuming that all the module parameters are in the same device)."""
return self.wan_dit.device
def post_init(self):
if "vace_pretrained" in self.vace_config and self.vace_config["vace_pretrained"] is not None:
load_checkpoint(self.vace_dit, self.vace_config['vace_pretrained'], assign=True)
elif self.vace_dit.device.type == "meta":
self.vace_dit = VaceDit(**self.vace_config)
self.vace_patch_embedding_replace(self.wan_dit, self.vace_dit)
def forward(
self,
latents: torch.Tensor = None,
timestep: torch.Tensor = None,
prompt: torch.Tensor = None,
prompt_mask: torch.Tensor = None,
vace_context=None,
vace_scale=1.0,
use_unified_sequence_parallel: bool = False,
**kwargs
):
timestep = timestep.to(latents[0].device)
times = self.wan_dit.time_embedding(
self.wan_dit.sinusoidal_embedding_1d(self.wan_dit.freq_dim, timestep)
)
time_emb = self.wan_dit.time_projection(times).unflatten(1, (6, self.wan_dit.hidden_size))
bs = prompt.size(0)
prompt = prompt.view(bs, -1, prompt.size(-1))
if prompt_mask is not None:
seq_lens = prompt_mask.view(bs, -1).sum(dim=-1)
seq_lens = seq_lens.to(torch.int64)
for i, seq_lens in enumerate(seq_lens):
prompt[i, seq_lens:] = 0
prompt_emb = self.wan_dit.text_embedding(prompt)
x = latents
patch_emb = self.wan_dit.patch_embedding(x.to(time_emb.dtype))
embs, grid_sizes = self.wan_dit.patchify(patch_emb)
batch_size, frames, height, width = (
embs.shape[0],
grid_sizes[0],
grid_sizes[1],
grid_sizes[2],
)
rotary_pos_emb = self.wan_dit.rope(batch_size, frames, height, width)
vace_hints = self.vace_dit(embs, vace_context, prompt_emb, time_emb, rotary_pos_emb)
for block_id, block in enumerate(self.wan_dit.blocks):
embs = block(embs, prompt_emb, time_emb, rotary_pos_emb)
if vace_context is not None and block_id in self.vace_dit.vace_to_wan:
current_vace_hint = vace_hints[self.vace_dit.vace_to_wan[block_id]]
embs = embs + current_vace_hint * vace_scale
embs_out = self.wan_dit.head(embs, times)
out = self.wan_dit.unpatchify(embs_out, frames, height, width)
rtn = (out, prompt, prompt_emb, time_emb, times, prompt_mask)
return rtn
def vace_patch_embedding_replace(self, wan_dit: WanDiT, vace_dit: VaceDit):
vace_dit.vace_patch_embedding.bias = copy.deepcopy(wan_dit.patch_embedding.bias)
weight_shape = list(vace_dit.vace_patch_embedding.weight.shape)
weight_shape[-1] -= wan_dit.patch_embedding.weight.shape[1] * 2
vace_dit.vace_patch_embedding.weight = torch.nn.Parameter(
torch.cat((copy.deepcopy(wan_dit.patch_embedding.weight),
copy.deepcopy(wan_dit.patch_embedding.weight),
torch.zeros(weight_shape, device=wan_dit.patch_embedding.weight.device, dtype=self.dtype)),
dim=1)
)
@contextmanager
def meta_init(self, device=torch.device("meta"), include_buffers: bool = False):
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs['device'] = device
return fn(*args, **kwargs)
return wrapper
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def freeze(self):
self.wan_dit.eval()
self.wan_dit.requires_grad_(False)
self.vace_dit.train()
self.vace_dit.requires_grad_(True)