e7d0df49创建于 2025年9月19日历史提交

1. 动态分档

当模型输入不是固定shape,但可以固定为某几个shape时(通常是batch维度变化,有固定几种batch输入),如果使用动态图性能无法达到最优,但静态图每次遇到新的shape会重新编译。此时可以使用动态分档功能以获得静态图的性能优化,且只需编译一次。

示例:

import torch
import torch_npu
import torchair

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x1, x2):
        return x1 + x2

input1 = torch.ones(2, 2).npu()
input2 = torch.ones(2, 2).npu()
config = torchair.CompilerConfig()
# zip方式,位置一一对应,支持(2, 2)(2, 2)和(4, 2)(4, 2)两种输入组合
config.inference_config.dynamic_gears_merge_policy = "zip" #缺省值
# product方式,排列组合,支持(2, 2)(2, 2)、(2, 2)(4, 2)、(4, 2)(2, 2)和(4, 2)(4, 2)四种输入组合
# config.inference_config.dynamic_gears_merge_policy = "product"
npu_bakcend = torchair.get_npu_backend(compiler_config=config)
model = Model().npu()
# 必须整图编译
model = torch.compile(model, fullgraph=True, backend=npu_bakcend)

# 设置档位
torchair.inference.set_dim_gears(input1, dim_gears={0:[2, 4]})
torchair.inference.set_dim_gears(input2, dim_gears={0:[2, 4]})

# 首次编译+执行,shape为(2, 2)、(2, 2)
print(model(input1, input2))

# 再次执行,shape为(4, 2)、(4, 2)在档位中,不会触发重新编译
input1 = torch.ones(4, 2).npu()
input2 = torch.ones(4, 2).npu()
print(model(input1, input2))

2. torch.dynamo.mark_static()

PyTorch提供的接口,标记某个输入的shape为固定shape。将输入全部使用该接口固定后,如果模型中无动态shape,torch.compile的dynamic参数设置为True时,也可以得到静态图。如果模型输入shape固定,但内部有动态shape,使用这个参数可以在GE编译时生成部分静态子图,提升性能。

示例:

import torch
import torch_npu
import torchair

class Decoder(nn.Module):
    ...

    def decode(
        self,
        x: Tensor,
        xa: Tensor,
        positional_embedding: Tensor,
        kv_cache: Optional[dict] = None,
        updated_kv_positions: Optional[torch.LongTensor] = None,
        actual_seq_len: Optional[list] = None,
        kv_padding_size: Optional[torch.LongTensor] = None
        ):
        ...
        return ...
    
    def compute_logits(self, x, xa):
        ...
        torch._dynamo.mark_static(x)
        torch._dynamo.mark_static(xa)
        torch._dynamo.mark_static(positional_embedding)
        for i in range(n_layer):
            torch._dynamo.mark_static(self.kv_cache[i]['attn']["key"])
            torch._dynamo.mark_static(self.kv_cache[i]['attn']["value"])
            torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["key"])
            torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["value"])
        torch._dynamo.mark_static(kv_padding_size)

        return self.decode(x, xa, positional_embedding, self.kv_cache,
                            actual_seq_len=[actual_seq_len], kv_padding_size=kv_padding_size,
                            updated_kv_positions=updated_kv_positions)

config = torchair.CompilerConfig()
npu_bakcend = torchair.get_npu_backend(compiler_config=config)
decoder = Decoder().npu()
decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend)

3. 编译缓存

使用torch.compile编译时,每次启动程序都要重新编译,且编译时间通常较长,不方便调试,我们可以通过cache_compile接口,将首次编译结果落盘到磁盘,加速图模式的启动时间。

一般使用torch.compile的时候,首次执行需要Dynamo编译,Torch Guards,Ascend IR图编译等过程,后续每次执行的时候还需要运行Guards函数来判断是否需要重新编译。使用编译缓存后,首次执行直接load cache,后续再次执行的时候跳过Guards函数。

缓存编译接口:

torchair.inference.cache_compile

def cache_compile(
    func, # 缓存编译的method,只支持module的method
    *,
    config: Optional[CompilerConfig] = None,
    dynamic: bool = True, # 是否按照输入动态trace
    cache_dir: Optional[str] = None, # 缓存根目录,默认.torchair_cache
    global_rank: Optional[int] = None, # 分布式训练时的rank,默认torch.distributed.get_rank()
    tp_rank: Optional[int] = None, # 指定的tp rank
    pp_rank: Optional[int] = None, # 指定的pp rank
    custom_decompositions: optional[dict] = None, # 用户自定义的decompose策略
    ge_cache: bool = False, # 是否开启GE缓存
    **kwargs
) -> Callable:

使用约束:

  • func函数只能被处罚一次Dynamo trace, 如果func发生重编译,则会放弃缓存。
  • 对于发生多次trace (Guards失效)的函数,需要进行一次函数封装来使缓存生效。
  • func必须是module实例对象的method,且该方法未被其他装饰器修饰
  • func必须能形成整图,即必须支持full graph
  • 只支持推理模式,不支持带反向计算过程的func缓存。

Ascend IR编译缓存

  • 除了优化Dynamo编译耗时,还支持优化Ascend IR图编译耗时,主要通过cache_compile中的ge_cache参数实现,以进一步加速图模式启动时间。
  • 缺省情况下,ge_cache=False (功能不开启),因受CANN版本变更影响,用户需根据实际情况手动开启该功能
  • 缓存的编译结果文件路径与封装的func函数缓存文件路径一致

使用示例

原模型代码:

import torch, torch_npu, torchair
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    @torch.inference_mode()
    def forward(self, x, y):
        return x + y

适配步骤

  • 提取forward的实现为_forward函数,缓存编译_forward函数
  • 初始化时将被编译的函数定义为Module自身的一个属性
  • forward里执行被编译的函数,也就是初始化中设置的属性
import torch, torch_npu, torchair
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cached_forward = torchair.Inference.cache_compile(self._forward)
    
    def forward(self, x, y):
        # 修改为编译后的模型
        return self.cached_forward(x, y)

    def _forward(self, xy, y):
        return x + y

带有transformer结构的模型中常见forward中会在prefill和decode阶段中使用不同的算子或者计算逻辑,此时需要针对每个场景新增一个函数:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        config = torchair.CompilerConfig()
        self.cached_prefill = torchair.inference.cache_compile(self.prefill, config=config)
        self.cached_decode = torchair.inference.cache_compile(self.decode, config=config)
    
    def forward(self, x, ...):
        if x.size(1) > 1:
            return self.cached_prefill(x, ...)
        return self.cached_decode(x, ...)
    
    def prefill(self, x, ...):
        return ...
    
    def decode(self, x, kv_cache):
        return ...