SalsaNext for PyTorch [终止随版本演进]
目录
简介
模型介绍
SalsaNext采用编码器-解码器架构,其中编码器单元包含一组ResNet块,解码部分则结合了残差块上采样的特征。SalsaNext引入了一个新的上下文模块,用带有逐渐增加的感受野的新型残差膨胀卷积堆栈替换ResNet编码器块,并在解码器中添加了像素shuffle层。此外,SalsaNext将步长卷积改为平均池化,并应用中心辍学处理。为了直接优化Jaccard指数,将加权交叉熵损失与Lovasz-Softmax损失相结合。最后,通过贝叶斯处理计算云中每个点的先验(Epistemic)和随机(Aleatoric)不确定性。模型在Semantic-KITTI数据集上提供了全面的定量评估。
支持任务列表
本仓已经支持以下模型任务类型
| 模型 | 任务列表 | 是否支持 |
|---|---|---|
| SalsaNext | 训练 | ✔ |
代码实现
- 参考实现:
url=https://github.com/TiagoCortinhal/SalsaNext
commit_id=7548c124b48f0259cdc40e98dfc3aeeadca6070c
- 适配昇腾 AI 处理器的实现:
url=https://gitcode.com/Ascend/DrivingSDK.git
code_path=model_examples/SalsaNext
SalsaNext
准备训练环境
昇腾环境说明
请参考昇腾社区中《Pytorch框架训练环境准备》文档搭建昇腾环境,本仓已支持表1中软件版本。
表 1 昇腾软件版本支持表
| 软件类型 | 首次支持版本 |
|---|---|
| FrameworkPTAdapter | 7.0.0 |
| CANN | 8.1.RC1 |
表 2 三方库版本支持表
| 三方库 | 支持版本 |
|---|---|
| PyTorch | 2.1.0 |
准备源代码
- 克隆代码仓并应用补丁。
git clone https://github.com/TiagoCortinhal/SalsaNext.git
cp -f ${DrivingSDK_root_dir}/model_examples/SalsaNext/salsanext.patch SalsaNext
cp -rf ${DrivingSDK_root_dir}/model_examples/SalsaNext/test SalsaNext
cp -rf ${DrivingSDK_root_dir}/model_examples/SalsaNext/train SalsaNext
cd SalsaNext
git checkout 7548c124b48f0259cdc40e98dfc3aeeadca6070c
git apply --whitespace=fix salsanext.patch
模型数据准备
进入semantic-kitti官网,下载Semantic Segmentation and Panoptic Segmentation数据集。将数据集解压后放置或者链接到DrivingSDK/model_examples/SalsaNext/datasets路径下,数据集结构排布成如下格式:
- 文件夹结构
dataset
└──sequences
├──00
| ├── labels
| ├── velodyne
| ├── calib.txt
| ├── poses.txt
| └── times.txt
├──01
| ├── labels
| ├── velodyne
| ├── calib.txt
| ├── poses.txt
| └── times.txt
...
└──21
├── labels
├── velodyne
├── calib.txt
├── poses.txt
└── times.txt
安装模型环境
-
准备conda
创建conda环境,环境内需要python>=3.9
conda create -n salsanext python=3.9 -
安装基础依赖
conda activate salsanext安装Driving SDK 请参考昇腾Driving SDK代码仓说明编译安装Driving SDK
进入模型源码根目录
cd {DrivingSDK_root_dir}/model_examples/SalsaNext/使用pip指令安装模型所需的其他代码库: pip install -r requirements.txt
快速开始
训练任务
本任务主要提供单机的8卡训练脚本。
开始训练
进入模型根目录,
cd /${Model_root_dir}/
salsanext.yml为默认配置文件。
-
单机8卡性能
# epoch = 20 bash test/train_8p_performance.sh -d /path/to/dataset -a ./salsanext.yml -l ./ -
单机8卡精度
# epoch = 150 bash test/train_8p.sh -d /path/to/dataset -a ./salsanext.yml -l ./ -
训练脚本参数说明
- d[String]:数据集的路径
- a[String]:配置文件的路径
- l[String]:主日志文件夹的路径
- c[String]:要使用的device_id,该参数仅单卡时生效
训练结果
| 芯片 | 卡数 | global batch size | Max epochs | mIoU | FPS |
|---|---|---|---|---|---|
| 竞品A | 8p | 192 | 150 | 0.577 | 241.6 |
| Atlas 800T A2 | 8p | 192 | 150 | 0.581 | 197.2 |
变更说明
2025.03.06:首次发布。
2025.04.08:修改性能测试脚本错误,刷新性能数据,优化loss函数提高训练速度
2025.04.22:补充global batch size数据,修复FPS计算错误
2025.06.11:修复多卡拉起时文件夹创建冲突问题
2025.07.04:性能优化
FAQ
暂无。