SRGAN
This implements training of SRGAN on VOC2012 dataset, mainly modified from leftthomas/SRGAN. The Train dataset has 16700 images and The Val dataset has 425 images. Download the datasets from here(access code:5tzp).
Requirements
创建新的conda环境,安装下面列出的包。(@后面为包的版本信息)
- apex @20210930
- torch @20210930
pip install -r requirements.txt(若以安装高版本的对应的库,可跳过) 注:pillow建议安装较新版本, 与之对应的torchvision版本如果无法直接安装,可使用源码安装对应的版本,源码参考链接:https://github.com/pytorch/vision ,建议Pillow版本是9.1.0 torchvision版本是0.6.0- CANN @5.0.3
数据集准备
Train、Val Dataset
训练数据集有16700张图像,验证数据集有425张图像。可以通过这里进行 (5tzp)下载。创建一个名为 data 的文件夹,将下载好的数据集解压到该文件夹中。
Test Image Dataset
测试数据集共有9张图片,其中5张为 Set5 中的所有图片,其余四张为仓库中的示例图片。测试结果以Set5的结果为准。数据集可以在这里 (k9cj) 下载。将下载好的数据集同样解压到 data 文件夹中
最终数据集目录结构
data
|-- test
| |-- SRF_2
| | |-- target
| | |-- data
|-- VOC2012
| |-- train
| |-- val
Train
单p训练
注:若脚本不能正常运行,可以尝试使用 dos2unix test/* 命令转换后运行。
注:data_path可以是绝对路径,或者相对路径,由于脚本内会cd到test目录下,故相对路径为相对test目录的路径
注:性能日志、训练日志、精度结果都在SRGAN/test/output目录下
性能脚本
bash ./test/train_performance_1p.sh --data_path=../data
精度脚本
bash ./test/train_full_1p.sh --data_path=../data
多P训练
性能脚本:
bash ./test/train_performance_8p.sh --data_path=../data
精度脚本:
bash ./test/train_full_8p.sh --data_path=../data
使用train_full_xx.py 脚本时会自动运行测试脚本,训练结果和测试结果保存在 ./test/output 路径下。
SRGAN training result
| Device | FPS | Epochs | AMP_Type | PSNR | SSIM |
|---|---|---|---|---|---|
| NPU 1p_1.5 | 270 | 100 | O1 | 33.0558 | 0.9226 |
| NPU 8P_1.5 | 1200 | 100 | O1 | 32.1882 | 0.9172 |
| GPU 1p | 360 | 100 | O1 | 33.4604 | 0.9308 |
| GPU 8P | 1400 | 100 | O1 | 31.0824 | 0.9191 |
| NPU 1p_1.8 | 180 | 100 | O1 | 33.3234 | 0.9302 |
| NPU 8p_1.8 | 1200 | 100 | O1 | 33.2284 | 0.9312 |
训练结果示例 (npu_1p)
-
Set5_001.jpg

-
Set5_002.jpg

-
Set5_003.jpg

-
Set5_004.jpg

-
Set5_005.jpg
