(beta)torch_npu.jit.optimize

产品支持情况

产品 是否支持
Atlas A3 训练系列产品
Atlas A2 训练系列产品
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

主要用于优化ScriptFunction或ScriptModule,以获取更好的性能。

函数原型

torch_npu.jit.optimize(jit_mod)

参数说明

jit_mod:必选参数。用于被优化的ScriptFunction或ScriptModule。

调用示例

import torch
import torch_npu
from torch_npu import jit

class SimpleModel(torch.nn.Module):
    def forward(self, x, y):
        z = x + y
        return torch.relu(z)

model = SimpleModel().eval()
traced_model = torch.jit.trace(model, (torch.rand(1, 3), torch.rand(1, 3)))

torch_npu.jit.optimize(traced_model)