FLUX DanceGRPO 使用指南

目录

简介

以 MindSpeed MM 仓库复现 DanceGRPO 后训练方法来帮助用户快速入门,前期需要完成代码仓、环境、数据集以及权重等准备工作,再按照说明中的启动方式启动训练,以下为具体的操作说明。

参考实现

DanceGRPO开源代码仓以及对应commit id如下:

url=https://github.com/XueZeyue/DanceGRPO
commit_id=2149f36f22db601f9dbf70472fea11576f62a0f6

环境安装

【模型开发时推荐使用配套的环境版本】

请参考安装指南

DanceGRPO场景下,Python版本推荐3.10

1. 仓库拉取

git clone --branch 26.0.0 https://gitcode.com/Ascend/MindSpeed-MM.git
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
git checkout core_v0.12.1
cp -r megatron ../MindSpeed-MM/
cd ..

cd MindSpeed-MM
mkdir -p logs data ckpt
cd ..

2. 环境搭建

# python3.10
conda create -n test python=3.10
conda activate test

# 对于x86的设备,若遇到有关torchvision的导包问题,建议优先检查环境中的torchvision版本是否为`+cpu`版本,建议使用以下源配置解决此类问题
# pip config set global.extra-index-url "https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi"
# 安装torch和torch_npu
pip install torch-2.7.1+cpu-cp310-cp310-*.whl
pip install torch_npu-2.7.1*.whl

# 安装加速库
git clone https://gitcode.com/Ascend/MindSpeed.git
cd MindSpeed
git checkout 5176c6f5f133111e55a404d82bd2dc14a809a6ab
cp -r mindspeed ../MindSpeed-MM/
cd ..

# 安装dance grpo依赖库
cd MindSpeed-MM
pip install -r ./examples/dancegrpo/requirements-lint.txt
cd ..

git clone https://github.com/tgxs002/HPSv2.git
cd HPSv2
git checkout 866735ecaae999fa714bd9edfa05aa2672669ee3
pip install -e . 
cd ..

3.Decord搭建

【X86版安装】

pip install decord==0.6.0

【ARM版安装】

apt方式安装请参考链接

yum方式安装请参考脚本

权重下载

创建保存权重的目录:

cd MindSpeed-MM
mkdir ckpt/flux
mkdir ckpt/hps_ckpt
cd ..

下载FLUX预训练权重 FLUX预训练权重 ,下载至MindSpeed MM工程根目录下的ckpt/flux目录中。

下载HPS-v2.1预训练权重 HPS-v2.1预训练权重 ,将其中的HPS_v2.1_compressed.pt下载至MindSpeed MM工程根目录下的ckpt/hps_ckpt目录中。

下载CLIP预训练权重 CLIP预训练权重 ,将其中的open_clip_pytorch_model.bin下载至MindSpeed MM工程根目录下的ckpt/hps_ckpt目录中。

数据集准备及处理

下载FLUX DanceGRPO使用的提示词数据集。在文件页面点击download raw file下载文件至MindSpeed MM工程根目录的data目录下。

数据集下载完成后要对数据进行预处理,在启动预处理之前,可以根据自身训练配置需要修改数据预处理脚本的配置,以FLUX模型为例:

  1. vae模型权重所在路径为LOAD_PATH,默认为ckpt/flux;
  2. 预处理后的数据集存放路径为OUTPUT_DIR,默认为data/rl_embeddings;
  3. 提示词文件路径为PROMPT_DIR,默认为data/prompts.txt。

上述注意点修改完毕后,可启动脚本进行数据预处理:

cd MindSpeed-MM
bash examples/dancegrpo/preprocess_flux_rl_embeddings.sh

处理后的数据默认会存储在MindSpeed MM根目录下的data/rl_embeddings目录中。

训练

1. 准备工作

配置脚本前需要完成前置准备工作,包括:环境安装权重下载数据集准备及处理,详情可查看对应章节。

2. 三方库修改

找到使用的Python环境的根目录,对于使用conda安装的环境,可以使用如下指令找到:

echo $(conda info --envs | grep test) | awk '{print $NF}'
  1. 将文件lib/python3.10/site-packages/diffusers/models/embeddings.pyFluxPosEmbed类的forward函数的如下代码:

    is_mps = ids.device.type == "mps"
    freqs_dtype = torch.float32 if is_mps else torch.float64
    

    修改为:

    is_mps = ids.device.type == "mps"
    is_npu = ids.device.type == "npu"
    freqs_dtype = torch.float32 if is_mps or is_npu else torch.float64
    
  2. 将文件lib/python3.10/site-packages/diffusers/models/embeddings.py中的get_1d_rotary_pos_embed函数的如下代码:

    freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
    freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
    

    修改为:

    freqs_cos = freqs.cos().T.repeat_interleave(2, dim=0).T.contiguous().float()
    freqs_sin = freqs.sin().T.repeat_interleave(2, dim=0).T.contiguous().float()
    
  3. 将文件lib/python3.10/site-packages/diffusers/models/attention_processor.pyAttention类的__init__函数的如下代码:

    elif qk_norm == "rms_norm":
        self.norm_q = RMSNorm(dim_head, eps=eps)
        self.norm_k = RMSNorm(dim_head, eps=eps)
    

    修改为:

    elif qk_norm == "rms_norm":
        self.norm_q = NpuFusedRMSNorm(dim_head, eps=eps)
        self.norm_k = NpuFusedRMSNorm(dim_head, eps=eps)
    

    增加如下类:

    class NpuFusedRMSNorm(torch.nn.Module):
        def __init__(self, hidden_size, eps=1e-6):
            super().__init__()
            self.weight = nn.Parameter(torch.ones(hidden_size))
            self.eps = eps
    
        def forward(self, x):
            return torch_npu.npu_rms_norm(x.to(self.weight.dtype), self.weight, epsilon=self.eps)[0]
    

3. 启动训练

以 FLUX 模型为例,在启动训练之前,可根据自身训练配置需要修改启动脚本的配置:

  1. 根据使用机器的情况,修改NNODESNPUS_PER_NODE配置, 例如单机8卡 可设置NNODES为 1 、NPUS_PER_NODE为8;
  2. 如果为多机训练,需要保证各个节点的MASTER_ADDR一致,且为其中一台节点的IP;各节点的MASTER_PORT 配置为相同端口号;从IP为MASTER_ADDR的节点开始,将各节点的NODE_RANK配置为从0开始依次递增的整数;
  3. 数据集配置信息路径为MM_DATA,默认路径为./examples/dancegrpo/data_dancegrpo.json;
  4. 模型配置信息路径为MM_MODEL,默认路径为./examples/dancegrpo/model_dancegrpo.json;
  5. DiT模型预训练权重加载路径为LOAD_PATH,默认路径为ckpt/flux,用户也可以根据自身权重存放位置进行调整;
  6. 训练权重的保存路径为SAVE_PATH,默认为save_dir;
  7. 模型训练过程的reward值保存文件的路径为HPS_REWARD_SAVE_PATH,默认为./hps_reward.txt。

在启动训练前,可根据自身训练配置需要修改数据集配置data_dancegrpo.json

  1. dataset_param.basic_parameters.data_path表示预处理数据中的元数据文件videos2caption.json的路径。

在启动训练前,可根据自身训练配置需要修改模型配置model_dancegrpo.json

  1. reward.ckpt_dir表示奖励模型预训练权重的路径。

上述注意点修改完毕后,可启动脚本开启训练:

bash examples/dancegrpo/posttrain_flux_dancegrpo.sh

注意:所有节点的代码、权重、数据等路径的层级要保持一致,且启动训练脚本的时候都位于MindSpeed MM目录下

训练完成后,会在logs目录中生成运行日志文件,生成训练reward记录文件。


性能数据

模型 机器型号 集群 任务 GBS 端到端 SPS
FLUX DanceGRPO Atlas 200T A2 Box16 1*8 微调 32 0.1123

注:此处 SPS 代表 Samples per Second。


FAQ

  1. 对于CPU型号为x86的设备,建议使用torchvision版本为0.22.1+cpu,若遇到有关torchvision的导包问题,建议优先检查环境中的torchvision版本是否为+cpu版本。