文件最后提交记录最后更新时间
!4671 【fix】批量修改模型python版本,兼容环境上的python3.8版本 * fix python version 3 年前
init 4 年前
init 4 年前
init 4 年前
init 4 年前
init 4 年前
init 4 年前
!7376 optimize public_address_statement.md Merge pull request !7376 from 王凯宇/master 8 个月前
[众智][PyTorch]整改模型中的requirements.txt文件,删除torch,apex Signed-off-by: bailang <bailang12@h-partners.com> 3 年前
README.md

DCGAN

该目录为DCGAN在MNIST数据集上的训练与测试,主要参考实现eriklindernoren/PyTorch-GAN

DCGAN相关细节

  • 添加DDP多卡训练代码
  • 添加混合精度训练
  • 修改了训练逻辑,以解决原代码在多卡训练时的因数据被改动而无法反向传播的问题
  • 添加相关日志代码,模型保存代码

环境准备

  • 执行本样例前,请确保已安装有昇腾AI处理器的硬件环境,CANN包版本5.0.3
  • 该目录下的实现是基于PyTorch框架,其中torch版本为1.5.0+ascend.post3.20210930,使用的混合精度apex版本为0.1+ascend.20210930, torchvision版本为0.2.2.post3。

训练准备-关于数据集

  • 此DCGAN模型在MNIST的手写数据集上训练。
  • 数据集来自pytorch中的数据集接口,若指定的目录下无数据集,程序将会自动下载数据集
  • 参考一下命令下载数据集
bash test/get_mnist.sh --data_path=/path/to/data/root/
# 例:bash test/get_mnist.sh --data_path=./data/

快速运行

模型的训练文件详见main.py, 运行以下脚本能够进行单/多卡的训练和性能测试:

# data_path参数可任意指定一目录,建议为“./data/”
# train 1p performance
bash test/train_performance_1p.sh --data_path=/path/to/data/root/
# train 1p full
bash test/train_full_1p.sh --data_path=/path/to/data/root/
# 例:bash test/train_full_1p.sh --data_path=./data/

# train 8p performance
bash test/train_performance_8p.sh --data_path=/path/to/data/root/
# train 8p full
bash test/train_full_8p.sh --data_path=/path/to/data/root/

# 验证脚本使用保存的模型随机生成10张图片,
bash test/eval.sh --checkpoint_path=./checkpoint-amp-epoch_20.pth

DCGAN的训练结果

精度结果-生成图片展示

更多结果请参看./imgs文件夹

sample

性能结果

此处展示不同batchsize下,DCGAN模型在NPU下的FPS。提高batchsize的目的是在保证精度的情况下尽可能利用NPU性能。

Acc@1 FPS Npu_nums Epochs AMP_Type Batchsize
- 1008.340 1 20 O2 64
- 3099.795 1 20 O2 128
- 6258.619 8 20 O2 512
- 12742.566 8 20 O2 1024