1. 动态分档
当模型输入不是固定shape,但可以固定为某几个shape时(通常是batch维度变化,有固定几种batch输入),如果使用动态图性能无法达到最优,但静态图每次遇到新的shape会重新编译。此时可以使用动态分档功能以获得静态图的性能优化,且只需编译一次。
示例:
import torch
import torch_npu
import torchair
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
return x1 + x2
input1 = torch.ones(2, 2).npu()
input2 = torch.ones(2, 2).npu()
config = torchair.CompilerConfig()
# zip方式,位置一一对应,支持(2, 2)(2, 2)和(4, 2)(4, 2)两种输入组合
config.inference_config.dynamic_gears_merge_policy = "zip" #缺省值
# product方式,排列组合,支持(2, 2)(2, 2)、(2, 2)(4, 2)、(4, 2)(2, 2)和(4, 2)(4, 2)四种输入组合
# config.inference_config.dynamic_gears_merge_policy = "product"
npu_bakcend = torchair.get_npu_backend(compiler_config=config)
model = Model().npu()
# 必须整图编译
model = torch.compile(model, fullgraph=True, backend=npu_bakcend)
# 设置档位
torchair.inference.set_dim_gears(input1, dim_gears={0:[2, 4]})
torchair.inference.set_dim_gears(input2, dim_gears={0:[2, 4]})
# 首次编译+执行,shape为(2, 2)、(2, 2)
print(model(input1, input2))
# 再次执行,shape为(4, 2)、(4, 2)在档位中,不会触发重新编译
input1 = torch.ones(4, 2).npu()
input2 = torch.ones(4, 2).npu()
print(model(input1, input2))
2. torch.dynamo.mark_static()
PyTorch提供的接口,标记某个输入的shape为固定shape。将输入全部使用该接口固定后,如果模型中无动态shape,torch.compile的dynamic参数设置为True时,也可以得到静态图。如果模型输入shape固定,但内部有动态shape,使用这个参数可以在GE编译时生成部分静态子图,提升性能。
示例:
import torch
import torch_npu
import torchair
class Decoder(nn.Module):
...
def decode(
self,
x: Tensor,
xa: Tensor,
positional_embedding: Tensor,
kv_cache: Optional[dict] = None,
updated_kv_positions: Optional[torch.LongTensor] = None,
actual_seq_len: Optional[list] = None,
kv_padding_size: Optional[torch.LongTensor] = None
):
...
return ...
def compute_logits(self, x, xa):
...
torch._dynamo.mark_static(x)
torch._dynamo.mark_static(xa)
torch._dynamo.mark_static(positional_embedding)
for i in range(n_layer):
torch._dynamo.mark_static(self.kv_cache[i]['attn']["key"])
torch._dynamo.mark_static(self.kv_cache[i]['attn']["value"])
torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["key"])
torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["value"])
torch._dynamo.mark_static(kv_padding_size)
return self.decode(x, xa, positional_embedding, self.kv_cache,
actual_seq_len=[actual_seq_len], kv_padding_size=kv_padding_size,
updated_kv_positions=updated_kv_positions)
config = torchair.CompilerConfig()
npu_bakcend = torchair.get_npu_backend(compiler_config=config)
decoder = Decoder().npu()
decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend)
3. 编译缓存
使用torch.compile编译时,每次启动程序都要重新编译,且编译时间通常较长,不方便调试,我们可以通过cache_compile接口,将首次编译结果落盘到磁盘,加速图模式的启动时间。
一般使用torch.compile的时候,首次执行需要Dynamo编译,Torch Guards,Ascend IR图编译等过程,后续每次执行的时候还需要运行Guards函数来判断是否需要重新编译。使用编译缓存后,首次执行直接load cache,后续再次执行的时候跳过Guards函数。
缓存编译接口:
torchair.inference.cache_compile
def cache_compile(
func, # 缓存编译的method,只支持module的method
*,
config: Optional[CompilerConfig] = None,
dynamic: bool = True, # 是否按照输入动态trace
cache_dir: Optional[str] = None, # 缓存根目录,默认.torchair_cache
global_rank: Optional[int] = None, # 分布式训练时的rank,默认torch.distributed.get_rank()
tp_rank: Optional[int] = None, # 指定的tp rank
pp_rank: Optional[int] = None, # 指定的pp rank
custom_decompositions: optional[dict] = None, # 用户自定义的decompose策略
ge_cache: bool = False, # 是否开启GE缓存
**kwargs
) -> Callable:
使用约束:
- func函数只能被处罚一次Dynamo trace, 如果func发生重编译,则会放弃缓存。
- 对于发生多次trace (Guards失效)的函数,需要进行一次函数封装来使缓存生效。
- func必须是module实例对象的method,且该方法未被其他装饰器修饰
- func必须能形成整图,即必须支持full graph
- 只支持推理模式,不支持带反向计算过程的func缓存。
Ascend IR编译缓存
- 除了优化Dynamo编译耗时,还支持优化Ascend IR图编译耗时,主要通过cache_compile中的ge_cache参数实现,以进一步加速图模式启动时间。
- 缺省情况下,ge_cache=False (功能不开启),因受CANN版本变更影响,用户需根据实际情况手动开启该功能
- 缓存的编译结果文件路径与封装的func函数缓存文件路径一致
使用示例
原模型代码:
import torch, torch_npu, torchair
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
@torch.inference_mode()
def forward(self, x, y):
return x + y
适配步骤
- 提取forward的实现为_forward函数,缓存编译_forward函数
- 初始化时将被编译的函数定义为Module自身的一个属性
- forward里执行被编译的函数,也就是初始化中设置的属性
import torch, torch_npu, torchair
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.cached_forward = torchair.Inference.cache_compile(self._forward)
def forward(self, x, y):
# 修改为编译后的模型
return self.cached_forward(x, y)
def _forward(self, xy, y):
return x + y
带有transformer结构的模型中常见forward中会在prefill和decode阶段中使用不同的算子或者计算逻辑,此时需要针对每个场景新增一个函数:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
config = torchair.CompilerConfig()
self.cached_prefill = torchair.inference.cache_compile(self.prefill, config=config)
self.cached_decode = torchair.inference.cache_compile(self.decode, config=config)
def forward(self, x, ...):
if x.size(1) > 1:
return self.cached_prefill(x, ...)
return self.cached_decode(x, ...)
def prefill(self, x, ...):
return ...
def decode(self, x, kv_cache):
return ...