Cat算子消除功能
功能简介
图模式场景下,当计算图中存在 [torch.cat](https://pytorch.org/docs/stable/generated/torch.cat.html)(对应 `torch.ops.aten.cat.default`)时,npugraph_ex 内部会通过 Cat 算子消除 Pass 将其替换为「预分配输出张量 + slice + 原地写入」的模式,以减少内存拷贝和临时张量分配,提升执行性能。默认情况下,本功能处于开启状态。
使用约束
本功能仅针对 `torch.ops.aten.cat.default` 进行优化。
拼接维度:当前仅支持 `dim=0` 的 cat,其他拼接维度不会触发本优化。
输入算子约束:cat的每个输入必须来自带 `.out` 变体(原地变体)的算子,且不能同时作为其他cat的输入,否则该 cat 节点会被跳过。
Shape 约束:参与 cat 的各输入张量在非拼接维度上的 shape 必须一致。
若用户需要进行精度比对或问题排查,可关闭本功能以避免优化影响分析结果。
本功能支持的产品型号参见使用说明。
使用方法
该功能通过npugraph_ex的options配置,示例如下,仅供参考不支持直接拷贝运行,参数说明参见下表。
import torch
import torch_npu
torch.compile(model, backend="npugraph_ex", options={"remove_cat_ops": True}, dynamic=False, fullgraph=True)
表 1 参数说明
| 参数名 | 参数说明 |
|---|---|
| remove_cat_ops | 是否开启Cat算子消除优化。 True(默认值):开启优化。 False:关闭优化。 |
开启Debug日志后,如果存在可优化的Cat节点,可以看到类似的信息:
[DEBUG] [remove_cat_ops] Found 1 cat node(s)
[DEBUG] Optimizing cat_1 (3 inputs)
[DEBUG] remove_cat_ops: Optimized 1 cat node(s)