__all__ = [
"is_current_stream_capturing",
"graph_pool_handle",
"graph_task_group_begin",
"graph_task_group_end",
"graph_task_update_begin",
"graph_task_update_end",
"NPUGraph",
"graph",
"make_graphed_callables",
"super_kernel_scope_begin",
"super_kernel_scope_end",
]
import gc
import os
import re
import threading
import typing
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, Optional, Tuple
import torch
from torch.fx.node import has_side_effect
import torch_npu._C
from torch_npu._C import _weak_ref_tensor as TensorWeakRef
from torch_npu.npu._npugraph_handlers.npugraph_handler import _NPU_GRAPH_OP_HANDLERS
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu._compiler._config import force_npugraph_gc
from .utils import _dummy_type
if not hasattr(torch_npu._C, "_NPUStreamBase"):
torch_npu._C.__dict__["_NPUGraph"] = _dummy_type("_NPUGraph")
torch_npu._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
torch_npu._C.__dict__["_npu_isCurrentStreamCapturing"] = _dummy_type(
"_npu_isCurrentStreamCapturing"
)
torch_npu._C.__dict__["_graph_task_group_begin"] = _dummy_type("_graph_task_group_begin")
torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end")
torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin")
torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end")
torch_npu._C.__dict__["_super_kernel_scope_begin"] = _dummy_type("_super_kernel_scope_begin")
torch_npu._C.__dict__["_super_kernel_scope_end"] = _dummy_type("_super_kernel_scope_end")
from torch_npu._C import (
_npu_isCurrentStreamCapturing,
_NPUGraph,
_graph_pool_handle,
_graph_task_group_begin,
_graph_task_group_end,
_graph_task_update_begin,
_graph_task_update_end,
_super_kernel_scope_begin,
_super_kernel_scope_end,
)
log = torch._logging.getArtifactLogger("torch_npu.npugraph", "cudagraphs")
def is_current_stream_capturing():
r"""Return True if NPU graph capture is underway on the current NPU stream, False otherwise.
If a NPU context does not exist on the current device, returns False without initializing the context.
"""
return _npu_isCurrentStreamCapturing()
def graph_pool_handle():
r"""Return an opaque token representing the id of a graph memory pool.
See :ref:`Graph memory management<graph-memory-management>`.
.. warning::
This API is in beta and may change in future releases.
"""
return _graph_pool_handle()
def graph_task_group_begin(stream):
_graph_task_group_begin(stream)
def graph_task_group_end(stream):
return _graph_task_group_end(stream)
def graph_task_update_begin(stream, handle):
_graph_task_update_begin(stream, handle)
def graph_task_update_end(stream):
_graph_task_update_end(stream)
def _super_kernel_scope_begin_impl(scope_name: Optional[str] = None) -> None:
if scope_name is not None and not scope_name.strip():
raise RuntimeError(
"scope_name should be None or a non-empty string.",
pta_error(ErrCode.PARAM),
)
_super_kernel_scope_begin(scope_name)
def _super_kernel_scope_end_impl(scope_name: Optional[str] = None) -> None:
if scope_name is not None and not scope_name.strip():
raise RuntimeError(
"scope_name should be None or a non-empty string.",
pta_error(ErrCode.PARAM),
)
_super_kernel_scope_end(scope_name)
_save_npugraph_tensor_lock = threading.Lock()
_save_npugraph_tensor_counters = {}
_save_tensor_streams: Dict[int, "torch_npu.npu.Stream"] = {}
_save_tensor_stream_lock = threading.Lock()
_NPUGRAPH_TENSOR_PTR_SPEC_MARKER = "host_tensor_ptr"
_NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER = "host_tensor_buffer"
def _get_save_tensor_stream(device_index: int):
with _save_tensor_stream_lock:
if device_index not in _save_tensor_streams:
_save_tensor_streams[device_index] = torch.npu.Stream(
device=device_index, priority=-1
)
return _save_tensor_streams[device_index]
def _build_save_npugraph_tensor_path(save_path=None, device_index=None, overwrite=False):
if save_path is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
save_path = f"tensor_{timestamp}.pt"
if not isinstance(save_path, str):
raise TypeError(f"save_path must be str or None, but got {type(save_path).__name__}")
if not isinstance(overwrite, bool):
raise TypeError(f"overwrite must be bool, but got {type(overwrite).__name__}")
directory, filename = os.path.split(save_path)
stem, suffix = os.path.splitext(filename)
if not stem:
raise ValueError("save_path must include a file name")
if not suffix:
suffix = ".pt"
file_stem = stem
if device_index is not None:
file_stem = f"{file_stem}_device_{device_index}"
if overwrite:
final_name = f"{file_stem}{suffix}"
return os.path.join(directory, final_name) if directory else final_name
counter_key = os.path.join(directory, f"{file_stem}{suffix}")
with _save_npugraph_tensor_lock:
file_index = _save_npugraph_tensor_counters.get(counter_key, 0)
_save_npugraph_tensor_counters[counter_key] = file_index + 1
final_name = f"{file_stem}_{file_index}{suffix}"
return os.path.join(directory, final_name) if directory else final_name
def _print_callback_pending(tensor_name, tensor_arg):
tensor_arg = _materialize_npugraph_tensor_arg(tensor_arg)
output = str(tensor_arg)
if tensor_name is not None:
output = f"{tensor_name}={output}"
output = f"{output}, shape={tuple(tensor_arg.shape)}, dtype={tensor_arg.dtype}"
print(output, flush=True)
def _save_callback_pending(tensor_arg, str_arg):
tensor_arg = _materialize_npugraph_tensor_arg(tensor_arg)
torch.save(tensor_arg, str_arg)
def _make_npugraph_tensor_ptr_spec(tensor):
return (
_NPUGRAPH_TENSOR_PTR_SPEC_MARKER,
tensor.data_ptr(),
tensor.numel() * tensor.element_size(),
tuple(tensor.shape),
tensor.dtype,
)
def _materialize_npugraph_tensor_arg(tensor_arg):
if (
isinstance(tensor_arg, tuple)
and len(tensor_arg) == 4
and tensor_arg[0] == _NPUGRAPH_TENSOR_BUFFER_SPEC_MARKER
):
_, buffer, shape, dtype = tensor_arg
if len(buffer) == 0:
return torch.empty(shape, dtype=dtype)
return torch.frombuffer(buffer, dtype=dtype).reshape(shape)
if isinstance(tensor_arg, list):
return [_materialize_npugraph_tensor_arg(arg) for arg in tensor_arg]
return tensor_arg
def _validate_tensor_list(inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError(f"input must be Tensor or TensorList, but got {type(inputs).__name__}")
if len(inputs) == 0:
raise ValueError("input tensor list must not be empty")
if not all(isinstance(tensor, torch.Tensor) for tensor in inputs):
raise TypeError("all elements in input tensor list must be Tensor")
first_device = inputs[0].device
for tensor in inputs[1:]:
if tensor.device != first_device:
raise ValueError("all tensors in input tensor list must be on the same device")
return first_device
def _print_npugraph_tensor_impl(input, tensor_name=None):
if not isinstance(input, torch.Tensor):
return
device = input.device
if device.type == "cpu":
_print_callback_pending(tensor_name, input)
return
if device.type != "npu":
return
device_index = device.index
save_stream = _get_save_tensor_stream(device_index)
event1 = torch.npu.Event()
event2 = torch.npu.Event()
event1.record()
with torch.npu.stream(save_stream):
event1.wait()
cpu_tensor = input.to("cpu", non_blocking=True)
cpu_arg = _make_npugraph_tensor_ptr_spec(cpu_tensor)
current_stream = torch.npu.current_stream()
torch_npu.npu._launch_host_func_pending(
current_stream,
_print_callback_pending,
(tensor_name, cpu_arg),
)
event2.record()
event2.wait()
def _save_npugraph_tensor_impl(input, save_path=None, overwrite=False):
if not isinstance(input, torch.Tensor):
return
device = input.device
if device.type == "cpu":
torch.save(input, _build_save_npugraph_tensor_path(save_path, overwrite=overwrite))
return
if device.type != "npu":
return
device_index = device.index
save_stream = _get_save_tensor_stream(device_index)
final_path = _build_save_npugraph_tensor_path(save_path, device_index, overwrite)
event1 = torch.npu.Event()
event2 = torch.npu.Event()
event1.record()
with torch.npu.stream(save_stream):
event1.wait()
cpu_tensor = input.to("cpu", non_blocking=True)
cpu_arg = _make_npugraph_tensor_ptr_spec(cpu_tensor)
current_stream = torch.npu.current_stream()
torch_npu.npu._launch_host_func_pending(
current_stream,
_save_callback_pending,
(cpu_arg, final_path),
)
event2.record()
event2.wait()
def _save_npugraph_tensor_tensor_list_impl(input, save_path=None, overwrite=False):
device = _validate_tensor_list(input)
if device.type == "cpu":
torch.save(list(input), _build_save_npugraph_tensor_path(save_path, overwrite=overwrite))
return
if device.type != "npu":
return
device_index = device.index
save_stream = _get_save_tensor_stream(device_index)
final_path = _build_save_npugraph_tensor_path(save_path, device_index, overwrite)
event1 = torch.npu.Event()
event2 = torch.npu.Event()
event1.record()
with torch.npu.stream(save_stream):
event1.wait()
cpu_tensors = [tensor.to("cpu", non_blocking=True) for tensor in input]
cpu_args = [_make_npugraph_tensor_ptr_spec(tensor) for tensor in cpu_tensors]
current_stream = torch.npu.current_stream()
torch_npu.npu._launch_host_func_pending(
current_stream,
_save_callback_pending,
(cpu_args, final_path),
)
event2.record()
event2.wait()
_npu_lib = torch.library.Library("npu", "FRAGMENT")
if not hasattr(torch.ops.npu, "super_kernel_scope_begin"):
_npu_lib.define("super_kernel_scope_begin(str? scope_name=None) -> ()")
_npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "PrivateUse1")
_npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "BackendSelect")
_npu_lib.impl("super_kernel_scope_begin", _super_kernel_scope_begin_impl, "CPU")
has_side_effect(torch.ops.npu.super_kernel_scope_begin.default)
@torch.library.register_fake("npu::super_kernel_scope_begin")
def _super_kernel_scope_begin_meta(scope_name: Optional[str] = None):
pass
if not hasattr(torch.ops.npu, "super_kernel_scope_end"):
_npu_lib.define("super_kernel_scope_end(str? scope_name=None) -> ()")
_npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "PrivateUse1")
_npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "BackendSelect")
_npu_lib.impl("super_kernel_scope_end", _super_kernel_scope_end_impl, "CPU")
has_side_effect(torch.ops.npu.super_kernel_scope_end.default)
@torch.library.register_fake("npu::super_kernel_scope_end")
def _super_kernel_scope_end_meta(scope_name: Optional[str] = None):
pass
if not hasattr(torch.ops.npu, "print_npugraph_tensor"):
_npu_lib.define("print_npugraph_tensor(Tensor input, *, str? tensor_name=None) -> ()")
_npu_lib.impl("print_npugraph_tensor", _print_npugraph_tensor_impl, "PrivateUse1")
_npu_lib.impl("print_npugraph_tensor", _print_npugraph_tensor_impl, "CPU")
has_side_effect(torch.ops.npu.print_npugraph_tensor.default)
@torch.library.register_fake("npu::print_npugraph_tensor")
def _print_npugraph_tensor_meta(input, *, tensor_name=None):
pass
if not hasattr(torch.ops.npu, "save_npugraph_tensor"):
_npu_lib.define("save_npugraph_tensor(Tensor input, *, str? save_path=None, bool overwrite=False) -> ()")
_npu_lib.define("save_npugraph_tensor.tensor_list(Tensor[] input, *, str? save_path=None, bool overwrite=False) -> ()")
_npu_lib.impl("save_npugraph_tensor", _save_npugraph_tensor_impl, "PrivateUse1")
_npu_lib.impl("save_npugraph_tensor", _save_npugraph_tensor_impl, "CPU")
_npu_lib.impl("save_npugraph_tensor.tensor_list", _save_npugraph_tensor_tensor_list_impl, "PrivateUse1")
_npu_lib.impl("save_npugraph_tensor.tensor_list", _save_npugraph_tensor_tensor_list_impl, "CPU")
has_side_effect(torch.ops.npu.save_npugraph_tensor.default)
has_side_effect(torch.ops.npu.save_npugraph_tensor.tensor_list)
@torch.library.register_fake("npu::save_npugraph_tensor")
def _save_npugraph_tensor_meta(input, *, save_path=None, overwrite=False):
pass
@torch.library.register_fake("npu::save_npugraph_tensor.tensor_list")
def _save_npugraph_tensor_tensor_list_meta(input, *, save_path=None, overwrite=False):
pass
def super_kernel_scope_begin(scope_name: Optional[str] = None):
return torch.ops.npu.super_kernel_scope_begin(scope_name)
def super_kernel_scope_end(scope_name: Optional[str] = None):
return torch.ops.npu.super_kernel_scope_end(scope_name)
@dataclass
class _GraphDispatchRecord:
"""Record of a single dispatched operator call during graph capture."""
event: Any = None
handle: Any = None
kwargs: Dict[str, Any] = field(default_factory=dict)
args: Tuple[Any, ...] = field(default_factory=tuple)
op_cache_entry: Any = None
class _GraphDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
"""Template-method skeleton for NPU Graph capture and update.
The skeleton keeps the common stream / event / task-group orchestration
and delegates operator-specific logic to
:class:`~torch_npu.npu.NpuGraphOpHandler` classes
looked up in ``_NPU_GRAPH_OP_HANDLERS``.
"""
tensor_schema_name = {}
update_stream = None
def __new__(cls):
if cls.update_stream is None:
cls.update_stream = torch_npu.npu.Stream()
return super().__new__(cls)
def __init__(self):
self.graph_dispatch_records = []
@classmethod
def is_infra_mode(cls):
return True
@classmethod
def update_schema(cls, name, schema):
if name in cls.tensor_schema_name:
return
pattern = r'Tensor(?:\(a!\)|\?)?\s+(\w+)'
cls.tensor_schema_name[name] = re.findall(pattern, schema)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
handler_cls = _NPU_GRAPH_OP_HANDLERS.get(func.__name__)
if handler_cls:
if hasattr(handler_cls, 'should_handle'):
if not handler_cls.should_handle(func, args, kwargs):
return func(*args, **kwargs)
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
actual_func, args, kwargs = handler_cls.prepare_capture(
func, args, kwargs
)
self.update_schema(
str(actual_func.__name__), str(actual_func._schema)
)
graph_task_group_begin(stream)
result = actual_func(*args, **kwargs)
handle = graph_task_group_end(stream)
self.graph_dispatch_records.append(
self._append_dispatch_record(
event, handle, args, kwargs, actual_func, handler_cls
)
)
return handler_cls.postprocess_result(result, kwargs)
return func(*args, **kwargs)
def _append_dispatch_record(
self, event, handle, args, kwargs, func, handler_cls
):
"""Create a dispatch record, converting args / kwargs to weak-refs or deep copies.
``handler_cls`` is a required parameter (class object) guaranteed by
the capture skeleton.
"""
args_ref = []
for element in args:
if torch.is_tensor(element) and "npu" in str(element.device):
args_ref.append(TensorWeakRef(element))
else:
args_ref.append(deepcopy(element))
tensor_param_names = self.tensor_schema_name.get(
str(func.__name__), []
)
kwargs_ref = {}
for key, value in kwargs.items():
kwargs_ref[key] = handler_cls.record_wrap_kwarg(
key, value, tensor_param_names
)
return _GraphDispatchRecord(
event=event,
handle=handle,
kwargs=kwargs_ref,
args=list(args_ref),
op_cache_entry=func,
)
def update_capture_record(self, cpu_update_input):
if len(cpu_update_input) == 1:
new_list = [
cpu_update_input[0].copy()
for _ in range(len(self.graph_dispatch_records))
]
cpu_update_input = new_list
if len(cpu_update_input) != len(self.graph_dispatch_records):
raise RuntimeError(
f"Currently, there are {len(self.graph_dispatch_records)} "
f"operators that need to be updated by capture, and there "
f"are only {len(cpu_update_input)} elements in the incoming "
f"cpu_update_input list",
pta_error(ErrCode.PARAM),
)
with torch.npu.stream(self.update_stream):
for record, update_input in zip(
self.graph_dispatch_records, cpu_update_input
):
graph_task_update_begin(self.update_stream, record.handle)
for key in update_input:
if key in record.kwargs:
record.kwargs[key] = update_input[key]
handler_cls = _NPU_GRAPH_OP_HANDLERS.get(
record.op_cache_entry.__name__
)
if handler_cls is None:
raise RuntimeError(
f"No handler for recorded op: {record.op_cache_entry.__name__}. "
f"This indicates the handler was unregistered between capture and update.",
pta_error(ErrCode.PARAM),
)
handler_cls.update_args(record, update_input)
record.op_cache_entry(*record.args, **record.kwargs)
graph_task_update_end(self.update_stream)
record.event.record(self.update_stream)
class NPUGraph(torch_npu._C._NPUGraph):
r"""Wrapper around a NPU graph.
.. warning::
This API is in beta and may change in future releases.
"""
def __new__(cls):
return super().__new__(cls)
def __init__(self):
self.graph_dispatch_mode = _GraphDispatchMode()
self.auto_dispatch_capture = False
log.debug("NPUGRAPH Lifecycle NPUGraph created, graph_id=%s, auto_dispatch=%s",
id(self), self.auto_dispatch_capture)
super().__init__()
def capture_begin(self, pool=None, capture_error_mode="global"):
r"""Begin capturing NPU work on the current stream.
Typically, you shouldn't call ``capture_begin`` yourself.
Use :class:`~torch.npu.graph` or :func:`~torch.npu.make_graphed_callables`,
which call ``capture_begin`` internally.
Arguments:
pool (optional): Token (returned by :func:`~torch.npu.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.npu.NPUGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
capture_error_mode (str, optional): specifies the aclmdlRICaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During npu graph capture, some actions, such as npuMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
unless you're familiar with `aclmdlRICaptureMode`_
"""
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode, report_shape=True)
def capture_end(self):
r"""End NPU graph capture on the current stream.
After ``capture_end``, ``replay`` may be called on this instance.
Typically, you shouldn't call ``capture_end`` yourself.
Use :class:`~torch.npu.graph` or :func:`~torch.npu.make_graphed_callables`,
which call ``capture_end`` internally.
"""
super().capture_end()
def replay(self):
r"""Replay the NPU work captured by this graph."""
log.debug("NPUGRAPH Replay graph_id=%s", id(self))
super().replay()
def reset(self):
r"""Delete the graph currently held by this instance."""
log.debug("NPUGRAPH Lifecycle NPUGraph reset, graph_id=%s", id(self))
super().reset()
def pool(self):
r"""Return an opaque token representing the id of this graph's memory pool.
This id can optionally be passed to another graph's ``capture_begin``,
which hints the other graph may share the same memory pool.
"""
return super().pool()
def update(self, cpu_update_input):
log.debug("NPUGraph: updating graph (%s inputs)...", len(cpu_update_input))
if not self.auto_dispatch_capture:
raise RuntimeError(
"The current graph configuration does not support update,"
"Try to capture by setting auto_dispatch_capture=True during capture",
pta_error(ErrCode.PARAM),
)
self.graph_dispatch_mode.update_capture_record(cpu_update_input)
def debug_dump(self, debug_path):
r"""Calls a function to dump the graph in JSON format.
Arguments:
debug_path (required): Path to dump the graph to.
"""
return super().debug_dump(debug_path)
def super_kernel_optimize(self, optimize_options=None, debug_options=None):
r"""Calls a function to optimize graph by super kernel.
Arguments:
optimize_options (optional):
preload_code (int) - Controls code preloading strategy.
split_mode (int) - Controls kernel splitting mode for better performance.
stream_fusion (int) - Enables/disables stream fusion optimization.
constant_codegen (int) - Enables/disables constant code generation.
auto_op_parallel (int) - Enables/disables auto op parallel optimization.
opt_extend (str) - Extended optimization option string.
dcci_before_kernel_start (list[str]) - List of kernel names to insert DCCI before kernel start.
dcci_after_kernel_end (list[str]) - List of kernel names to insert DCCI after kernel end.
dcci_disable_on_kernel (list[str]) - List of kernel names to disable DCCI.
aggressive_opt_strategies (dict) - Aggressive optimization strategies:
event_breaker_bypass (int) - Event wait继续融合(默认死锁检测通过融合)
value_breaker_bypass (int) - Value wait融合策略
task_breaker_bypass (int) - 不支持算子继续融合(默认断开)
debug_options (optional):
debug_sync_all (int) - Enables debug synchronization for all operations.
debug_op_exec_trace (int) - Enables/disables op execution trace.
debug_cross_core_sync_check (int) - Enables/disables cross core sync check.
debug_extend (str) - Extended debug option string.
"""
self._validate_options("optimize_options", optimize_options)
self._validate_options("debug_options", debug_options)
return super().super_kernel_optimize(optimize_options, debug_options)
def _validate_options(self, option_name, options=None):
if options is None:
return
if not isinstance(options, dict):
raise RuntimeError(f"{option_name} param must be dict or None.", pta_error(ErrCode.PARAM))
valid_options = {
'optimize_options': {
'preload_code': {
'value_type': int
},
'split_mode': {
'value_type': int
},
'stream_fusion': {
'value_type': int
},
'constant_codegen': {
'value_type': int
},
'auto_op_parallel': {
'value_type': int
},
'opt_extend': {
'value_type': str
},
'dcci_before_kernel_start': {
'value_type': list,
'element_type': str
},
'dcci_after_kernel_end': {
'value_type': list,
'element_type': str
},
'dcci_disable_on_kernel': {
'value_type': list,
'element_type': str
},
'aggressive_opt_strategies': {
'value_type': dict,
'sub_options': {
'event_breaker_bypass': {'value_type': int},
'value_breaker_bypass': {'value_type': int},
'task_breaker_bypass': {'value_type': int}
}
},
'ubuf_lock_ignore_kernel': {
'value_type': list,
'element_type': str
},
'early_start': {
'value_type': int
}
},
'debug_options': {
'debug_sync_all': {
'value_type': int
},
'debug_op_exec_trace': {
'value_type': int
},
'debug_cross_core_sync_check': {
'value_type': int
},
'debug_extend': {
'value_type': str
},
'debug_per_op_max_core_num': {
'value_type': int
}
}
}
for key, value in options.items():
if option_name not in valid_options:
raise RuntimeError(f"Invalid {option_name} param.", pta_error(ErrCode.PARAM))
if key not in valid_options[option_name]:
raise RuntimeError(f"Invalid {option_name} param key: '{key}'.", pta_error(ErrCode.PARAM))
expected_type = valid_options[option_name][key]['value_type']
if not isinstance(value, expected_type):
raise RuntimeError(f"{option_name} param['{key}'] must be {expected_type.__name__}, "
f"got {type(value).__name__}", pta_error(ErrCode.PARAM))
if expected_type == dict and 'sub_options' in valid_options[option_name][key]:
for sub_key, sub_value in value.items():
if sub_key not in valid_options[option_name][key]['sub_options']:
raise RuntimeError(f"Invalid sub_key '{sub_key}' in {option_name}['{key}'].",
pta_error(ErrCode.PARAM))
sub_expected_type = valid_options[option_name][key]['sub_options'][sub_key]['value_type']
if not isinstance(sub_value, sub_expected_type):
raise RuntimeError(f"{option_name}['{key}']['{sub_key}'] must be {sub_expected_type.__name__}, "
f"got {type(sub_value).__name__}", pta_error(ErrCode.PARAM))
if expected_type == list and 'element_type' in valid_options[option_name][key]:
element_type = valid_options[option_name][key]['element_type']
for i, elem in enumerate(value):
if not isinstance(elem, element_type):
raise RuntimeError(f"{option_name}['{key}'][{i}] must be {element_type.__name__}, "
f"got {type(elem).__name__}", pta_error(ErrCode.PARAM))
class graph:
r"""Context-manager that captures NPU work into a :class:`torch.npu.NPUGraph` object for later replay.
See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
detailed use, and constraints.
Arguments:
npu_graph (torch.npu.NPUGraph): Graph object used for capture.
pool (optional): Opaque token (returned by a call to :func:`~torch.npu.graph_pool_handle()` or
:meth:`other_Graph_instance.pool()<torch.npu.NPUGraph.pool>`) hinting this graph's capture
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
stream (torch.npu.Stream, optional): If supplied, will be set as the current stream in the context.
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
capture_error_mode (str, optional): specifies the aclmdlRICaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During npu graph capture, some actions, such as npuMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `aclmdlRICaptureMode`_
.. note::
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
.. warning::
This API is in beta and may change in future releases.
"""
default_capture_stream: typing.Optional["torch.npu.Stream"] = None
def __init__(
self,
npu_graph,
pool=None,
stream=None,
auto_dispatch_capture=False,
capture_error_mode: str = "global",
):
if self.__class__.default_capture_stream is None:
self.__class__.default_capture_stream = torch.npu.Stream()
self.pool = () if pool is None else (pool,)
self.capture_stream = (
stream if stream is not None else self.__class__.default_capture_stream
)
if self.capture_stream is None:
raise RuntimeError("capture stream is None")
self.stream_ctx = torch.npu.stream(self.capture_stream)
self.npu_graph = npu_graph
self.capture_error_mode = capture_error_mode
self.npu_graph.auto_dispatch_capture = auto_dispatch_capture
def __enter__(self):
log.debug("NPUGRAPH Capture device=%s, pool=%s, mode=%s, auto_dispatch=%s",
torch.npu.current_device(), self.pool, self.capture_error_mode,
self.npu_graph.auto_dispatch_capture)
torch.npu.synchronize()
if force_npugraph_gc:
gc.collect()
torch.npu.empty_cache()
torch_npu.npu.host_empty_cache()
self.stream_ctx.__enter__()
if self.npu_graph.auto_dispatch_capture:
self.npu_graph.graph_dispatch_mode.__enter__()
self.npu_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode
)
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
log.error("NPUGraph: ERROR — capture failed: %s: %s, "
"graph_id=%s, stream=%s, pool=%s, device=%s",
exc_type.__name__, exc_value,
id(self.npu_graph), self.capture_stream,
self.pool, torch.npu.current_device())
log.warning("NPUGRAPH DFX common causes: default stream, CPU ops, multi-device tensors, "
"dynamic control flow on tensor values")
else:
log.debug("NPUGraph: capture completed, graph_id=%s, device=%s",
id(self.npu_graph), torch.npu.current_device())
self.npu_graph.capture_end()
if self.npu_graph.auto_dispatch_capture:
self.npu_graph.graph_dispatch_mode.__exit__(exc_type, exc_value, traceback)
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
Each graphed callable's forward pass runs its source callable's
forward CUDA work as a CUDA graph inside a single autograd node.
The graphed callable's forward pass also appends
a backward node to the autograd graph. During backward, this node runs the
callable's backward work as a CUDA graph.
Therefore, each graphed callable should be a drop-in replacement for its source callable
in an autograd-enabled training loop.
See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
If you pass a tuple of several callables, their captures will use the same memory pool.
See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
Arguments:
callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
they'll run in the live workload.
sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
11 iterations for warm up. Default: ``3``.
allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
(and therefore their grad is always zero) is an error. Defaults to False.
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
.. note::
The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
that's expected for the corresponding real input in the training loop.
.. warning::
This API is in beta and may change in future releases.
.. warning::
``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
.. warning::
Returned callables do not support higher order differentiation (e.g., double backward).
.. warning::
In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
may be trainable. Buffers must have ``requires_grad=False``.
.. warning::
After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
you may not add or remove any of that Module's parameters or buffers.
.. warning::
:class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
registered on them at the time they are passed. However, registering hooks on modules *after* passing them
through :func:`~torch.cuda.make_graphed_callables` is allowed.
.. warning::
When running a graphed callable, you must pass its arguments in the same order and format
they appeared in that callable's ``sample_args``.
.. warning::
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
"""
if torch_npu.npu.is_autocast_enabled() and torch.is_autocast_cache_enabled():
log.error("NPUGraph: ERROR — autocast caching not supported")
log.warning("NPUGRAPH DFX Fix: set cache_enabled=False in torch.autocast()")
raise RuntimeError(
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
)
just_one_callable = False
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
flatten_sample_args = []
for c, args in zip(callables, sample_args):
if isinstance(c, torch.nn.Module):
if len(c._backward_hooks) > 0 or len(c._forward_hooks) > 0 or len(c._forward_pre_hooks) > 0:
log.error("NPUGraph: ERROR — hooks registered on module %s", type(c).__name__)
log.warning("NPUGRAPH DFX Fix: remove hooks before make_graphed_callables, re-register after")
raise RuntimeError(
"Modules must not have hooks registered at the time they are passed. However, "
+ "registering hooks on modules after passing them through make_graphed_callables is allowed."
)
if any(b.requires_grad for b in c.buffers()):
raise RuntimeError(
"In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`,"
+ " only parameters may be trainable. All buffers must have ``requires_grad=False``."
)
flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
flatten_sample_args.append(tuple(flatten_arg))
if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg):
raise RuntimeError(
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
per_callable_module_params = [
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
for c in callables
]
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(callables))
]
fwd_graphs = [torch_npu.npu.NPUGraph() for _ in range(len(callables))]
bwd_graphs = [torch_npu.npu.NPUGraph() for _ in range(len(callables))]
mempool = graph_pool_handle() if pool is None else pool
torch_npu.npu.synchronize()
with torch_npu.npu.stream(torch_npu.npu.Stream()):
for func, args, static_input_surface in zip(
callables, sample_args, per_callable_static_input_surfaces
):
grad_inputs, outputs, outputs_grad = None, None, None
for _ in range(num_warmup_iters):
outputs = torch.utils._pytree.tree_leaves(func(*args))
outputs_grad = tuple(o for o in outputs if o.requires_grad)
if len(outputs_grad) > 0:
grad_inputs = torch.autograd.grad(
outputs=outputs_grad,
inputs=tuple(
i for i in static_input_surface if i.requires_grad
),
grad_outputs=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
only_inputs=True,
allow_unused=allow_unused_input,
)
for v in [outputs, outputs_grad, grad_inputs]:
del v
torch_npu.npu.synchronize()
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
with torch_npu.npu.graph(fwd_graph, pool=mempool):
outputs = func(*args)
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec)
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
reversed(per_callable_module_params),
):
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
grad_inputs = None
if len(outputs_grad) > 0:
with torch_npu.npu.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=outputs_grad,
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad and grad_inputs is not None:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None)
static_grad_inputs = tuple(static_grad_inputs)
per_callable_static_grad_outputs.append(static_grad_outputs)
per_callable_static_grad_inputs.append(static_grad_inputs)
per_callable_static_grad_outputs.reverse()
per_callable_static_grad_inputs.reverse()
def make_graphed_autograd_function(
fwd_graph,
bwd_graph,
module_params,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
):
class Graphed(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs):
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
fwd_graph.replay()
if not isinstance(static_outputs, tuple):
raise RuntimeError("static_outputs is not tuple.")
return tuple(o.detach() for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
if (len(grads) != len(static_grad_outputs)):
raise RuntimeError(
"The length of grads"
+ " is not equal with the length of static_grad_outputs."
)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
if not isinstance(static_grad_inputs, tuple):
raise RuntimeError("static_grad_inputs is not tuple.")
return tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
def functionalized(*user_args):
flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
return functionalized
ret = []
for i, func in enumerate(callables):
graphed = make_graphed_autograd_function(
fwd_graphs[i],
bwd_graphs[i],
per_callable_module_params[i],
per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
)
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def new_fwd(*user_args):
if func.training == graph_training_state:
return graphed(*user_args)
else:
return orig_fwd(*user_args)
return new_fwd
func.forward = make_graphed_forward(func, func.training, graphed, func.forward)
ret.append(func)
else:
ret.append(graphed)
log.debug("NPUGraph: graphed callables ready")
if just_one_callable:
return ret[0]
return tuple(ret)