新模型开发
===========================
Last updated: 12/08/2025. Author: cxiaolong
本文档介绍了如何利用 MindSpeed-MM FSDP2后端开发一个基于DiT结构的视频生成模型训练demo。整体代码可以参考 `develop a new model in MM <https://gitcode.com/cxiaolong/MindSpeed-MM_demo/commit/9adc91fe4430a68a30fb310ab2402ff309f3a745>`_
开发流程
------------
该流程主要包含环境搭建、数据集构建、模型构建、配置文件、训练入口、训练脚本、启动训练。
.. image:: ../_static/dev_guide/new_model_dev/flow.png
:width: 1200px
:align: center
Step1: 环境搭建
-------------------
1. 硬件准备
昇腾A3或A2加速卡,建议8卡或以上
2. 参考 `环境搭建 <https://mindspeed-mm.readthedocs.io/zh-cn/latest/quick_start/%E7%8E%AF%E5%A2%83%E6%90%AD%E5%BB%BA.html>`_ 章节,完成HDK、CANN、PyTorch和TorchNPU基础环境搭建;
3. 创建并激活Python虚拟环境;
.. code:: bash
conda create -n [env_name] python=3.10
conda activate [env_name]
# 安装CANN latest下的te包,这会自动安装sympy,decorator等必装包
pip install [CANN_HOME_PATH]/ascend-toolkit/latest/lib64/te-*-py3-none-any.whl
4. 按照下面的步骤安装MindSpeed-MM及其依赖包:
.. code:: bash
git clone --branch 26.0.0 https://gitcode.com/Ascend/MindSpeed-MM.git
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
git checkout core_v0.12.1
cp -r megatron ../MindSpeed-MM/
cd ..
cd MindSpeed-MM
mkdir logs data ckpt
# 安装加速库
git clone https://gitcode.com/Ascend/MindSpeed.git
cd MindSpeed
# checkout commit from MindSpeed core_r0.12.1
git checkout 93c45456c7044bacddebc5072316c01006c938f9
# 安装mindspeed及依赖
pip install -e .
cd ..
# 安装mindspeed mm及依赖
pip install -e .
Step2: 数据集构建
--------------------
准备demo数据
^^^^^^^^^^^^^^
请准备训练的demo数据,并组织成如下格式:
.. code:: text
├── data
├── data.json
├── videos
├── video000001.mp4
├── video000002.mp4
├── ...
videos/下存放视频文件,data.json中包含该数据集中所有的视频-文本对信息,具体示例如下:
.. code:: json
[
{
"path": "./videos/video000001.mp4",
"cap": "A scenic view of mountains during sunrise.",
"num_frames": 81,
"fps": 24,
"resolution": {480, 832}
},
{
"path": "./videos/video000002.mp4",
"cap": "A bustling city street with people walking and cars passing by.",
"num_frames": 81,
"fps": 24,
"resolution": {480, 832}
},
...
]
Dataset 构建
^^^^^^^^^^^^^^^^^^^
可以使用MM仓上已有的T2V相关DataSet,也可以自行定义数据集,本教程使用自定义的CustomT2VDataset。
新建 ``data/datasets/custom_t2v_dataset.py``
.. code:: python
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from transformers import AutoTokenizer
from mindspeed_mm.data.data_utils.data_transform import ResizeVideo, ToTensorVideo, CenterCropResizeVideo
from mindspeed_mm.data.data_utils.utils import TextProcesser
class CustomT2VDataset(Dataset):
def __init__(
self,
data_folder,
json_path,
tokenizer_config,
num_frames=49,
max_height=480,
max_width=832,
**kwargs
):
super().__init__(**kwargs)
self.data_samples = pd.read_json(json_path)
self.data_folder = data_folder
self.num_frames = num_frames
# Initialize tokenizer and text processor
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_config)
self.text_processer = TextProcesser(
tokenizer=tokenizer,
text_preprocess_methods=[{"method": "basic_clean"},{"method": "whitespace_clean"}],
)
# Initialize video transforms
self.video_transforms = transforms.Compose([
ResizeVideo(
transform_size={"max_height": max_height, "max_width": max_width},
interpolation_mode="bilinear",
antialias=True,
mode="shortside"
),
ToTensorVideo(),
transforms.Normalize(mean=0.5, std=0.5),
CenterCropResizeVideo(transform_size={"max_height": max_height, "max_width": max_width}, antialias=True)
])
def __getitem__(self, index):
sample = self.data_samples.iloc[index]
file_name = sample["path"]
captions = sample["cap"]
video_path = os.path.join(self.data_folder, file_name)
prompt_ids, prompt_mask = self.text_processer(captions)
video_tensor = torchvision.io.read_video(video_path, pts_unit="sec", output_format="TCHW")[0][:self.num_frames]
video_tensor = self.video_transforms(video_tensor)
video_tensor = video_tensor.permute(1, 0, 2, 3) # TCHW -> CTHW
return {
"video": video_tensor,
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask
}
def __len__(self):
return len(self.data_samples)
DataLoader 构建
^^^^^^^^^^^^^^^^^^^
MindSpeed-MM 提供了丰富的DataLoader组件,调用入口为 ``mindspeed_mm/data/build_mm_dataloader`` ,当前实现了三种类型可供选择:
* ``base`` : 封装了原生的 ``torch.utils.data.DataLoader``
* ``sampler`` : 通过构建自定义分布式训练Sampler封装的DataLoader
* ``variable`` : 支持动态分辨率的DataLoader
本教程选择使用 ``sampler`` 类型的DataLoader
Step3: 模型构建
------------------
MindSpeed-MM 中提供了一个SoRAModel作为所有扩散视频生成模型的组合类,模型继承关系如下。SoRAModel是一个组合类,可以实例化成Wan、HunyuanVideo等具体的模型,由TextEncoder、PredictModel、DiffusionModel、AEModel多个部件组成。
.. image:: ../_static/dev_guide/new_model_dev/sora_model.png
:width: 800px
:align: center
本教程将新构建一个 ``CustomModel`` 用于表示自定义的视频生成模型的组合类,它由 ``PrediciModel(CustomDiT)``, ``TextEncoder(UMT5)``, ``AEModel(WanVideoVAE)``, ``DiffusionModel(WanFlowMatchScheduler)`` 四部分组成。
CustomModel构建
^^^^^^^^^^^^^^^^^^^^^
.. code:: python
import torch
from torch import nn
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.ae import AEModel
from mindspeed_mm.models.diffusion import DiffusionModel
from mindspeed_mm.models.text_encoder import TextEncoder
from mindspeed_mm.models.predictor import PredictModel
from mindspeed_mm.models.transformers.base_model import FSDP2Mixin, WeightInitMixin
class CustomModel(nn.Module, FSDP2Mixin, WeightInitMixin):
def __init__(self, config):
super().__init__()
args = get_args()
self.config = core_transformer_config_from_args(args)
self.ae = AEModel(config.ae).eval()
self.ae.requires_grad_(False)
self.text_encoder = TextEncoder(config.text_encoder).eval()
self.text_encoder.requires_grad_(False)
self.diffusion = DiffusionModel(config.diffusion).get_model()
self.predictor = PredictModel(config.predictor).get_model()
def forward(self, video, prompt_ids, prompt_mask=None):
# encode vision and text
with torch.no_grad():
latents, _ = self.ae.encode(video)
prompt_embeds, prompt_mask = self.text_encoder.encode(prompt_ids, prompt_mask)
# q sample to add noise
noised_latents, noise, timesteps = self.diffusion.q_sample(latents)
# Diffusion Transformer forward to predict
output = self.predictor(noised_latents, timesteps, prompt_embeds)
loss = self._compute_loss(
output,
latents,
noised_latents,
timesteps,
noise
)
return loss
def _compute_loss(self, model_output, latents, noised_latents, timesteps, noise):
"""compute diffusion loss"""
loss_dict = self.diffusion.training_losses(
model_output=model_output,
x_start=latents,
x_t=noised_latents,
noise=noise,
t=timesteps
)
return loss_dict
def train(self, mode=True):
self.predictor.train()
def state_dict(self):
"""Customized state_dict for fsdp2"""
return self.predictor.state_dict()
def set_input_tensor(self, input_tensor):
self.input_tensor = input_tensor
self.predictor.set_input_tensor(input_tensor)
.. note::
CustomModel需要继承 ``FSDP2Mixin`` ``WeightInitMixin`` 以提供FSDP2和权重初始化能力。
* FSDP2Mixin: 提供了 ``fully_shared`` 的基础能力;
* WeightInitMixin: 提供了权重初始化能力,当使用FSDP2 ``meta-device`` 初始化后,需要重初始化权重。
CustomDiT构建
^^^^^^^^^^^^^^^^^^^^^
参考WanVideo的模型结构,自行实现一个自定义的DiT结构如下:
新建 ``mindspeed_mm/models/predictor/dits/custom_dit.py``
.. code:: python
from typing import Tuple, Dict, Union
import math
from einops import rearrange
import torch
from torch import nn
from diffusers.models.attention import FeedForward
from diffusers.models.transformers.transformer_wan import (
WanRotaryPosEmbed,
WanTimeTextImageEmbedding,
WanAttention,
WanAttnProcessor
)
from mindspeed_mm.models.common.module import MultiModalModule
class CustomDiT(MultiModalModule):
def __init__(
self,
num_layers: int = 40,
num_heads: int = 40,
head_dim: int = 128,
patch_size: Tuple[int, ...] = (1, 2, 2),
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
rope_max_seq_len: int = 1024,
**kwargs
) -> None:
super().__init__(config=None)
self.patch_size = patch_size
self.hidden_size = num_heads * head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, self.hidden_size, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
self.condition_embedder = WanTimeTextImageEmbedding(
dim=self.hidden_size,
time_freq_dim=freq_dim,
time_proj_dim=self.hidden_size * 6,
text_embed_dim=text_dim,
)
# 3. DiT blocks
self.blocks = nn.ModuleList(
[
CustomDiTBlock(
hidden_size=self.hidden_size,
num_heads=num_heads,
head_dim=head_dim,
ffn_dim=ffn_dim,
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = nn.LayerNorm(self.hidden_size, elementwise_affine=False)
self.proj_out = nn.Linear(self.hidden_size, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, self.hidden_size) / self.hidden_size ** 0.5)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# 1. Get conditioning embeddings
temb, timestep_proj, encoder_hidden_states, _ = self.condition_embedder(
timestep=timestep,
encoder_hidden_states=encoder_hidden_states.squeeze(1)
)
timestep_proj = timestep_proj.unflatten(1, (6, -1)) # [batch_size, 6, inner_dim]
rotary_emb = self.rope(hidden_states)
# 2. Patch embedding
hidden_states = self.patch_embedding(hidden_states.to(temb.dtype))
batch_size, _, frames, height, width = hidden_states.shape
# 3. Patchify
hidden_states = self.patchify(hidden_states)
# 4. Transformer blocks
for block in self.blocks:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=timestep_proj,
rotary_emb=rotary_emb
)
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift)
hidden_states = self.proj_out(hidden_states)
# Unpatchify
return self.unpatchify(hidden_states, frames, height, width)
def patchify(self, embs: torch.Tensor):
# b c f h w -> b (f h w) c
patch_out = rearrange(embs, "b c f h w -> b (f h w) c").contiguous()
return patch_out
def unpatchify(self, embs: torch.Tensor, frames: int, height: int, width: int):
# b (f h w) (p0 p1 p2 c) -> b c (f*p0) (h*p1) (w*p2)
return rearrange(
embs,
"b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
f=frames,
h=height,
w=width,
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2],
)
class CustomDiTBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
ffn_dim: int
):
super().__init__()
# 1. Self-attention
self.norm1 = nn.LayerNorm(hidden_size)
self.attn1 = WanAttention(hidden_size, num_heads, head_dim, cross_attention_dim_head=None, processor=WanAttnProcessor())
# 2. Cross-attention
self.attn2 = WanAttention(hidden_size, num_heads, head_dim, cross_attention_dim_head=hidden_size // num_heads, processor=WanAttnProcessor())
# 3. Feed-forward
self.ffn = FeedForward(hidden_size, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = nn.LayerNorm(hidden_size)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
# Split scale-shift table into components
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.to(temb.device) + temb
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa)
attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states + attn_output * gate_msa)
# 2. Cross-attention
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states + ff_output * c_gate_msa)
return hidden_states
.. note::
构建好CustomDiT后,需要将模型注册到PredictModel的 ``PREDICTOR_MODEL_MAPPINGS`` 中。
Step4: 配置文件
-----------------
配置文件主要包括模型配置(model.json),数据配置(data.json),FSDP2配置(fsdp2.yaml),配置文件内容可以参考 `配置说明 <https://mindspeed-mm.readthedocs.io/zh-cn/latest/config/基础配置.html>`_ , 这里Diffusion、AE和text_encoder使用Wan2.2的模型。
新建 ``examples/custom_model/model.json`` ``examples/custom_model/data.json`` ``examples/custom_model/fsdp2.yaml``
model.json
^^^^^^^^^^^^^
.. code:: json
{
"diffusion": {
"model_id": "wan_flow_match_scheduler",
"num_train_timesteps": 1000,
"shift": 5,
"sigma_min": 0.0,
"extra_one_step": true
},
"predictor": {
"model_id": "custom_dit",
"num_layers": 30,
"num_heads": 12,
"head_dim": 128,
"patch_size": [1, 2, 2],
"in_channels": 16,
"out_channels": 16,
"text_dim": 4096,
"freq_dim":256,
"ffn_dim": 8960,
"rope_max_seq_len": 1024
},
"text_encoder": {
"model_id": "UMT5",
"hub_backend": "hf",
"from_pretrained": "./ckpt/Wan-AI/Wan2.2-T2V-A14B-Diffusers/text_encoder/",
"dtype": "bf16"
},
"ae": {
"model_id": "wan_video_vae",
"from_pretrained": "./ckpt/Wan-AI/Wan2.2-T2V-A14B-Diffusers/vae/",
"dtype": "bf16",
"enable_tiling": false,
"tiling_param": {
"tile_sample_min_height": 256,
"tile_sample_min_width": 256,
"tile_sample_stride_height": 192,
"tile_sample_stride_width": 192
},
"norm_latents": true,
"norm_mode": "channel_specified_shift_scale",
"do_sample": false
}
}
data.json
^^^^^^^^^^^^^
.. code:: json
{
"dataset_param": {
"data_folder": "./data/video_demo_dataset/",
"json_path": "./data/video_demo_dataset/annotation.json",
"num_frames": 49,
"max_height": 480,
"max_width": 832,
"tokenizer_config": {
"pretrained_model_name_or_path": "./ckpt/Wan-AI/Wan2.2-T2V-A14B-Diffusers/tokenizer/",
"model_max_length": 512
}
},
"dataloader_param": {
"dataloader_mode": "sampler",
"sampler_type": "SequentialSampler",
"drop_last": true,
"pin_memory": true,
"group_frame": false,
"group_resolution": false
}
}
fsdp2.yaml
^^^^^^^^^^^^^
.. code:: yaml
sharding_size: 8
sub_modules_to_wrap:
- predictor.blocks.{*}
- predictor.head
reshard_after_forward: True
param_dtype: "bf16"
reduce_dtype: "fp32"
ignored_modules:
- ae
- text_encoder
cast_forward_inputs: True
recompute_modules:
- predictor.blocks.{*}
offload_to_cpu: False
Step5: 训练入口
--------------------
新建 ``mindspeed_mm/pretrain_custom.py`` , 训练入口采用Megatron范式,主要需要实现以下函数:
================================== =====
func 描述
================================== =====
model_provider 模型构建函数,返回模型实例
get_batch 获取batch输入数据,forward_step中使用
loss_func 计算模型report损失用于打印,并把真实loss往下传递,用于Megatron后续处理
forward_step 模型前向,返回tuple类型,值为(模型前向输出,loss_func)
train_valid_test_datasets_provider 构建数据加载器
================================== =====
.. code:: python
import torch
import mindspeed.megatron_adaptor
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.training import get_args, print_rank_0
from megatron.training.utils import average_losses_across_data_parallel_group
from mindspeed_mm.configs.config import mm_extra_args_provider
from mindspeed_mm.training import pretrain
from mindspeed_mm.data import build_mm_dataloader
from mindspeed_mm.data.datasets.custom_t2v_dataset import CustomT2VDataset
from mindspeed_mm.data.data_utils.utils import build_iterations
from mindspeed_mm.models.custom_model import CustomModel
def model_provider(pre_process=True, post_process=True):
"""Builds custom model."""
args = get_args()
print_rank_0("building Custom model ...")
model = CustomModel(args.mm.model)
return model
def get_batch(data_iterator):
"""Generate a batch."""
if data_iterator is not None:
batch = next(data_iterator, None)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(torch.cuda.current_device())
return batch
else:
return None
def loss_func(output_tensor):
"""Loss function."""
loss = output_tensor.mean()
reporting_loss = average_losses_across_data_parallel_group([loss])
loss = loss.unsqueeze(0)
return loss, {"loss": reporting_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
batch = get_batch(data_iterator)
output_tensor = model(**batch)
return output_tensor, loss_func
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
data_config = args.mm.data
train_dataset = CustomT2VDataset(**data_config.dataset_param.to_dict())
train_dataloader = build_mm_dataloader(
train_dataset,
data_config.dataloader_param,
process_group=mpu.get_data_parallel_group(),
consumed_samples=args.consumed_train_samples,
dataset_param=data_config.dataset_param,
)
data_iterator, _, _ = build_iterations(train_dl=train_dataloader)
return data_iterator, None, None
if __name__ == "__main__":
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
extra_args_provider=mm_extra_args_provider,
args_defaults={"dataloader_type": "external", "vision_pretraining": False, "curr_forward_iteration": 0},
)
Step6: 训练脚本
-----------------
新建 ``examples/custom_model/pretrain.sh``
.. code:: bash
#!/bin/bash
source /usr/local/Ascend/cann/set_env.sh
export CUDA_DEVICE_MAX_CONNECTIONS=2 # 开启FSDP2时,不能置为1
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6007
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
MBS=1
GRAD_ACC_STEP=1
DP=$(($WORLD_SIZE))
GBS=$(($MBS*$GRAD_ACC_STEP*$DP))
MM_DATA="./examples/custom_model/data.json"
MM_MODEL="./examples/custom_model/model.json"
MM_TOOL="./mindspeed_mm/tools/tools.json"
fsdp2_config="./examples/custom_model/fsdp2_config.yaml"
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--micro-batch-size ${MBS} \
--global-batch-size ${GBS} \
--num-workers 8 \
--lr 1e-5 \
--min-lr 1e-5 \
--adam-beta1 0.9 \
--adam-beta2 0.999 \
--adam-eps 1e-8 \
--lr-decay-style constant \
--weight-decay 1e-2 \
--lr-warmup-init 0 \
--lr-warmup-iters 0 \
--clip-grad 1.0 \
--train-iters 5000 \
--no-gradient-accumulation-fusion \
--use-torch-fsdp2 \
--fsdp2-config-path ${fsdp2_config} \
--untie-embeddings-and-output-weights \
"
MM_ARGS="
--mm-data $MM_DATA \
--mm-model $MM_MODEL \
--mm-tool $MM_TOOL
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 10000 \
--eval-interval 10000 \
--eval-iters 10 \
--ckpt-format torch_dcp \
"
logfile=$(date +%Y%m%d)_$(date +%H%M%S)
mkdir -p logs
torchrun $DISTRIBUTED_ARGS pretrain_custom.py \
$GPT_ARGS \
$MM_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
2>&1 | tee logs/train_${logfile}.log
Step7: 启动训练
------------------
执行下面脚本进行模型训练
.. code:: bash
bash examples/custom_model/pretrain.sh