模型编译缓存功能
功能简介
torch.compile是一种即时编译器(Just-In-Time compiler),成图首次编译时间通常较长,而大模型推理场景对时延敏感,因此优化首次编译时长显得尤为重要。在推理服务和弹性扩容等业务场景中,使用编译缓存可有效缩短服务启动后的首次推理时延。
成图编译涉及两段耗时,一段是Dynamo的编译耗时,另一段是基于Dynamo编译出的FX图进行再处理的耗时,再处理的行为会根据模式不同而有所变化。
为降低成图编译的耗时,TorchAir提供了模型编译缓存方案,通过cache_compile接口将首次编译的结果保存,从而加快torch.compile图模式的启动时间。
图 1 max-autotune模式执行时间分布示意图

以LLaMA 2-70B(Large Language Model Meta AI 2)为例,上图呈现了启动与未开启模型编译缓存的耗时分布。需要注意的是,该图不呈现与本功能无关的耗时细节。
-
原始推理任务执行,分为5个阶段:
-
开启模型编译缓存:
通过缓存Dynamo、Ascend IR图编译两个耗时占比最大环节,实现模型的加速启动。
使用约束
- 本功能仅适用于GE图模式场景,暂不支持同时配置Dynamo导图功能、RefData类型转换功能。
- 如果图中包含依赖随机数生成器(RNG)的算子(例如randn、bernoulli、dropout等),不支持使用本功能。
- 本功能跳过了Dynamo的JIT编译环节、Guards、Ascend IR图编译环节,与torch.compile原始方案相比多了如下限制:
- 缓存要与执行计算图一一对应,若重编译则缓存失效。
- Guards阶段被跳过且不会触发JIT编译,要求生成模型的脚本和加载缓存的脚本一致。
- CANN包跨版本缓存无法保证兼容性,如果版本升级,需要清理缓存目录并重新进行Ascend IR计算图编译生成缓存。
Dynamo编译缓存
本节提供一个简化版的模型编译缓存使用示例,同时展示缓存对特殊类型输入的处理能力(如Python Class类型、List类型等)。
-
准备PyTorch模型脚本。
假设在/home/workspace目录下定义了test.py模型脚本,代码示例如下:
import torch import dataclasses from typing import List import torch_npu import torchair from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = torchair.get_npu_backend(compiler_config=config) # InputMeta为仿照VLLM(Versatile Large Language Model)框架的入参结构 @dataclasses.dataclass class InputMeta: data: torch.Tensor is_prompt: bool class Model(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(2, 1) self.linear2 = torch.nn.Linear(2, 1) for param in self.parameters(): torch.nn.init.ones_(param) def forward(self, x: InputMeta, kv: List[torch.Tensor]): return self.linear2(x.data) + self.linear2(kv[0]) x = InputMeta(data=torch.randn(2, 2).npu(), is_prompt=True) kv = [torch.randn(2, 2).npu()] model = Model().npu() # 调用torch.compile编译 compiled_model = torch.compile(model, backend=npu_backend) # 执行prompt res_prompt = compiled_model(x, kv) x.is_prompt = False # 执行decode res_decode = compiled_model(x, kv) -
改造PyTorch模型脚本。
-
先处理forward函数。
将test.py中“forward”函数的实现提取为“_forward”函数。
@torch.inference_mode() def forward(self, x: InputMeta, kv: List[torch.Tensor]): return self._forward(x, kv) def _forward(self, x, kv): return self.linear2(x.data) + self.linear2(kv[0]) -
通过cache_compile接口实现编译缓存。
“_forward”函数是可以缓存编译的函数,但由于其会触发多次重新编译,所以要为每个场景封装一个新的func函数,然后func直接调用_forward函数。同时,forward函数中添加调用新函数的判断逻辑。如何封装新的func函数,取决于原始模型逻辑,请用户根据实际场景自行定义。
说明
- func函数只能被触发一次Dynamo trace,换言之如果func发生重编译,则会放弃缓存。
- 对于发生多次trace(Guards失效)的函数,需要进行一次函数封装来使缓存生效。
- func必须是method,即module实例对象的方法,且该方法未被其他装饰器修饰。
- func必须能形成整图,即必须支持full graph。
- 使用cache_compile接口后,原先脚本中的torch.compile编译流程不再需要。
test.py中只展示了prompt和decode的func函数封装,具体代码示例如下:
import dataclasses import logging from typing import List import torch import torch_npu import torchair from torchair import logger from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() logger.setLevel(logging.INFO) # InputMeta为仿照VLLM(Versatile Large Language Model)框架的入参结构 @dataclasses.dataclass class InputMeta: data: torch.Tensor is_prompt: bool class Model(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(2, 1) self.linear2 = torch.nn.Linear(2, 1) for param in self.parameters(): torch.nn.init.ones_(param) # 通过torchair.inference.cache_compile实现编译缓存 self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=config) self.cached_decode = torchair.inference.cache_compile(self.decode, config=config) def forward(self, x: InputMeta, kv: List[torch.Tensor]): # 添加调用新函数的判断逻辑 if x.is_prompt: return self.cached_prompt(x, kv) return self.cached_decode(x, kv) def _forward(self, x, kv): return self.linear2(x.data) + self.linear2(kv[0]) # 重新封装为prompt函数 def prompt(self, x, y): return self._forward(x, y) # 重新封装为decode函数 def decode(self, x, y): return self._forward(x, y) x = InputMeta(data=torch.randn(2, 2).npu(), is_prompt=True) kv = [torch.randn(2, 2).npu()] model = Model().npu() # 注意无需调用torch.compile进行编译,直接执行model # 执行prompt res_prompt = model(x, kv) x.is_prompt = False # 执行decode res_decode = model(x, kv)
-
-
模型脚本改造后,运行并生成封装func函数的缓存文件。
-
进入test.py所在目录,执行如下命令:
cd /home/workspace python3 test.py -
参考TorchAir Python层日志打印开启INFO日志,首次执行可以看到如下打印日志:
[INFO] TORCHAIR 2024-04-30 14:48:18 Cache ModelCacheMeta(name='CacheCompileSt.test_cache_hint.<locals>.Model.prompt(x, y)', date='2024-04-30 14:48:16.736731', version='1.0.0', fx=None) saved to /home/workspace/.torchair_cache/Model_dynamic_f2df0818d06118d4a83a6cacf8dc6d28/prompt/compiled_module [INFO] TORCHAIR 2024-04-30 14:48:20 Cache ModelCacheMeta(name='CacheCompileSt.test_cache_hint.<locals>.Model.decode(x, y)', date='2024-04-30 14:48:19.654573', version='1.0.0', fx=None) saved to /home/workspace/.torchair_cache/Model_dynamic_f2df0818d06118d4a83a6cacf8dc6d28/decode/compiled_module生成的各个func函数缓存文件路径由cache_compile中cache_dir参数指定,支持相对路径和绝对路径。
- 若cache_dir指定路径,且为绝对路径,则缓存文件路径为{cache_dir}/\{cache\_dir\}/{model_info}/${func}。
- 若cache_dir指定路径,且为相对路径,则缓存文件路径为{work_dir}/\{work\_dir\}/{cache_dir}/{model_info}/\{model\_info\}/{func}。
{cache_dir}默认为“.torchair_cache”(若无会新建,请确保有读写权限),\{cache\_dir\}默认为“.torchair\_cache”(若无会新建,请确保有读写权限),{work_dir}为当前工作目录,{model_info}为模型信息,\{model\_info\}为模型信息,{func}为封装的func函数。
说明
若编译缓存的模型涉及多机多卡,缓存路径包含集合通信相关的world_size以及global_rank信息,路径为{work_dir}/\{work\_dir\}/{cache_dir}/{model_info}/world\{model\_info\}/world{world_size}global_rank{global_rank}/\{global\_rank\}/{func}/。
-
-
再次执行脚本,验证模型启动时间。
重新执行test.py脚本,开启Python侧INFO日志,可以看到缓存命中的日志:
[INFO] TORCHAIR 2024-04-30 14:52:08 Cache ModelCacheMeta(name='CacheCompileSt.test_cache_hint.<locals>.Model.prompt(x, y)', date='2024-04-30 14:48:16.736731', version='1.0.0', fx=None) loaded from /home/workspace/.torchair_cache/Model_dynamic_f2df0818d06118d4a83a6cacf8dc6d28/prompt/compiled_module [INFO] TORCHAIR 2024-04-30 14:52:08 Cache ModelCacheMeta(name='CacheCompileSt.test_cache_hint.<locals>.Model.decode(x, y)', date='2024-04-30 14:48:19.654573', version='1.0.0', fx=None) loaded from /home/workspace/.torchair_cache/Model_dynamic_f2df0818d06118d4a83a6cacf8dc6d28/decode/compiled_module -
(可选)如需查看封装的func函数缓存文件compiled_module,通过readable_cache接口读取。
说明
compiled_module主要存储了torch.compile成图过程中模型脚本、模型结构、执行流程等相关信息,可用于问题定位分析。
接口调用示例如下:
import torch_npu, torchair torchair.inference.readable_cache("/home/workspace/.torchair_cache/Model_dynamic_f2df0818d06118d4a83a6cacf8dc6d28/prompt/compiled_module", file="prompt.py")compiled_module内容最终解析到可读文件prompt.py(格式不限,如py、txt等)中。
Ascend IR编译缓存
max-autotune模式下**,**除了优化Dynamo编译耗时,还支持优化Ascend IR图编译耗时,主要通过cache_compile中ge_cache参数实现,以进一步加快图模式启动时间。具体参见下方示例代码:
说明
- 默认情况下,ge_cache=False(功能不开启),因受CANN包版本变更影响,用户需根据实际情况手动开启该功能。
- CANN包跨版本的缓存无法保证兼容性,如果版本升级,需要清理缓存目录并重新GE编译生成缓存。
- 在单算子和图混跑场景下,开启该功能会增加通信域资源开销,有额外显存消耗。
import dataclasses
import logging
from typing import List
import torch
import torch_npu
import torchair
from torchair import logger
logger.setLevel(logging.INFO)
# InputMeta为仿照VLLM(Versatile Large Language Model)框架的入参结构
@dataclasses.dataclass
class InputMeta:
data: torch.Tensor
is_prompt: bool
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 1)
self.linear2 = torch.nn.Linear(2, 1)
for param in self.parameters():
torch.nn.init.ones_(param)
# 通过torchair.inference.cache_compile实现编译缓存
config = torchair.CompilerConfig()
# 开启ge_cache
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=config, ge_cache=True)
self.cached_decode = torchair.inference.cache_compile(self.decode, config=config, ge_cache=True)
def forward(self, x: InputMeta, kv: List[torch.Tensor]):
# 添加调用新函数的判断逻辑
if x.is_prompt:
return self.cached_prompt(x, kv)
return self.cached_decode(x, kv)
def _forward(self, x, kv):
return self.linear2(x.data) + self.linear2(kv[0])
# 重新封装为prompt函数
def prompt(self, x, y):
return self._forward(x, y)
# 重新封装为decode函数
def decode(self, x, y):
return self._forward(x, y)
x = InputMeta(data=torch.randn(2, 2).npu(), is_prompt=True)
kv = [torch.randn(2, 2).npu()]
model = Model().npu()
# 执行prompt
res_prompt = model(x, kv)
x.is_prompt = False
# 执行decode
res_decode = model(x, kv)
配置ge_cache=True后,缓存编译结果路径与封装的func函数缓存文件路径一致,即{work_dir}/\{work\_dir\}/{cache_dir}/{model_info}/\{model\_info\}/{func},注意此时缓存路径中的模型信息${model_info}里会自动增加ge_cache关键词。
缓存的编译结果文件包括:
-
graph_{key}_\{key\}\_{timestamp}.om:模型缓存文件。
-
graph_${key}.idx:索引文件,用户可通过graph_key快速找到对应的缓存文件。索引文件内容示例如下:
{ "cache_file_list":[ { "cache_file_name":"./cache_dir/graph_$key1_20230117202307.om", "graph_key":"graph_$key1", "var_desc_file_name":"./cache_dir/graph_$key1_20230117202307.rdcpkt" }, { "cache_file_name":"./cache_dir/graph_$key1_20230117203007.om", "graph_key":"graph_$key1", "var_desc_file_name":"./cache_dir/graph_$key1_20230117203007.rdcpkt" } ] } -
(可选)graph_{key}_\{key\}\_{timestamp}.rdcpkt:变量格式文件,仅在图中存在变量时生成。用于框架匹配模型缓存文件,如果graph_key对应的图内变量格式发生变更,则之前缓存的文件将无法直接恢复使用,该场景下会重新触发编译流程重新生成缓存文件。
说明
- 如果未生成“.om”和“.idx”文件,需要清理缓存目录并重新生成缓存。
- 文件名中的{key}表示graph的编号,\{key\}表示graph的编号,{timestamp}表示文件保存的时间戳。