DETR3D for PyTorch [终止随版本演进]

目录

简介

模型介绍

DETR3D(3D Detection Transformer) 是一种基于 Transformer 的端到端 3D 目标检测模型,由 Waymo 提出,旨在通过轻量化设计实现高效的 3D 目标检测任务。与传统基于点云或 voxel 的 3D 检测方法不同,DETR3D 直接在多视图 2D 图像中生成查询(queries),利用 Transformer 结构对 2D 图像的特征进行全局建模。该模型在无须额外的后处理步骤(如非极大值抑制)的情况下即可生成精确的 3D 边界框。

支持任务列表

本仓已经支持以下模型任务类型

模型 任务列表 是否支持
DETR3D 训练

代码实现

  • 参考实现:
url=https://github.com/WangYueFt/detr3d
commit_id=34a47673011fe13593a3e594a376668acca8bddb
  • 适配昇腾 AI 处理器的实现:
url=https://gitcode.com/Ascend/DrivingSDK.git
code_path=model_examples/DETR3D

DETR3D

准备训练环境

安装昇腾环境

请参考昇腾社区中《Pytorch框架训练环境准备》文档搭建昇腾环境,本仓已支持表1中软件版本。

表 1 昇腾软件版本支持表

软件类型 首次支持版本
FrameworkPTAdapter 7.1.0
CANN 8.2.RC1

安装模型环境

表 2 三方库版本支持表

三方库 支持版本
PyTorch 2.1.0
  1. 激活 CANN 环境(例如:source /usr/local/Ascend/ascend-toolkit/set_env.sh

  2. 参考《Pytorch框架训练环境准备》安装 2.1.0 版本的 PyTorch 框架和 torch_npu 插件。

  3. 安装 Driving SDK 加速库

    安装方法参考原仓

  4. 安装基础依赖

    pip install mmsegmentation==0.29.1
    
  5. 安装mmcv

    git clone -b 1.x https://github.com/open-mmlab/mmcv.git
    cd mmcv
    cp -f ../mmcv.patch ./
    git apply --reject --whitespace=fix mmcv.patch
    pip install -r requirements/runtime.txt
    MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install
    cd ../
    
  6. 安装mmdet

    git clone -b v2.28.0 https://github.com/open-mmlab/mmdetection.git
    cd mmdetection
    cp -f ../mmdet.patch ./
    git apply --reject --whitespace=fix mmdet.patch
    pip install -e .
    cd ../
    
  7. 准备模型源码并安装mmdet3d

    git clone https://github.com/WangYueFt/detr3d
    cp -f detr3d.patch detr3d
    cd detr3d
    git checkout 34a47673011fe13593a3e594a376668acca8bddb
    git apply --reject --whitespace=fix detr3d.patch
    cp -fr ../test/ .
    pip install -r requirements.txt
    git clone -b v1.0.0rc6 https://github.com/open-mmlab/mmdetection3d.git
    cp -f ../mmdet3d.patch mmdetection3d
    cd mmdetection3d
    git apply --reject --whitespace=fix mmdet3d.patch
    pip install -r requirements/runtime.txt
    pip install -e .
    
  8. 配置tcmalloc

    配置方法可以参考昇腾社区

模型数据准备

进入NuScenes官网,下载 Nuscenes 数据集。将数据集解压至DETR3D/detr3d/data/nuscenes目录下。数据集与模型目录结构排布成如下格式:

  • 数据集排布结构
  data
    | -- nuscenes
      | -- lidarseg
      | -- maps
      | -- panoptic
      | -- samples (CAM_BACK, CAM_BACK_LEFT, CAM_BACK_RIGHT, ...)
      | -- sweeps (CAM_BACK, CAM_BACK_LEFT, CAM_BACK_RIGHT, ...)
      | -- v1.0-test
      | -- v1.0-trainval
  projects
  tools
  • 下载模型依赖的权重

根据原仓Evaluation using pretrained models章节通过此处fcos3d.pthdd3d_det_final.pthpillar.pthvoxel.pth自行下载并按如下目录组织:

  ckpts
    | -- dd3d_det_final.pth
    | -- fcos3d.pth
    | -- pillar.pth
  pretrained
    | -- fcos3d.pth
    | -- voxel.pth
  data
  projects
  • 生成模型训练数据
cd /path/to/detr3d/mmdetection3d
python3 tools/create_data.py nuscenes --root-path=../data/nuscenes --out-dir=../data/nuscenes --extra-tag nuscenes

快速开始

训练任务

本任务主要提供单机8卡训练脚本,以配置文件detr3d_res101_gridmask.py为例。

开始训练

cd model_examples/DETR3D/detr3d
  • 单机8卡性能

    bash test/train_8p_performance.sh # 默认跑1个epoch
    
  • 单机8卡精度

    bash test/train_8p_full.sh
    

训练结果

芯片 卡数 global batch size epoch mAP NDS FPS
竞品A 8p 24 24 0.3414 0.4133 14.28
Atlas 800T A2 8p 24 24 0.3420 0.4112 14.35

变更说明

2024.12.30:首次发布 2025.1.13: 性能优化 2025.9.1: 更改训练脚本配置

FAQ