HiVT for PyTorch

目录

简介

模型介绍

HiVT是一种面向自动驾驶Multi-Agent Motion Prediction任务的深度学习框架,它克服了传统vectorized approaches(将轨迹点和地图都转化为矢量化实体,如轨迹点、车道段,再利用GNN,Transformers 等方法进行建模)在同时建模Multi-Agent时效率低、难以做到实时预测的问题,实现了更快、更准确的行为预测。

支持任务列表

本仓已经支持以下模型任务类型

模型 任务列表 是否支持
HiVT 训练

代码实现

  • 参考实现:
https://github.com/ZikangZhou/HiVT.git
commit_id=6876656ce7671982ebdc29113aaaa028c2931518
  • 适配昇腾 AI 处理器的实现:
url=https://gitcode.com/Ascend/DrivingSDK.git
code_path=model_examples/HiVT

HiVT(在研版本)

准备训练环境

安装昇腾环境

请参考昇腾社区中《Pytorch框架训练环境准备》文档搭建昇腾环境,本仓已支持表1中软件版本。

表 1 昇腾软件版本支持表

软件类型 首次支持版本
FrameworkPTAdapter 7.0.0
CANN 8.1.RC1

安装模型环境

表 2 三方库版本支持表

三方库 支持版本
PyTorch 2.1.0
  1. 激活 CANN 环境(例如:source /usr/local/Ascend/ascend-toolkit/set_env.sh

  2. 参考《Pytorch框架训练环境准备》安装 2.1.0 版本的 PyTorch 框架和 torch_npu 插件。

  3. 安装 Driving SDK 加速库,安装方法参考官方文档

    # Driving SDK 加速库安装完成后到HiVT模型目录下
    cd DrivingSDK/model_examples/HiVT
    
  4. 安装argoverse-api,并且下载hd_maps解压到argoverse-api文件中

    git clone https://github.com/argoai/argoverse-api.git
    cd argoverse-api
    git checkout f886ac54fba9f06f8a7d109eb663c7f501b3aa8e
    git apply ../patch/argoverse-api.patch
    pip install -e .
    wget https://s3.amazonaws.com/argoverse/datasets/av1.1/tars/hd_maps.tar.gz
    tar -zxvf hd_maps.tar.gz
    cd ..
    
  5. 安装torch_scatter

    git clone https://github.com/rusty1s/pytorch_scatter.git -b 2.1.0
    cd pytorch_scatter
    git checkout fa4f442952955acf8fe9fcfb98b600f6ca6081b6
    git apply ../patch/torch_scatter.patch
    # 编译耗时较久,需要十分钟
    pip install -e .
    cd ..
    
  6. 根据操作系统安装 tcmalloc 高效内存资源分配库

    • OpenEuler系统

    在当前python环境和路径下执行以下命令,安装并使用tcmalloc动态库。

    mkdir gperftools
    cd gperftools
    wget https://github.com/gperftools/gperftools/releases/download/gperftools-2.16/gperftools-2.16.tar.gz
    tar -zvxf gperftools-2.16.tar.gz
    cd gperftools-2.16
    ./configure --prefix=/usr/local/lib --with-tcmalloc-pagesize=64
    make
    make install
    echo '/usr/local/lib/lib/' >> /etc/ld.so.conf
    ldconfig
    export LD_LIBRARY_PATH=/usr/local/lib/lib/:$LD_LIBRARY_PATH
    export PATH=/usr/local/lib/bin:$PATH
    export LD_PRELOAD=/usr/local/lib/lib/libtcmalloc.so.4
    
    • Ubuntu系统

    参考下载链接,下载三个文件。

    libgoogle-perftools4_2.7-1ubuntu2_arm64.deb

    libgoogle-perftools-dev_2.7-1ubuntu2_arm64.deb

    libtcmalloc-minimal4_2.7-1ubuntu2_arm64.deb

    安装三个文件:

    sudo dpkg -i libtcmalloc-minimal4_2.7-1ubuntu2_arm64.deb
    sudo dpkg -i libgoogle-perftools-dev_2.7-1ubuntu2_arm64.deb
    sudo dpkg -i libgoogle-perftools4_2.7-1ubuntu2_arm64.deb
    find /usr -name libtcmalloc.so*
    

    将find指令的输出路径记为libtcmalloc_dir,执行下列文件使用tcmalloc动态库。

    export LD_PRELOAD="$LD_PRELOAD:${libtcmalloc_dir}/libtcmalloc.so"
    
  7. 安装pip依赖

    pip install -r requirements.txt
    
  8. 拉取HiVT模型源代码

    git clone https://github.com/ZikangZhou/HiVT.git && cd HiVT
    git checkout 6876656ce7671982ebdc29113aaaa028c2931518
    git apply ../patch/HiVT.patch
    cd ..
    

模型数据准备

训练集:https://s3.amazonaws.com/argoverse/datasets/av1.1/tars/forecasting_train_v1.1.tar.gz

验证集:https://s3.amazonaws.com/argoverse/datasets/av1.1/tars/forecasting_val_v1.1.tar.gz

下载后解压到指定目录下:/path/to/Argoverse

  • 文件夹结构
Argoverse
├── train/
| └── data/
| ├── 1.csv
| ├── 2.csv
| ├── ...
└── val/
└── data/
├── 1.csv
├── 2.csv
├── ...
  • 数据预处理

当数据集解压后置于数据集路径下,pytorch-lightning框架会在第一次执行训练脚本时,自动开始数据预处理过程,处理总时长大约10小时。

快速开始

训练任务

本任务主要提供单机8卡训练脚本。

开始训练

在模型根目录下启动训练。

cd /path/DrivingSDK/model_examples/HiVT
  • 单机8卡性能

    # /path/to/Argoverse/ 请更改为存放数据的路径
    bash test/train_8p_performance.sh --data_path=/path/to/Argoverse/
    
  • 单机8卡精度

    # /path/to/Argoverse/ 请更改为存放数据的路径
    bash train_8p.sh --data_path=/path/to/Argoverse/
    

训练结果

芯片 卡数 global batch size epoch minFDE minADE 性能-单步迭代耗时(s) FPS
竞品A 8p 256 64 1.022 0.6845 0.392 652
Atlas 800T A2 8p 256 64 1.03 0.6858 0.397 645

变更说明

2025.4.22:首次发布

2025.5.28:优化模型性能,更新性能数据

FAQ

  1. pip安装omegaconf==2.1.0报错

    ERROR: Ignored the following yanked versions: 1.0.0, 1.0.1, 1.0.2, 2.0.0rc1, 2.0.0rc2, 2.0.0rc22, 2.0.0rc23, 2.0.0rc24, 2.0.0rc25, 2.0.0rc26, 2.0.0rc27, 2.0.0rc28, 2.0.0rc29, 2.0.1rc1, 2.0.1rc2, 2.0.1rc3, 2.0.1rc4, 2.0.1rc5, 2.2.0
    ERROR: Could not find a version that satisfies the requirement omegaconf==2.1.0
    

    解决方法:pip install pip==24.0

  2. pip安装h5py报错

    ERROR: Failed building wheel for h5py
    Failed to build h5py
    ERROR: Could not build wheels for h5py, which is required to install pyproject.toml-based projects
    

    解决方法:conda install h5py