GAN 训练

This implements training of RDN on the DIV2K_x2 dataset.

  • Reference implementation:
url=https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

Requirements

  • Install PyTorch (pytorch.org)
  • pip install -r requirements.txt
  • The MNIST Dataset can be downloaded from the links below.

Training

To train a model, run:

# 1p train perf
bash train_performance_1p.sh --data_path=data/mnist

# 8p train perf
bash train_performance_8p.sh --data_path=data/mnist

# 8p train full
bash train_full_8p.sh --data_path=data/mnist

# 8p eval
bash train_eval_8p.sh --data_path=data/mnist

# finetuning
bash train_finetune_1p.sh --data_path=data/mnist

After running,you can see the results in ./output

GAN training result

Acc@1 FPS Npu_nums Epochs AMP_Type
- 1642.130 1 200 O1
- 15275.049 8 200 O1

Statement

For details about the public address of the code in this repository, you can get from the file public_address_statement.md