DiT for PyTorch
目录
简介
模型介绍
Scalable Diffusion Models with Transformers,是完全基于transformer架构的扩散模型,这个工作不仅将transformer成功应用在扩散模型,还探究了transformer架构在扩散模型上的scalability能力,其中最大的模型DiT-XL/2在ImageNet 256x256的类别条件生成上达到了SOTA。
支持任务列表
本仓已经支持以下模型任务类型
| 模型 | 任务列表 | 是否支持 |
|---|---|---|
| DiT-XL/2 | 训练 | ✔ |
代码实现
-
参考实现:
url=https://github.com/facebookresearch/DiT commit_id=ed81ce2229091fd4ecc9a223645f95cf379d582b -
适配昇腾 AI 处理器的实现:
url=https://gitcode.com/ascend/ModelZoo-PyTorch.git code_path=PyTorch/built-in/mlm
DiT
准备训练环境
安装环境
表 1 三方库版本支持表
| 三方库 | 支持版本 |
|---|---|
| PyTorch | 2.1.0 |
-
在模型根目录下执行以下命令,安装模型对应PyTorch版本需要的依赖。
pip install -r requirements.txt
安装昇腾环境
请参考昇腾社区中《Pytorch框架训练环境准备》文档搭建昇腾环境,本仓已支持表2中软件版本。
表 2 昇腾软件版本支持表
| 软件类型 | 支持版本 |
|---|---|
| FrameworkPTAdaper | 6.0.RC2 |
| CANN | 8.0.RC2 |
| 昇腾NPU固件 | 24.1.RC2 |
| 昇腾NPU驱动 | 24.1.RC2 |
准备预训练权重
-
联网环境下使用以下命令会自动下载stabilityai/sd-vae-ft-mse预训练模型。如果网络问题无法自动下载,需要在官网手动下载,存放在任意文件夹中,文件夹内容如下所示,并修改train.py--line174指向上述路径
Your-sd-vae-ft-mse-PATH ├── config.json ├── diffusion_pytorch_model.bin ├── diffusion_pytorch_model.safetensors ├── README.md
准备数据集
- 自行下载准备imageNet2012数据集,目录结构如下。
├── ImageNet2012
├──train
├──类别1
│──图片1
│──图片2
│ ...
├──类别2
│──图片1
│──图片2
│ ...
├──...
├──val
├──类别1
│──图片1
│──图片2
│ ...
├──类别2
│──图片1
│──图片2
│ ...
说明:
该数据集的训练过程脚本只作为一种参考示例。
快速开始
训练任务
本任务主要提供单机的8卡训练脚本。
开始训练
-
进入源码根目录。
cd /${模型文件夹名称} -
运行训练脚本。
该模型支持单机8卡训练。
- 单机8卡训练
bash test/train_8p.sh --data_path=/PATH/ImageNet2012/train --image_size=256 --global_batch_size=256 --precision=fp32 --epochs=1模型训练脚本参数说明如下。
train_8p.sh --data_path //训练数据集实际所在路径,请用户根据实际情况修改 --image_size //图片大小,支持256和512 --global_batch_size //全局batch size设置 --precision // 训练精度,支持fp32和bf16 --epochs //训练轮数
训练结果
| 芯片 | 卡数 | image size | global batch size | Precision | 性能FPS |
|---|---|---|---|---|---|
| GPU | 8p | 256 | 256 | fp32 | 432 |
| Atlas 800T A2 | 8p | 256 | 256 | fp32 | 376 |
| GPU | 8p | 256 | 512 | bf16 | 727 |
| Atlas 800T A2 | 8p | 256 | 512 | bf16 | 586 |
| GPU | 8p | 512 | 64 | fp32 | 80 |
| Atlas 800T A2 | 8p | 512 | 64 | fp32 | 77 |
| GPU | 8p | 512 | 128 | bf16 | 151 |
| Atlas 800T A2 | 8p | 512 | 128 | bf16 | 122 |
在线推理
本任务主要提供单卡推理功能。
开始推理
- 单卡推理命令
python sample.py --model DiT-XL/2 --image-size 256 --ckpt /path/to/model.pt
脚本入参说明如下。
sample.py
--model //模型结构
--image-size //图片大小,支持256和512
--ckpt //权重路径,支持官方开源权重和自己训练的权重
公网地址说明
代码涉及公网地址参考 public_address_statement.md
变更说明
2024.05.15:首次发布。
FAQ
无