"""
Graph compilation utilities for NPU inference.
"""
import os
import logging
import torch
from torchair.configs.compiler_config import CompilerConfig
def compile_model_forward(
model_forward,
exe_mode="ge_graph",
enable_cache_compile=False,
cache_dir=None,
frozen_parameter=True,
tiling_schedule_optimize=True,
topology_sorting_strategy="StableRDFS",
enable_static_kernel=False,
):
"""
Compile a model.forward method for graph execution.
Args:
model_forward: The model.forward callable to compile.
exe_mode: Execution mode ("ge_graph", "npugraph_ex", etc.).
enable_cache_compile: Whether to use cache compilation.
cache_dir: Directory for cache compilation.
frozen_parameter: Whether to freeze parameters.
tiling_schedule_optimize: Whether to enable tiling schedule optimization.
topology_sorting_strategy: Strategy for topology sorting.
enable_static_kernel: Whether to enable static kernel compile for npugraph_ex.
Returns:
Compiled forward function.
"""
import torchair as tng
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
tng.patch_for_hcom()
torch._dynamo.config.inline_inbuilt_nn_modules = False
if exe_mode == "npugraph_ex":
compile_options = {
"frozen_parameter": True,
"static_kernel_compile": enable_static_kernel,
}
if enable_cache_compile:
compiled = torch.npu.npugraph_ex.inference.cache_compile(model_forward, cache_dir=cache_dir,
dynamic=True, options=compile_options)
else:
compiled = torch.compile(model_forward, dynamic=True, fullgraph=True, backend="npugraph_ex",
options=compile_options)
return compiled
compiler_config = CompilerConfig()
compiler_config.experimental_config.frozen_parameter = frozen_parameter
compiler_config.experimental_config.tiling_schedule_optimize = tiling_schedule_optimize
compiler_config.experimental_config.topology_sorting_strategy = topology_sorting_strategy
if enable_cache_compile:
if cache_dir is None:
raise ValueError("cache_dir must be provided when enable_cache_compile=True")
compiled = tng.inference.cache_compile(
model_forward,
cache_dir=cache_dir,
config=compiler_config,
dynamic=False,
fullgraph=True,
ge_cache=True
)
else:
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
compiled = torch.compile(model_forward, dynamic=False, fullgraph=True, backend=npu_backend)
return compiled