图捕获模式算子开发示例 (Graph Capturing Mode)
本示例在自定义的Softmax算子基础上,展示了如何开启图捕获模式,以及执行捕获图。
总览介绍
随着性能优化不断深入,Eager模式带来的Host侧开销逐渐成为瓶颈。本目录目前聚焦于:
- ACLGraph:通过图捕获模式(
ACLGraph)将相关任务下沉到Device上并执行,减少Host侧开销,当捕获图需要多次执行时,无需再下发任务,只需要多次调用replay接口即可。
样例代码特性
- 装饰器修饰:PyPTO 自定义算子对接
torch.compile,需要使用@allow_in_graph装饰器修饰,避免图分割报错。 - FakeTensor处理 PyPTO 自定义算子对接
torch.compile,需要判断输入是否为FakeTensor,若是则直接返回空tensor。 - 编译模型: 调用
torch.compile编译模型model。 - 图捕获:使用
with torch.npu.graph(g):捕获model第一次执行的任务。 - 执行捕获图 调用
replay执行捕获图,计算Softmax结果;
代码结构
aclgraph/:aclgraph.py: 包含图捕获模式开启、捕获图执行以及与 PyTorch 原生算子的精度比对。
运行方法
环境准备
# 配置 CANN 环境变量
# 安装完成后请配置环境变量,请用户根据set_env.sh的实际路径执行如下命令。
# 上述环境变量配置只在当前窗口生效,用户可以按需将以上命令写入环境变量配置文件(如.bashrc文件)。
# 默认路径安装,以root用户为例(非root用户,将/usr/local替换为${HOME})
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 设置设备 ID
export TILE_FWK_DEVICE_ID=0
执行脚本
python3 examples/03_advanced/aclgraph/aclgraph.py
注意事项
- 当捕获图需要多次执行时,需要更新对应输入数据。
- 在捕获过程中,调用内存同步操作类函数是非法的,会检验失败导致捕获失败。