Sparse4D

目录

简介

模型介绍

近年来,基于鸟瞰图的方法在多视角三维检测任务中取得了很大进展。与基于BEV的方法相比,稀疏方法在性能上有所落后,但仍有许多不可忽视的优点。为了进一步推动稀疏3D检测,地平线提出一种名为Sparse4D的新方法,该方法通过稀疏采样和融合时空特征对锚框进行迭代细化。

  • 稀疏四维采样:对于每个3D锚点,作者分配多个四维关键点,然后将其投影到多视图/尺度/时间戳图像特征上,以采样相应的特征;
  • 层次特征融合:对不同视角/尺度、不同时间戳、不同关键点的采样特征进行层次融合,生成高质量的实例特征。

这样,Sparse4D可以高效有效地实现3D检测,而不依赖于密集的视图变换和全局关注,并且对边缘设备的部署更加友好。

代码实现

  • 参考实现:

    url=https://github.com/HorizonRobotics/Sparse4D
    commit_id=c41df4bbf7bc82490f11ff55173abfcb3fb91425
    
  • 适配昇腾 AI 处理器的实现:

    url=https://gitcode.com/Ascend/DrivingSDK.git
    code_path=model_examples/Sparse4D
    

准备训练环境

安装昇腾环境

请参考昇腾社区中《Pytorch框架训练环境准备》文档搭建昇腾环境。本仓已支持表1中软件版本。

表 1 昇腾软件版本支持表

软件类型 首次支持版本
FrameworkPTAdapter 7.0.0
CANN 8.1.RC1

安装模型环境

当前模型支持的 PyTorch 版本和已知三方库依赖如下表所示。

表 2 版本支持表

三方库 支持版本
PyTorch 2.1.0, 2.7.1
mmcv 1.x
mmdet 2.28.2
  • 安装Driving SDK

    请参考昇腾Driving SDK代码仓说明编译安装Driving SDK

  • 克隆代码仓到当前目录

  • 推荐使用依赖安装一键配置脚本,可使用如下指令完成后续基础依赖,mmcv, mmdet,模型代码patch的安装和更新:

    bash install_Sparse4D.sh
    

    一键安装默认使用 Pytorch 2.1.0 版本,如需更换 Pytorch 2.7.1 版本,请自行修改脚本中 requirements.txt 为 requirements_pytorch2.7.1.txt

  • 安装基础依赖

    在模型根目录下执行命令,根据 Pytorch 版本安装模型需要的依赖

    若服务器为A2/A3,安装步骤如下:

    cd DrivingSDK/model_examples/Sparse4D
    # PyTorch 2.1.0
    pip install -r requirements.txt
    # Pytorch 2.7.1
    pip install -r requirements_pytorch2.7.1.txt
    

    若服务器为A5,安装步骤如下:

    cd DrivingSDK/model_examples/Sparse4D
    pip install -r requirements_pytorch2.7.1_a5.txt
    
  • 源码安装mmcv

    若服务器为A2/A3,安装步骤如下:

    git clone -b 1.x https://github.com/open-mmlab/mmcv.git
    cp mmcv.patch mmcv
    cd mmcv
    git apply mmcv.patch
    MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install
    cd ..
    

    若服务器为A5,安装步骤如下:

    git clone -b 1.x https://github.com/open-mmlab/mmcv.git
    cp mmcv_a5.patch mmcv
    cd mmcv
    git apply mmcv_a5.patch
    MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install
    cd ..
    
  • 源码安装mmdet

    git clone -b v2.28.2 https://github.com/open-mmlab/mmdetection.git
    cp mmdet.patch mmdetection
    cd mmdetection
    git apply mmdet.patch
    pip install -e . --no-build-isolation
    cd ..
    
  • 模型代码使用Patch

    若服务器为A2/A3,步骤如下:

    git clone https://github.com/HorizonRobotics/Sparse4D.git
    cp Sparse4D.patch Sparse4D
    cp patch.py Sparse4D/tools
    cd Sparse4D
    git checkout c41df4bbf7bc82490f11ff55173abfcb3fb91425
    git apply Sparse4D.patch
    cp -rf ../test .
    

    若服务器为A5,步骤如下:

    git clone https://github.com/HorizonRobotics/Sparse4D.git
    cp Sparse4D.patch Sparse4D
    cp patch_a5.py Sparse4D/tools
    mv Sparse4D/tools/patch_a5.py Sparse4D/tools/patch.py
    cd Sparse4D
    git checkout c41df4bbf7bc82490f11ff55173abfcb3fb91425
    git apply Sparse4D.patch
    cp -rf ../test .
    

准备数据集

获取训练数据

用户自行获取nuscenes数据集,在源码目录创建软连接data/nuscenes指向解压后的nuscenes数据目录

mkdir data
ln -s path/to/nuscenes ./data/nuscenes

运行数据预处理脚本生成Sparse4D模型训练需要的pkl文件

pkl_path="data/nuscenes_anno_pkls"
mkdir -p ${pkl_path}
python3 tools/nuscenes_converter.py --version v1.0-trainval,v1.0-test --info_prefix ${pkl_path}/nuscenes

通过K-means生成初始锚框

export OPENBLAS_NUM_THREADS=2
export GOTO_NUM_THREADS=2
export OMP_NUM_THREADS=2
python3 tools/anchor_generator.py --ann_file ${pkl_path}/nuscenes_infos_train.pkl --output_file_name nuscenes_kmeans900.npy

获取预训练权重

下载backbone预训练权重

mkdir ckpt
wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O ckpt/resnet50-19c8e357.pth

使用高性能内存库(可选)

安装tcmalloc(适用OS: openEuler

mkdir gperftools
cd gperftools
wget https://github.com/gperftools/gperftools/releases/download/gperftools-2.16/gperftools-2.16.tar.gz
tar -zvxf gperftools-2.16.tar.gz
cd gperftools-2.16
./configure --prefix=/usr/local/lib --with-tcmalloc-pagesize=64
make
make install
echo '/usr/local/lib/lib/' >> /etc/ld.so.conf
ldconfig
export LD_LIBRARY_PATH=/usr/local/lib/lib/:$LD_LIBRARY_PATH
export PATH=/usr/local/lib/bin:$PATH
export LD_PRELOAD=/usr/local/lib/lib/libtcmalloc.so.4

注意:需要安装OS对应tcmalloc版本(以下以 Ubuntu 为例)

# 安装autoconf和libtool
apt-get update
apt install autoconf
apt install libtool
git clone https://github.com/libunwind/libunwind.git
cd libunwind
autoreconf -i
./configure --prefix=/usr/local
make -j128
make install
cd ..

# 安装tcmalloc
wget https://github.com/gperftools/gperftools/releases/download/gperftools-2.16/gperftools-2.16.tar.gz
tar -xf gperftools-2.16.tar.gz && cd gperftools-2.16
./configure --prefix=/usr/local/lib --with-tcmalloc-pagesize=64
make -j128
make install
export LD_PRELOAD="$LD_PRELOAD:/usr/local/lib/lib/libtcmalloc.so"

快速开始

训练模型

进入模型根目录model-root-path

  • 单机8卡精度训练

    bash test/train_full_8p.sh
    (option) bash test/train_full_8p.sh --batch-size=6 --num-npu=8
    
  • 单机8卡性能训练

    bash test/train_performance_8p.sh
    (option) bash test/train_performance_8p.sh --batch-size=12 --num-npu=8
    

    模型训练脚本参数说明如下:

    公共参数:
    --batch-size                             //指定batchsize,默认值如上指定值
    --num-npu                                //指定卡数,默认值为8
    
  • 多机多卡训练

    # 'XX.XX.XX.XX'为主节点的IP地址;端口号可以换成未被占用的可用端口
    bash test/train_multi_server.sh 8 2 0 'XX.XX.XX.XX' '3389' #主节点
    bash test/train_multi_server.sh 8 2 1 'XX.XX.XX.XX' '3389' #副节点
    

训练结果

表 3 训练结果展示表

A2单机八卡精度:

芯片 卡数 global batchsize Max epochs mAP
竞品A 8p 48 100 0.4534
Atlas 800T A2 8p 48 100 0.4509

A2单机八卡性能:

芯片 卡数 global batchsize Max steps FPS
竞品A 8p 96 500 65.75
Atlas 800T A2 8p 96 500 70.59

A5单机单卡性能:

芯片 卡数 global batchsize Max steps FPS
A5 1p 48 500 20.48

A2多机多卡线性度:

芯片 卡数 global batchsize 平均step耗时(s) Max epochs FPS 线性度
Atlas 800T A2 8p 96 1.27 4 75.59 -
Atlas 800T A2 16p 192 1.32 4 145.60 96.31%

版本说明

变更

2025.1.23: 首次发布。

2025.4.22: 更新训练脚本,刷新性能数据。

2025.5.6: 增加多机多卡训练脚本,增加多机多卡训练性能数据。

2025.7.22: 性能优化,更新patcher方式,刷新性能数据。

2025.8.5: 更新性能数据。

2025.8.26: 关闭图模式,更新性能数据。

2025.11.25: 新增 Pytorch 2.7.1 支持。

2026.05.12:新增 A5 模型适配步骤。

FAQ

Q: 训练时报错ImportError: cannot import name 'gcd' from 'fractions'

A: 报错原因是networkx版本低,使用pip install networkx==3.1升级依赖版本即可。

Q: 训练时报错torch没有uint64_t属性

A: 报错原因是safetensors版本与PyTorch版本不匹配,PyTorch版本为2.1.0,需匹配0.6.0以下的safetensors,使用pip install safetensors==0.5.1改变依赖版本即可。

Q: tcmalloc的动态库文件找不到报错

A: 报错原因是tcmalloc的动态库文件位置可能因环境配置会有所不同,找不到文件时可以进行搜索,一般安装在/usr/lib64或者/usr/local目录下:

find /usr -name libtcmalloc.so*

找到对应路径下的动态库文件,libtcmalloc.so或者libtcmalloc.so.版本号都可以使用。