文件最后提交记录最后更新时间
!4671 【fix】批量修改模型python版本,兼容环境上的python3.8版本 * fix python version 3 年前
init 4 年前
init 4 年前
init 4 年前
!5834 Network address of models to be rectified: 26 Merge pull request !5834 from Yss/network_declaration_26 2 年前
init 4 年前
!1891 [众智][PyTorch][GAN_PyTorch]-修改loss_scale为动态,调整学习率并添加绑核 * modified loss_scale, learning rate and add taskset 3 年前
init 4 年前
init 4 年前
!5834 Network address of models to be rectified: 26 Merge pull request !5834 from Yss/network_declaration_26 2 年前
[众智][PyTorch]整改模型中的requirements.txt文件,删除torch,apex Signed-off-by: bailang <bailang12@h-partners.com> 3 年前
README.md

GAN 训练

This implements training of RDN on the DIV2K_x2 dataset.

  • Reference implementation:
url=https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

Requirements

  • Install PyTorch (pytorch.org)
  • pip install -r requirements.txt
  • The MNIST Dataset can be downloaded from the links below.

Training

To train a model, run:

# 1p train perf
bash train_performance_1p.sh --data_path=data/mnist

# 8p train perf
bash train_performance_8p.sh --data_path=data/mnist

# 8p train full
bash train_full_8p.sh --data_path=data/mnist

# 8p eval
bash train_eval_8p.sh --data_path=data/mnist

# finetuning
bash train_finetune_1p.sh --data_path=data/mnist

After running,you can see the results in ./output

GAN training result

Acc@1 FPS Npu_nums Epochs AMP_Type
- 1642.130 1 200 O1
- 15275.049 8 200 O1

Statement

For details about the public address of the code in this repository, you can get from the file public_address_statement.md