"""
NPU graph trees are a safety abstraction over NPUGraphs, similar to make_graph_callables,
which share the same memory pool. Sharing a memory pool is an extremely
important optimization when chaining multiple NPU graphs together, as it
prevents you from needing to copy intermediate tensors from one graph to the
next, and reduces overall memory usage by allowing dead memory from the first
pool to be reused in the second.
The standard graph/make_graph_callables support sharing memory pool, but
with a lot of caveats. NPU graph trees remove these restrictions:
* Previously, if you recorded graphs A, B, you had to replay A, B in that
order. With NPU graph trees, after replaying A, you can change your
mind and record/replay a different graph B'; we will support efficient
execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In
other words: we support arbitrary trees of NPU graph operations, not just
sequences (this is why this feature is called NPU graph trees.)
* Previously, if you executed graph A, some non-NPU graph code, and then
graph B, after executing graph B, it was not safe to retain any references
to intermediates produced by A. With NPU graph trees, we track if any
outputs of graph A are still live by the time graph B is run, and make
sure graph B doesn't clobber there memory when reusing the NPU graphs
pool. You'll get a separate recording of B depending on what tensors
stay live or dead.
NPU graph trees are flexible enough to be used in Dynamo across graph breaks,
which is their primary use case.
The ability to switch from replay to record is fairly nontrivial: remember that
when you replay a NPU graph, you only replay NPU operations; no CPU side state
is updated. In particular, the CPU-side book-keeping for the allocator is not
reconstructed. However, to record a new child NPU graph, we must restore this
book-keeping. This is what checkpoint pool state is used for.
"""
from __future__ import annotations
import contextlib
import dataclasses
import functools
import gc
import itertools
import operator
import sys
import threading
import traceback
import warnings
import weakref
import logging
from collections import defaultdict
from enum import auto, Enum
from typing import (
Any,
Callable,
cast,
ContextManager,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
import torch.fx
from torch import Tensor
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state
from torch._inductor import config
from torch._inductor.compile_fx import (
align_inputs_from_check_idxs,
copy_misaligned_inputs,
get_expanded_dims,
get_input_idxs_to_check,
index_expanded_dims,
remove_unaligned_input_idxs,
static_input,
)
from torch._inductor.cudagraph_utils import (
check_for_mutation,
CheckInvariantStatus,
FunctionID,
log_cudagraph_skip_and_bump_counter,
log_data_ptr_mismatch,
maybe_warning_due_to_dynamic_shape,
ModelType,
OutputType,
PlaceholderInfo,
WrappedFunction,
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage
from torch.utils import _pytree as pytree
from torch.utils.weak import TensorWeakRef
import torch_npu
from torch_npu._C import (
_npu_NPUAllocator_AllocatorState as AllocatorState,
_set_cached_tensors_enabled as _set_cached_tensors_enabled)
import torch_npu.npu.aclnn
if TYPE_CHECKING:
from torch._inductor.utils import InputType
from torch.types import _bool
StorageWeakRefPointer = int
StorageDataPtr = int
NBytes = int
S = TypeVar("S", bound="StorageWeakRefWrapper")
log = logging.getLogger("torch_npu.npugraph")
@dataclasses.dataclass(frozen=True)
class GraphID:
"Unique counter of a npu graph recording"
id: int
def clear_cublass_cache() -> None:
"""
Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for
doing warmup within a NPUGraph private pool because we do not want persistent allocations from
one one run to the next. When we begin a new run of a npugraphs path (generation), all tensors
from the previous generation are freed. This frees them the memory pool, but not elsewhere.
A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated
in the next run. The memory would be in use in two places.
To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required
it will be allocated to the npugraph private pool and accounted for in the allocator for the duration of the
program. There is no overhead to this on replay since npugraphs removes allocation overhead.
"""
pass
@contextlib.contextmanager
def clear_cublas_manager() -> Generator[None, None, None]:
"Context manager around clearing cublas caches that will clear on enter and exit"
clear_cublass_cache()
try:
yield
finally:
clear_cublass_cache()
@contextlib.contextmanager
def disable_conv_cache_emptying() -> Generator[None, None, None]:
try:
yield
finally:
pass
@contextlib.contextmanager
def enable_history_recording() -> Generator[None, None, None]:
"Turns on history recording in the NPU Caching Allocator"
enabled = torch_npu._C._npu_isHistoryEnabled()
try:
if not enabled:
torch.npu.memory._record_memory_history()
yield
finally:
if not enabled:
torch.npu.memory._record_memory_history(None)
def get_history_recording() -> ContextManager[None]:
if not config.triton.cudagraph_trees_history_recording:
return contextlib.nullcontext()
return enable_history_recording()
class TreeManagerContainer:
"""
Manages the lifetime of the tree manager. Like `PrivatePool` in npu caching allocator,
the tree and its corresponding memory pool should be kept alive as long as any outstanding
graph or tensor which is an output of a graph remains alive.
There is a single tree manager container per device.
The lifecycle of a tree_manager is:
- Is constructed, no graph, no fns, no tensors
- Tree manager is fetched, resulting in tree manager being allocated
- We generate a bunch of functions, calling add_strong_reference
- These functions die, calling finalize_reference
- When all the functions die, we finalize_tree_manager.
In the future, we would like to do the following once storage weak refs land
- We look for all the live storages and add references to THOSE
- We count as storages die
- All the storages are dead, we deallocate the tree manager
"""
def __init__(self, device_index: int) -> None:
self.tree_manager: Optional[NPUGraphTreeManager] = None
self.live_npugraphify_fns = 0
self.device_index = device_index
self.live_storages_count = 0
self.graph: Optional[torch.npu.NPUGraph] = None
self.lock = threading.Lock()
def _finalize_tensor(self) -> None:
with self.lock:
self.live_storages_count -= 1
if self.live_storages_count == 0:
self.graph = None
if self.live_npugraphify_fns == 0:
self.tree_manager = None
def finalize_npugraphify_fn(self) -> None:
with self.lock:
self.live_npugraphify_fns -= 1
if self.live_npugraphify_fns == 0:
self._finalize_tree_manager()
def _finalize_tree_manager(self) -> None:
if not self.lock.locked():
raise RuntimeError("check self.lock.locked() fail")
self.tree_manager = None
def add_strong_reference(self, fn: Callable[..., Any]) -> None:
with self.lock:
self.live_npugraphify_fns += 1
weakref.finalize(fn, self.finalize_npugraphify_fn)
def get_tree_manager(self) -> NPUGraphTreeManager:
with self.lock:
if self.tree_manager is None:
self.tree_manager = NPUGraphTreeManager(self.device_index)
return self.tree_manager
local = threading.local()
local.npu_tree_manager_containers = {}
local.npu_tree_manager_locks = defaultdict(threading.Lock)
class MarkStepBox:
mark_step_counter = 0
torch._C._stash_obj_in_tls("npu_tree_manager_containers", local.npu_tree_manager_containers)
torch._C._stash_obj_in_tls("npu_tree_manager_locks", local.npu_tree_manager_locks)
def mark_step_begin() -> None:
"Indicates that a new iteration of inference or training is about to begin."
MarkStepBox.mark_step_counter -= 1
def reset_npugraph_trees() -> None:
"Clear all npugraph trees"
container_dict = get_obj(local, "npu_tree_manager_containers")
locks_dict = get_obj(local, "npu_tree_manager_locks")
for device, lock in locks_dict.items():
with lock:
container = container_dict.get(device)
if not container or not container.tree_manager:
continue
container.tree_manager.shutdown()
_set_cached_tensors_enabled(False)
container_dict.clear()
MarkStepBox.mark_step_counter = 0
def get_obj(thread_local: Any, attr_name: str) -> Any:
if hasattr(thread_local, attr_name):
return getattr(thread_local, attr_name)
else:
if not torch._C._is_key_in_tls(attr_name):
raise RuntimeError("check torch._C._is_key_in_tls(attr_name) fail")
return torch._C._get_obj_in_tls(attr_name)
def get_container(device_index: int) -> TreeManagerContainer:
container_dict = get_obj(local, "npu_tree_manager_containers")
lock = get_obj(local, "npu_tree_manager_locks")[device_index]
with lock:
if device_index not in container_dict:
container_dict[device_index] = TreeManagerContainer(device_index)
return container_dict[device_index]
def get_manager(
device_index: int, create_if_none_exists: bool = True
) -> Optional[NPUGraphTreeManager]:
if create_if_none_exists:
return get_container(device_index).get_tree_manager()
return get_container(device_index).tree_manager
def npugraphify_impl(
model: ModelType,
inputs: List[InputType],
static_input_idxs: Sequence[int],
*args: Any,
**kwargs: Any,
) -> ModelType:
fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {}
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None
has_warn = False
del inputs
def deferred_npugraphify(inputs: List[InputType]) -> OutputType:
nonlocal has_warn
int_key = get_ints(inputs)
fn = fn_cache.get(int_key)
if fn is not None:
return fn(inputs)
if int_key is None:
log.debug("NPUGRAPH-TREE Compile recording, key=None")
else:
log.debug("NPUGRAPH-TREE Compile recording, key=%s", int_key)
if not has_warn:
has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key)
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
copy_misaligned_inputs(inputs, check_input_idxs)
fn, out = npugraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs)
fn_cache[int_key] = fn
return out
return deferred_npugraphify
def npugraphify(
model: ModelType,
inputs: List[InputType],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
is_backward: bool,
is_inference: bool,
stack_traces: Optional[StackTraces] = None,
constants: Tuple[torch.Tensor, ...] = (),
placeholders: Tuple[PlaceholderInfo, ...] = (),
mutated_input_idxs: Tuple[int, ...] = (),
) -> Tuple[ModelType, OutputType]:
manager = get_container(device_index).get_tree_manager()
if is_backward and is_inference:
raise RuntimeError("check is_backward and is_inference fail")
mode = (
CompilationMode.BACKWARD
if is_backward
else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
)
return manager.add_function(
model,
inputs,
static_input_idxs,
stack_traces,
mode,
constants,
placeholders,
mutated_input_idxs,
)
class StorageWeakRefWrapper:
"""
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
"""
__slots__ = ["ref", "_data_ptr", "extra_ref_check"]
storage_ref: Optional[StorageWeakRef]
def __init__(
self,
inp: Union[Tensor, UntypedStorage],
extra_ref_check: Optional[Callable[[], bool]] = None,
) -> None:
"""
extra_ref_check is an additional check we need to run to check if the
weak ref has expired. in checking storage use count we assume extra_ref_check
will hold an additional reference to the storage.
"""
if isinstance(inp, Tensor):
stor = inp.untyped_storage()
else:
if not isinstance(inp, UntypedStorage):
raise RuntimeError("check isinstance(inp, UntypedStorage) fail")
stor = inp
self.ref = StorageWeakRef(stor)
self._data_ptr = stor.data_ptr()
self.extra_ref_check = extra_ref_check
@classmethod
def from_weakref_and_data_ptr(
cls: Type[S],
cdata: Any,
data_ptr: int,
extra_ref_check: Optional[Callable[[], bool]] = None,
) -> StorageWeakRefWrapper:
instance = cls.__new__(cls)
instance._data_ptr = data_ptr
instance.ref = StorageWeakRef.from_weakref(cdata)
instance.extra_ref_check = extra_ref_check
return instance
def __call__(self) -> Optional[StorageWeakRefPointer]:
if self.expired():
return None
return self.ref.cdata
def swap_weakref(self, cdata: Any) -> None:
self.ref.__del__()
self.ref.cdata = cdata
def data_ptr(self) -> int:
"NB: returns the data ptr even if the storage has expired"
return self._data_ptr
def remove_extra_reference(self) -> None:
self.extra_ref_check = None
def expired(self) -> bool:
if self.extra_ref_check is not None and not self.extra_ref_check():
return False
stor_count = torch_npu._C._storage_Use_Count(self.ref.cdata)
return (stor_count - (self.extra_ref_check is not None)) == 0
def __repr__(self) -> str:
if self.ref is None or self.ref.expired():
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
else:
return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool:
return maybe_deref(weak_ref) is not None
def maybe_deref(
weak_ref: Optional[StorageWeakRefWrapper],
) -> Optional[Tuple[StorageWeakRefPointer, int]]:
if weak_ref is None:
return None
r = weak_ref()
if r is None:
return None
return r, weak_ref.data_ptr()
@contextlib.contextmanager
def _use_npu_memory_pool_manager(
device: int, mem_pool: Tuple[int, int], stream: torch.npu.Stream
) -> Generator[None, None, None]:
"""
Context manager to use npu graph pool for new allocations. If you use this manager
all npugraph tensors in use should be reflected in the allocator or they will be overwritten.
existing_graph should already have been used in a capture, and the mem_pool must already exist,
because this manager will not preserve a reference to the pool which keeps it alive.
"""
torch.npu.synchronize()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream), torch.device(device):
torch_npu._C._npu_beginAllocateCurrentThreadToPool(device, mem_pool)
try:
yield
finally:
torch_npu._C._npu_endAllocateCurrentStreamToPool(device, mem_pool)
torch_npu._C._npu_releasePool(device, mem_pool)
torch.npu.current_stream().wait_stream(stream)
def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
if not isinstance(t, torch.Tensor):
if t is not None:
raise RuntimeError("check t is None fail")
return None
return StorageWeakRefWrapper(t)
PathOutputIndex = Tuple[int, int]
PathLiveness = List[List[bool]]
StackTraces = List[Optional[str]]
class NPUWarmupNode:
"""
Simplified Wrapper around A NPU Model that wraps outputs in storage refs and exposes
apis to get the live storages in the current chain of warmup.
A NPUWarmupNode may have either NPUGraphNode or NPUWarmupNode as a parent, but may only have
NPUWarmupNode as children, because we cannot record or execute with tensors which do not have stable
memory addresses.
NPUWarmupNode and NPUGraphNode have a number of differences that make it easier to use separate classes.
- Much of the NPUGraphNode logic & initialization is based on the tensor properties of first recording. In the
first instance of warmup, these are not finalized yet.
- All Inputs to the RecordedFunction must be copied over to the npu graph memory pool, this is unnecessary in warmup.
- NPUWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler.
NB: this class and NPUGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and
`self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility.
"""
def __init__(
self,
wrapped_function: WrappedFunction,
parent: Optional[Union[NPUGraphNode, NPUWarmupNode]],
npu_graphs_pool: Tuple[int, int],
existing_npu_graph: Optional[torch.npu.NPUGraph],
device_index: int,
stack_traces: Optional[StackTraces],
stream: torch.npu.Stream,
already_warm: bool,
graph_id: GraphID,
) -> None:
self.wrapped_function = wrapped_function
self.parent: Optional[Union[NPUGraphNode, NPUWarmupNode]] = parent
self.npu_graphs_pool = npu_graphs_pool
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
self.existing_npu_graph = existing_npu_graph
self.has_run = False
self.device_index = device_index
self.stack_traces = stack_traces
self.stream = stream
self.already_warm = already_warm
self.id = graph_id
def run(self, new_inputs: Any) -> OutputType:
if self.has_run:
raise RuntimeError("Wrapped function should never be run twice")
existing_path_data_ptrs = {
t.data_ptr()
for t in self.path_live_weakrefs()
if t()
}
def get_non_npugraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
non_npugraph_inps = []
for t in itertools.chain(new_inputs, self.wrapped_function.constants):
if (
isinstance(t, torch.Tensor)
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
):
non_npugraph_inps.append(weakref.ref(t.untyped_storage()))
return non_npugraph_inps
non_npugraph_inps_storages = get_non_npugraph_inps()
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.npu_graphs_pool, refs)
with torch.npu.device(
self.device_index
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_npu_memory_pool_manager(
self.device_index, self.npu_graphs_pool, self.stream
), get_history_recording():
if torch_npu.npu.aclnn._use_static_aclnn_kernel:
from torch_npu._inductor.npu_static_kernel import StaticKernelCompiler
static_kernel_complier = StaticKernelCompiler()
with static_kernel_complier:
out = self.wrapped_function.model(new_inputs)
else:
out = self.wrapped_function.model(new_inputs)
non_npugraph_inps_storage_ptrs = set()
for storage in non_npugraph_inps_storages:
s = storage()
if s is not None:
non_npugraph_inps_storage_ptrs.add(s._cdata)
if not len(new_inputs) == 0:
raise RuntimeError("check len(new_inputs) == 0 fail")
def add_ref(out_tensor: Any) -> bool:
return (
out_tensor is not None
and isinstance(out_tensor, torch.Tensor)
and out_tensor.is_npu
and out_tensor.untyped_storage()._cdata not in non_npugraph_inps_storage_ptrs
and out_tensor.untyped_storage().data_ptr() != 0
)
self.outputs_weakrefs.extend(
[map_to_ref(out_) if add_ref(out_) else None for out_ in out]
)
self.tensor_weakrefs.extend(
[TensorWeakRef(out_) if add_ref(out_) else None for out_ in out]
)
if config.triton.slow_path_cudagraph_asserts and not self.already_warm:
out_refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.npu_graphs_pool, out_refs)
return out
@property
def _path_from_root(
self,
) -> Generator[Union[NPUGraphNode, NPUWarmupNode], None, None]:
nodes = []
node: Union[NPUGraphNode, NPUWarmupNode] = self
while node:
nodes.append(node)
node = node.parent
yield from reversed(nodes)
def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
"Returns all live storages weakrefs that created by nodes in this path"
for node in self._path_from_root:
for output in node.outputs_weakrefs:
if is_live(output):
yield output
def all_outputs_are_dead(self) -> bool:
return not list(self.path_live_weakrefs())
def _is_npu_graph_recorded_tensor(self, t: torch.Tensor) -> bool:
for storage_weak_ref in self.path_live_weakrefs():
if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr():
return True
return False
InputList = List
OutputList = List
LevelList = List
class OutputAliasInfo:
pass
class _UnaliasedStorage(OutputAliasInfo):
"Singleton to mark that the graph output constructs a new alias or is None"
pass
UnaliasedStorage = _UnaliasedStorage()
class AliasesPriorGraphOutput(OutputAliasInfo):
"Marks that the graph output aliases an output of a prior graph"
__slots__ = ["index"]
index: PathOutputIndex
def __init__(self, index: PathOutputIndex) -> None:
if not isinstance(index, tuple):
raise RuntimeError("check isinstance(index, tuple) fail")
self.index = index
class AliasesNewOutput(OutputAliasInfo):
"Marks that the graph output aliases an index in the new, returned outputs"
__slots__ = ["index"]
index: int
def __init__(self, index: int) -> None:
if not isinstance(index, int):
raise RuntimeError("check isinstance(index, int) fail")
self.index = index
class NPUGraphNode:
"""
A single recording of a function into a NPU Graph. Recordings of NPU Graphs share a single memory pool
and are structured into a tree, where there is a single recording that can precede it (parent) and multiple
subsequent recordings that may follow (children). A node will have no parent if it is the first recording
in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which
would force a dependency.
On first recording, all of the live tensors in the current NPU Graph Node path will be
reflected in the corresponding private pool. On subsequent executions, the caching allocator
is unaffected when the graph is replayed.
In order to support recording a subsequent npu graph recording after execution of this graph,
we checkpoint the state of the memory pool so that it may later be resumed.
WrappedFunction should have already been warmed up prior to invocation.
"""
def __init__(
self,
wrapped_function: WrappedFunction,
graph_id: GraphID,
parent: Optional[NPUGraphNode],
inputs: List[Tensor],
npu_graphs_pool: Tuple[int, int],
device_index: int,
stack_traces: Optional[StackTraces],
stream: torch.npu.Stream,
) -> None:
if not isinstance(inputs, (list, tuple)):
raise RuntimeError("check isinstance(inputs, (list, tuple))")
self.wrapped_function = wrapped_function
self.id = graph_id
self.device = device_index
self.stack_traces = stack_traces
self.stream = stream
self.rerecord_if_static_inputs_change = (
torch._dynamo.config.inline_inbuilt_nn_modules
or torch._inductor.config.triton.cudagraph_support_input_mutation
)
self._parent = weakref.ref(parent) if parent is not None else None
self.npu_graphs_pool = npu_graphs_pool
self.children: Dict[FunctionID, List[NPUGraphNode]] = defaultdict(list)
self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = []
self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
node.outputs_weakrefs
for node in self._path_from_root
]
self.path_stacktraces: LevelList[Optional[StackTraces]] = [
node.stack_traces
for node in self._path_from_root
]
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
self.npugraph_managed_idxs: List[int] = [
idx
for idx, t in enumerate(inputs)
if isinstance(t, torch.Tensor) and self._is_npu_graph_recorded_tensor(t)
]
self.live_npugraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
(
self._is_alias_of_live_recorded_tensor(t)
if isinstance(t, torch.Tensor)
else None
)
for t in inputs
]
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
self.static_input_idxs: List[int] = list(
set(wrapped_function.static_input_idxs) | set(self.npugraph_managed_idxs)
)
self.non_static_input_idx: LevelList[int] = [
i
for i in range(len(inputs))
if i not in self.static_input_idxs
]
counters["inductor"]["npugraph_recorded_non_static_inputs"] += len(
self.non_static_input_idx
)
self.non_managed_static_input_idxs: LevelList[int] = [
i
for i in wrapped_function.static_input_idxs
if i not in self.npugraph_managed_idxs
]
def maybe_get_static_data_ptr(
idx: int,
inputs: List[InputType],
static_input_idxs: List[int],
) -> Optional[int]:
inp = inputs[idx]
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
return inp.data_ptr()
return None
self.static_input_data_ptrs: InputList[Optional[int]] = [
maybe_get_static_data_ptr(i, inputs, self.static_input_idxs)
for i in range(len(inputs))
]
self.expanded_dims: List[List[int]] = [
get_expanded_dims(x)
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
else []
for idx, x in enumerate(inputs)
]
self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = []
self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = []
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
self.live_indices_after_graph: List[PathOutputIndex] = []
if self.parent is not None:
previous_liveness = self.parent.recorded_liveness_after_graph
curr_liveness = self._get_liveness(self.path_weakrefs)
different_indices = self._get_different_indices(
previous_liveness, curr_liveness
)
self.recorded_liveness_before_graph = curr_liveness
self.expected_dead_indices_before_graph = different_indices
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
inputs.clear()
del inputs
self.graph: Optional[torch.npu.NPUGraph] = torch.npu.NPUGraph()
self.reconstructed_inputs: List[InputType] = [
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
if isinstance(x, torch.Tensor)
else x
for x in recording_inputs
]
self.checkpointed_caching_state: Optional[AllocatorState] = None
self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = []
self.unaliased_in_all_paths: OutputList[bool] = []
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
self.static_output_tensors: OutputList[Optional[Tensor]] = []
self.recording_outputs: Optional[OutputType] = self._record(
wrapped_function.model, recording_inputs)
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
if self.recording_outputs is None:
raise RuntimeError("check self.recording_outputs is not None fail")
for out in self.recording_outputs:
if isinstance(out, torch.Tensor):
self.outputs_metadata.append(
self._tensor_metadata(out, ignore_storage_offset=False)
)
else:
if not isinstance(out, (int, type(None))):
raise RuntimeError("check isinstance(out, (int, type(None))) fail")
self.outputs_metadata.append(out)
self.graph.replay()
def _copy_inputs_and_remove_from_src(
self, dsts: List[InputType], srcs: List[InputType]
) -> None:
dst_tensors = []
src_tensors = []
for idx in self.non_static_input_idx:
if not isinstance(srcs[idx], torch.Tensor):
continue
expanded_dims = self.expanded_dims[idx]
dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims))
src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims))
srcs[idx] = None
if dst_tensors:
torch._foreach_copy_(dst_tensors, src_tensors)
def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None:
if (
not self.rerecord_if_static_inputs_change
and not torch_npu._C._tensors_data_ptrs_at_indices_equal(
new_inputs,
self.static_input_data_ptrs,
self.non_managed_static_input_idxs,
)
):
error_msg = log_data_ptr_mismatch(
self.wrapped_function.placeholders,
new_inputs,
self.static_input_data_ptrs,
self.non_managed_static_input_idxs,
CheckInvariantStatus.StaticInputIdxMismatch,
)
torch._check(False, lambda: error_msg)
def _copy_input(self, idx, dst, src):
expanded_dims = self.expanded_dims[idx]
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)
def _record_input(self, idx, dst, src, dst_record, src_record):
expanded_dims = self.expanded_dims[idx]
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dtype = dst.dtype
if dtype not in dst_record.keys():
dst_record[dtype] = []
src_record[dtype] = []
dst_record[dtype].append(dst)
src_record[dtype].append(src)
def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType:
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_before_invocation()
if not len(new_inputs) == 0:
raise RuntimeError("check len(new_inputs) == 0 fail")
outputs = self.recording_outputs
self.recording_outputs = None
if outputs is None:
raise RuntimeError("check outputs is not None fail")
return outputs
def run(self, new_inputs: List[InputType]) -> OutputType:
log.debug("NPUGRAPH-TREE Node Run node=%s", self.id)
self.check_static_inputs_are_stable(new_inputs)
for item in new_inputs:
if isinstance(item, torch.Tensor) and item.dtype == torch.int32 and item.device.type == "cpu":
self.graph.update(cpu_update_input=[{"context_lens": item}, {"actual_seq_lengths_kv": item}])
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
self.run_graph()
outputs = self.reconstruct_outputs()
new_inputs.clear()
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_after_invocation()
if config.triton.force_cudagraph_sync:
torch.npu.synchronize()
self.static_inputs_stable = False
return outputs
def reconstruct_outputs(self) -> OutputType:
"Reconstruct output tensors according to their saved metadata and alias information"
if not self.cached_tensor_outputs:
self._initialize_cached_tensors()
outputs: OutputType = []
for i, (storage_info, metadata) in enumerate(
zip(self.output_storage_alias, self.outputs_metadata)
):
if not isinstance(metadata, dict):
if not isinstance(metadata, (int, type(None))):
raise RuntimeError("check isinstance(metadata, (int, type(None))) fail")
outputs.append(metadata)
continue
cached_t = self.cached_tensor_outputs[i]
if cached_t is not None:
if cached_t._backward_hooks is not None:
cached_t._backward_hooks = None
outputs.append(cached_t)
continue
static_t = self.static_output_tensors[i]
if static_t is not None:
if self.outputs_weakrefs[i] is not None:
raise RuntimeError("check self.outputs_weakrefs[i] is None fail")
outputs.append(static_t)
continue
storage = self.prepare_alias_info_for_tensor_construction(
storage_info, metadata
)
if isinstance(storage, UntypedStorage) or storage is None:
out = self._reconstruct_from_tensor_metadata(metadata, storage)
else:
if not isinstance(storage, int):
raise RuntimeError("check isinstance(storage, int) fail")
out = self._reconstruct_from_tensor_metadata(
metadata, cast(torch.Tensor, outputs[storage]).untyped_storage()
)
outputs.append(out)
w = self.outputs_weakrefs[i]
if w is None:
raise RuntimeError("check w is not None fail")
w.swap_weakref(out.untyped_storage()._weak_ref())
return outputs
def prepare_alias_info_for_tensor_construction(
self,
out_alias_info: Optional[OutputAliasInfo],
metadata: Union[Dict[str, Any], int, None],
) -> Union[UntypedStorage, None, int]:
if (
isinstance(metadata, (int, type(None)))
or out_alias_info is UnaliasedStorage
):
return None
if isinstance(out_alias_info, AliasesPriorGraphOutput):
depth, existing_output_index = out_alias_info.index
ref = self.path_weakrefs[depth][existing_output_index]
if ref is None:
raise RuntimeError("check ref is not None fail")
return torch.UntypedStorage._new_with_weak_ptr(ref())
if not isinstance(out_alias_info, AliasesNewOutput):
raise RuntimeError("check isinstance(out_alias_info, AliasesNewOutput) fail")
return out_alias_info.index
def prepare_storages_for_construction(
self,
) -> List[Union[UntypedStorage, None, int]]:
output_storages = []
for output_storage_alias, metadata in zip(
self.output_storage_alias, self.outputs_metadata
):
output_storages.append(
self.prepare_alias_info_for_tensor_construction(
output_storage_alias, metadata
)
)
return output_storages
def run_graph(self) -> None:
if self.graph is None:
raise RuntimeError("check self.graph is not None fail")
self.graph.replay()
def all_outputs_are_dead(self) -> bool:
"All outputs of the path from this node to its root are dead"
for depth, output_index in self.live_indices_after_graph:
if is_live(self.path_weakrefs[depth][output_index]):
return False
return True
def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType:
"Record the model"
def static_input_iter() -> Generator[torch.Tensor, None, None]:
for i in self.wrapped_function.static_input_idxs:
_inp = inputs[i]
if isinstance(
_inp, torch.Tensor
) and not self._is_npu_graph_recorded_tensor(_inp) and _inp.device.type != "cpu":
yield _inp
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {}
for inp in itertools.chain(static_input_iter(), self.wrapped_function.constants):
static_input_persistent_storage_ptrs[inp.untyped_storage().data_ptr()] = StorageWeakRefWrapper(inp)
if config.triton.slow_path_cudagraph_asserts:
memory = (
[] if self.parent is None else list(self.parent.path_live_weakrefs())
)
def _check_elem(idxs, elem):
return (
isinstance(elem, torch.Tensor)
and idxs not in self.wrapped_function.static_input_idxs
and elem.untyped_storage().data_ptr() != 0
)
memory += [
StorageWeakRefWrapper(elem)
for i, elem in enumerate(inputs)
if _check_elem(i, elem) and elem.device.type != "cpu"
]
check_memory_pool(self.device, self.npu_graphs_pool, memory)
cpu_tensor = None
for item in inputs:
if isinstance(item, torch.Tensor) and item.dtype == torch.int32 and item.device.type == "cpu":
cpu_tensor = item.clone()
del item
with preserve_rng_state(), torch.npu.device(
self.device
), clear_cublas_manager(), torch.npu.graph(
self.graph,
stream=self.stream,
pool=self.npu_graphs_pool,
capture_error_mode="thread_local",
auto_dispatch_capture=True,
), get_history_recording():
static_outputs = model(inputs)
if cpu_tensor is not None:
self.graph.update(cpu_update_input=[{"context_lens": cpu_tensor},
{"actual_seq_lengths_kv": cpu_tensor}])
if not len(inputs) == 0:
raise RuntimeError("check len(inputs) == 0 fail")
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs)
log.debug("NPUGRAPH-TREE Node Record node=%s recorded: outputs=%d, "
"non_static_inputs=%d, static_input_idxs=%d",
self.id, len(static_outputs),
len(self.non_static_input_idx), len(self.static_input_idxs))
return static_outputs
def _add_first_outputs(
self,
outputs: OutputType,
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
) -> None:
"Add the outputs from the first invocation of the node and set up metadata"
prev_liveness = self.recorded_liveness_before_graph
curr_liveness = self._get_liveness(self.path_weakrefs)
delta = self._get_different_indices(prev_liveness, curr_liveness)
self.expected_dead_indices_after_graph = delta
if not len(self.outputs_weakrefs) == 0:
raise RuntimeError("check len(self.outputs_weakrefs) == 0 fail")
output_new_storages_index: Dict[StorageDataPtr, int] = {}
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
self.static_output_tensors = [None for _ in range(len(outputs))]
for index_, out_ in enumerate(outputs):
from torch_npu._inductor import config as npu_config
if out_ is None or not isinstance(out_, torch.Tensor) or (npu_config.npugraph_trees.disable_cpu_input_check and out_.is_cpu):
self.output_storage_alias.append(UnaliasedStorage)
continue
torch._check(
out_.is_npu or out_.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all npu outputs in npu graph recording. Non npu output "
f"from {self.stack_traces[index_] if self.stack_traces else '(unknown)'}"
),
),
ref = static_input_persistent_storage_ptrs.get(
out_.untyped_storage().data_ptr(), None
)
is_empty_storage = out_.untyped_storage().data_ptr() == 0
if (ref and ref() is not None) or is_empty_storage:
self.output_storage_alias.append(None)
self.static_output_tensors[index_] = out_
continue
path_ref = self._is_alias_of_live_recorded_tensor(out_)
if path_ref is not None:
self._mark_prior_graph_output_as_aliased(path_ref)
for idx, inp_path_ref in enumerate(
self.live_npugraph_managed_path_refs
):
if path_ref == inp_path_ref:
self.preserved_aliased_inputs[idx] = True
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
continue
if out_.untyped_storage().data_ptr() in output_new_storages_index:
index = output_new_storages_index[out_.untyped_storage().data_ptr()]
self.unaliased_in_all_paths[index] = False
self.output_storage_alias.append(AliasesNewOutput(index))
continue
output_new_storages_index[out_.untyped_storage().data_ptr()] = index_
self.output_storage_alias.append(UnaliasedStorage)
self.unaliased_in_all_paths[index_] = True
if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))]
else:
if not len(self.stack_traces) == len(outputs):
raise RuntimeError("Wrong number of stack traces passed in")
if self.outputs_weakrefs:
raise RuntimeError("check self.outputs_weakrefs is None fail")
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
self.outputs_weakrefs.append(None)
self.tensor_weakrefs.append(None)
else:
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
self.tensor_weakrefs.append(TensorWeakRef(out))
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
self.checkpointed_caching_state = torch_npu._C._npu_getCheckpointState(
self.device, self.npu_graphs_pool
)
for depth, path_weakref in enumerate(self.path_weakrefs):
for output_index, weakref_ in enumerate(path_weakref):
if is_live(weakref_):
self.live_indices_after_graph.append((depth, output_index))
self.debug_check_invariants_after_invocation()
if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(
self.device, self.npu_graphs_pool, list(self.path_live_weakrefs())
)
def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None:
"Remove a graph output from the unaliased, cached tensors in an ancestor node"
depth, output_index = index
node = list(self._path_from_root)[depth]
node.unaliased_in_all_paths[output_index] = False
x = self.path_weakrefs[depth][output_index]
if x is None:
raise RuntimeError("check x is not None fail")
x.remove_extra_reference()
def _initialize_cached_tensors(self) -> None:
if not len(self.outputs_weakrefs) == len(self.outputs_metadata):
raise RuntimeError("check len(self.outputs_weakrefs) == len(self.outputs_metadata) fail")
for i, (storage_info, metadata, make_cached) in enumerate(
zip(
self.output_storage_alias,
self.outputs_metadata,
self.unaliased_in_all_paths,
)
):
if not make_cached:
self.cached_tensor_outputs.append(None)
continue
if storage_info is not UnaliasedStorage:
raise RuntimeError("check storage_info is UnaliasedStorage fail")
if not isinstance(metadata, dict):
raise RuntimeError("check isinstance(metadata, dict) fail")
s = self.create_storage(metadata)
out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
torch_npu._C._add_cached_tensor(out)
self_ref = weakref.ref(self)
def check_refcount(ref_count, ref_self):
self_loc = ref_self()
if self_loc is None:
return False
return self_loc.get_output_refcount(ref_count) == 2
check = functools.partial(check_refcount, ref_count=i, ref_self=self_ref)
self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
self.cached_tensor_outputs.append(out)
def get_output_refcount(self, index: int) -> int:
return sys.getrefcount(self.cached_tensor_outputs[index])
@property
def parent(self) -> Optional[NPUGraphNode]:
"unwraps the weakref to _parent"
return self._parent() if self._parent is not None else None
@property
def _path_to_root(self) -> Generator[NPUGraphNode, None, None]:
"Returns all nodes in the path starting at self and ending at root"
node = self
while node:
yield node
node = node.parent
@property
def _path_from_root(self) -> Generator[NPUGraphNode, None, None]:
"Returns all nodes in the path starting at the root and ending at self"
nodes = reversed(list(self._path_to_root))
yield from nodes
def _is_npu_graph_recorded_tensor(self, t: torch.Tensor) -> bool:
"Is this tensor an output of a node in this path"
for output_refs in self.path_weakrefs:
for storage_weak_ref in output_refs:
if storage_weak_ref is None:
continue
data_ptr = storage_weak_ref.data_ptr()
if t.untyped_storage().data_ptr() == data_ptr:
return True
return False
def _is_alias_of_live_recorded_tensor(
self, t: torch.Tensor
) -> Optional[PathOutputIndex]:
for depth, output_refs in enumerate(self.path_weakrefs):
for output_index, storage_ref in enumerate(output_refs):
storage_and_ptr = maybe_deref(storage_ref)
if storage_and_ptr is not None:
storage, ptr = storage_and_ptr
if ptr == t.untyped_storage().data_ptr():
return (depth, output_index)
return None
@staticmethod
def _check_liveness(
indices: List[PathOutputIndex],
output_refs: List[List[Optional[StorageWeakRefWrapper]]],
) -> bool:
"Check that all of the indices specified are dead references"
for depth, output_index in indices:
w = output_refs[depth][output_index]
if w is None:
raise RuntimeError("check w is not None fail")
if w() is not None:
return False
return True
def add_child(self, function_id: FunctionID, node: NPUGraphNode) -> None:
"Adds node as a a child of self"
self.children[function_id].append(node)
@staticmethod
def _get_different_indices(
prev: List[List[bool]], curr: List[List[bool]]
) -> List[PathOutputIndex]:
"Find indices where the two lists differ."
dead_indices = []
if not len(prev) <= len(curr):
raise RuntimeError("check len(prev) <= len(curr) fail")
for i, (outputs1, outputs2) in enumerate(zip(prev, curr)):
if not len(outputs1) == len(outputs2):
raise RuntimeError("check len(outputs1) == len(outputs2) fail")
for j, (output1, output2) in enumerate(zip(outputs1, outputs2)):
if output1 != output2:
dead_indices.append((i, j))
return dead_indices
@staticmethod
def _get_liveness(
weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
) -> List[List[bool]]:
"Maps weakrefs to true if the reference is alive and false otherwise"
if len(weakrefs) == 0:
return []
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
def debug_assert_invariants(
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
) -> None:
if not config.triton.fast_path_cudagraph_asserts:
return
for i, node in enumerate(self._path_from_root):
if self.path_weakrefs[i] is not node.outputs_weakrefs:
raise RuntimeError("check self.path_weakrefs[i] is node.outputs_weakrefs fail")
nodes = list(self._path_from_root)
live_blocks = get_block_addrs(self.npu_graphs_pool)
live_storage_data_ptrs = set()
live_storage_weak_ptrs = set()
for depth, outputs_liveness in enumerate(expected_liveness):
for output_idx, output_liveness in enumerate(outputs_liveness):
w = self.path_weakrefs[depth][output_idx]
stor_weak_ptr_and_data_ptr = maybe_deref(w)
if stor_weak_ptr_and_data_ptr is not None:
if output_liveness is None:
raise RuntimeError("check output_liveness is not None fail")
stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr
if not (stor_data_ptr in live_storage_data_ptrs) == (stor_weak_ptr in live_storage_weak_ptrs):
raise RuntimeError("check (stor_data_ptr in live_storage_data_ptrs) == (stor_weak_ptr in live_storage_weak_ptrs) fail")
live_storage_data_ptrs.add(stor_data_ptr)
live_storage_weak_ptrs.add(stor_weak_ptr)
is_persistent_alias = (
nodes[depth].static_output_tensors[output_idx] is not None
)
if is_persistent_alias:
if stor_data_ptr in live_blocks:
raise RuntimeError("check stor_data_ptr not in live_blocks fail")
for depth, output_index in newly_dead:
if is_live(self.path_weakrefs[depth][output_index]):
raise RuntimeError("check not is_live(self.path_weakrefs[depth][output_index]) fail")
def debug_check_invariants_before_invocation(self) -> None:
self.debug_assert_invariants(
self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph
)
def debug_check_invariants_after_invocation(self) -> None:
self.debug_assert_invariants(
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
)
def data_ptrs_dead_since_invocation(self) -> List[int]:
"""
Since this node was invoked, return data ptrs of all tensor outputs that have died
in the current executing tree path.
"""
curr_liveness = self._get_liveness(self.path_weakrefs)
_get_different_indices = self._get_different_indices(
self.recorded_liveness_after_graph, curr_liveness
)
path = list(self._path_from_root)
ptrs_to_deallocate = []
for depth, output_index in _get_different_indices:
ptrs_to_deallocate.append(
path[depth].outputs_metadata[output_index]["data_ptr"]
)
return ptrs_to_deallocate
def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]:
for i, j in self.live_indices_after_graph:
out = self.path_weakrefs[i][j]
if out is not None and is_live(out):
yield out
def remove_node_cached_tensors(self) -> None:
for t in self.cached_tensor_outputs:
if t is not None:
torch_npu._C._remove_cached_tensor(t)
self.cached_tensor_outputs.clear()
for i, unaliased in enumerate(self.unaliased_in_all_paths):
if unaliased:
n = self.outputs_weakrefs[i]
if n is None:
raise RuntimeError("check n is not None fail")
n.remove_extra_reference()
def remove_path_cached_tensors(self) -> None:
for node in self._path_from_root:
node.remove_node_cached_tensors()
def clear_path_state(self) -> None:
"Clear the path state in this current executing node"
pass
@staticmethod
def _tensor_metadata(
x: torch.Tensor, ignore_storage_offset: bool = True
) -> Dict[str, Any]:
if not isinstance(x, torch.Tensor):
raise RuntimeError("check isinstance(x, torch.Tensor) fail")
return {
"nbytes": x.untyped_storage().nbytes(),
"data_ptr": x.untyped_storage().data_ptr(),
"size": x.shape,
"stride": x.stride(),
"dtype": x.dtype,
"device": x.device,
"storage_offset": x.storage_offset() if not ignore_storage_offset else 0,
}
def _reconstruct_from_tensor_metadata(
self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None
) -> Tensor:
s = self.create_storage(metadata) if storage is None else storage
return torch_npu._C._construct_NPU_Tensor_From_Storage_And_Metadata(metadata, s)
def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage:
return torch_npu._C._construct_storage_from_data_pointer(
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
)
def _allocate_and_copy_recording_inputs(
self, inputs: List[InputType]
) -> List[InputType]:
"""
Allocate inputs for non static, non npugraph managraphed managed tensors in the memory pool
and copy over the tensor values.
"""
torch.npu.synchronize()
self.stream.wait_stream(torch.npu.current_stream())
recording_inputs: List[InputType] = []
with warnings.catch_warnings(record=True), torch.npu.device(
self.device
), _use_npu_memory_pool_manager(
self.device,
mem_pool=self.npu_graphs_pool,
stream=self.stream,
):
for i, inp in enumerate(inputs):
if not isinstance(inp, torch.Tensor):
if not isinstance(inp, int):
raise RuntimeError("check isinstance(inp, int) fail")
recording_inputs.append(inp)
elif i not in self.static_input_idxs:
recording_inputs.append(static_input(inp))
else:
recording_inputs.append(inp)
self._copy_inputs_and_remove_from_src(recording_inputs, inputs)
return recording_inputs
def check_invariants(
self, inputs: List[InputType]
) -> Tuple[CheckInvariantStatus, Callable[..., str]]:
"""
Checks if this node can be run. The same pattern of tensor liveness and tensors
managed in the npugraph private pool must remain stable.
"""
_logger = functools.partial(
log_data_ptr_mismatch,
self.wrapped_function.placeholders,
inputs,
self.static_input_data_ptrs,
)
if not torch_npu._C._tensors_data_ptrs_at_indices_equal(
inputs,
self.static_input_data_ptrs,
self.npugraph_managed_idxs,
):
status = CheckInvariantStatus.CudagraphManagedIdxMismatch
log.debug("NPUGRAPH-TREE Invariant node=%s mismatch: CudagraphManagedIdx, "
"expected_ptrs=%s",
self.id,
[self.static_input_data_ptrs[i] for i in self.npugraph_managed_idxs
if i < len(self.static_input_data_ptrs)])
_logger = functools.partial(
_logger,
self.npugraph_managed_idxs,
status,
)
return status, _logger
if not self._check_liveness(
self.expected_dead_indices_before_graph, self.path_weakrefs
):
status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch
log.debug("NPUGRAPH-TREE Invariant node=%s mismatch: ExpectedDeadIndicesBeforeGraph, "
"expected_dead=%s",
self.id, self.expected_dead_indices_before_graph)
return status, lambda: f"{status}"
if (
self.rerecord_if_static_inputs_change
and not torch_npu._C._tensors_data_ptrs_at_indices_equal(
inputs,
self.static_input_data_ptrs,
self.static_input_idxs,
)
):
status = CheckInvariantStatus.StaticInputIdxMismatch
log.debug("NPUGRAPH-TREE Invariant node=%s mismatch: StaticInputIdx, "
"expected_ptrs_at_static=%s",
self.id,
[self.static_input_data_ptrs[i] for i in self.static_input_idxs
if i < len(self.static_input_data_ptrs)])
_logger = functools.partial(
_logger,
self.static_input_idxs,
status,
)
return status, _logger
for idx in self.npugraph_managed_idxs:
if not self.preserved_aliased_inputs[idx]:
inputs[idx] = None
torch._check(
self._check_liveness(
self.expected_dead_indices_after_graph, self.path_weakrefs
),
lambda: "graph recording observed an input tensor deallocate during graph "
" recording that did not occur during replay. Please file an issue.",
)
return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}"
def num_descendants(self) -> int:
"Total number of descendents of this node"
num_desc = 0
for children in self.children.values():
for child in children:
num_desc += 1
num_desc += child.num_descendants()
return num_desc
def get_npugraph_segments(pool_id: Tuple[int, int]) -> Any:
segments = torch.npu.memory_snapshot()
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[int]:
blocks = []
for segment in get_npugraph_segments(pool_id):
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated" or not live_only:
blocks.append(addr)
addr += block["size"]
return blocks
def format_tb(frames: List[Any]) -> str:
formatted_traceback = []
for entry in frames:
formatted_traceback.append(
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
)
return "".join(traceback.format_list(formatted_traceback))
def check_memory_pool(
device: int,
pool_id: Tuple[int, int],
live_storages_ptrs: List[StorageWeakRefWrapper],
) -> None:
if not all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs):
raise RuntimeError("check all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) fail")
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()}
if torch_npu._C._npu_checkPoolLiveAllocations(device, pool_id, unique_storages):
return
gc.collect()
segments = get_npugraph_segments(pool_id)
allocated_not_in_live_storages = {}
for segment in segments:
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
if addr not in unique_storages:
allocated_not_in_live_storages[addr] = block
else:
unique_storages.remove(addr)
addr += block["size"]
torch._check(
len(unique_storages) == 0,
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
)
if len(allocated_not_in_live_storages) != 0:
formatted = []
for dp, block in allocated_not_in_live_storages.items():
trace = format_tb(block.get("frames", []))
formatted.append(f"Data Pointer: {dp}, history: \n{trace}")
formatted_s = "\n".join(formatted)
msg = (
f"These live storage data ptrs are in the npugraph pool but not "
f"accounted for as an output of npugraph trees: \n\n{formatted_s}"
)
raise RuntimeError(msg)
class ExecutionState(Enum):
"""
Represents the state of the NPUGraph Tree. Will be None if there is no live current memory allocated
in the npu graph pool. Otherwise will reflect the state of the most recently executed node.
"""
NONE = auto()
WARMUP = auto()
RECORDING = auto()
EXECUTION = auto()
class CompilationMode(Enum):
FORWARD = auto()
BACKWARD = auto()
INFERENCE = auto()
class NPUGraphTreeManager:
"""
Groups individual recordings or executions of npu graphs into a tree of recordings,
and checks required invariants, and manages warmups of graphs.
When graphs are recorded in the same tree, it enforces subsequent execution
to follow the same order and have the same output tensor livespans. To remove
unnecessary coupling of npu graphs (and additional imposed invariants),
the tree manager will end a currently recording tree whenever it is valid - when
the memory pool no longer has any live allocations.
We ignore outputs from a previous generation that correspond to prior model outputs.
Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo.
# make generation increment configurable, warn on overwrite.
We run graph warmups in the npugraph memory pool and return the result on the first invocation
of a function. For many models it is important to reclaim activations as you run the backward.
If we were to warm up the model and keep an extra copy of the inputs around to subsequently
use for recording, we would incur a memory penalty. Additionally, if we are part way through training
your model and need to recompile, memory will be allocated to the npu graph pool, so we run this
warmup run in the npu graph memory pool. As for recording, warm up needs the state of live tensors
to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph
replay.
"""
def __init__(self, device_index: int) -> None:
self.roots: Dict[FunctionID, List[NPUGraphNode]] = defaultdict(list)
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {}
self.warmed_up_functions: Set[FunctionID] = set()
self.warned_functions: Set[FunctionID] = set()
torch_npu._C._set_cached_tensors_enabled(True)
self.warned_mutation: Set[FunctionID] = set()
with torch.npu.device(device_index):
torch.npu.synchronize()
self.stream = torch.npu.Stream()
self.stream.wait_stream(torch.npu.current_stream())
self.graph: Optional[torch.npu.NPUGraph] = torch.npu.NPUGraph()
self.npu_graphs_thread_pool = torch.npu.graph_pool_handle()
with warnings.catch_warnings(record=True), torch.npu.graph(
self.graph,
pool=self.npu_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
):
pass
self.graph_counter = itertools.count(0)
self.func_counter = itertools.count(0)
log.debug("NPUGRAPH-TREE Manager init, device=%s, pool=(%s,%s), stream=%s",
device_index,
self.npu_graphs_thread_pool[0], self.npu_graphs_thread_pool[1],
self.stream)
self.non_npugraph_managed_mutation_hint: Dict[
Optional[GraphID], Dict[FunctionID, bool]
] = defaultdict(dict)
self.warmup_node_counter = itertools.count(start=-1, step=-1)
self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
self.path_state = ExecutionState.NONE
self.device_index = device_index
self.current_node: Optional[Union[NPUGraphNode, NPUWarmupNode]] = None
self.current_gen: int = -1
self.debug_fail_counter = 0
self.debug_checkpointing_counter = 0
self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
self.running_forwards_with_pending_backwards = False
self.mode: Optional[CompilationMode] = None
self.disable_invalidate_aliases = False
def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
if self.graph is None:
raise RuntimeError("Running NPUGraph after shutdown")
self.mode = self.id_to_mode[function_id]
out = self._run(new_inputs, function_id)
if self.mode == CompilationMode.FORWARD:
self.running_forwards_with_pending_backwards = True
elif self.mode == CompilationMode.BACKWARD:
self.running_forwards_with_pending_backwards = False
return out
def set_to_running_backward(self) -> None:
self.running_forwards_with_pending_backwards = False
self.mode = CompilationMode.BACKWARD
def _get_npu_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]:
return (
self.current_node._is_npu_graph_recorded_tensor
if isinstance(self.current_node, (NPUGraphNode, NPUWarmupNode))
else lambda _: False
)
def new_warmup_node_id(self) -> GraphID:
return GraphID(next(self.warmup_node_counter))
def _update_non_npugraph_managed_mutation(
self, function_id: FunctionID, inputs: List[InputType]
) -> None:
node_id = self._get_node_id()
maybe_mutation_str = check_for_mutation(
self.ids_to_funcs[function_id],
inputs,
self._get_npu_graph_recorded_tensor_checker(),
)
if maybe_mutation_str:
self.non_npugraph_managed_mutation_hint[node_id][function_id] = True
if function_id in self.warned_mutation:
return
self.warned_mutation.add(function_id)
log_cudagraph_skip_and_bump_counter(maybe_mutation_str)
else:
self.non_npugraph_managed_mutation_hint[node_id][function_id] = False
def _get_node_id(self) -> Optional[GraphID]:
if self.current_node is None:
return None
elif isinstance(self.current_node, (NPUGraphNode, NPUWarmupNode)):
return self.current_node.id
else:
raise RuntimeError(f"Unknown node type {type(self.current_node)}")
def exceed_rerecord_limit(
self, node_id: Optional[GraphID], function_id: FunctionID
) -> bool:
if torch._dynamo.config.inline_inbuilt_nn_modules:
return False
return (
self.num_rerecord[node_id][function_id]
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
)
def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
if self.in_recording:
self.try_end_curr_recording(function_id)
if self.in_warmup:
self.try_end_curr_warmup(function_id)
node_id = self._get_node_id()
if function_id not in self.non_npugraph_managed_mutation_hint[node_id]:
self._update_non_npugraph_managed_mutation(function_id, new_inputs)
if self.non_npugraph_managed_mutation_hint[node_id][function_id] or \
self.exceed_rerecord_limit(node_id, function_id):
return self.ids_to_funcs[function_id].model(new_inputs)
if (
not (
function_id in self.warmed_up_functions
or config.triton.skip_cudagraph_warmup
)
or self.in_warmup
or config.triton.force_cudagraphs_warmup
):
if self.path_state == ExecutionState.EXECUTION:
self.apply_checkpoint_execution_state_in_allocator()
with dynamo_timed(
"NPUGraphTreeManager.run_eager",
log_pt2_compile_event=True,
):
out = self.run_eager(new_inputs, function_id)
return out
if isinstance(self.current_node, NPUWarmupNode):
raise RuntimeError("self.current_node is NPUWarmupNode object")
child_nodes = (
self.roots if self.current_node is None else self.current_node.children
)
if not self.in_recording:
unexpected_rerecord, unexpected_rerecord_reason = False, lambda: ""
for child in child_nodes[function_id]:
status, status_logger = child.check_invariants(new_inputs)
if status == CheckInvariantStatus.SUCCESS:
return self.execute_node(child, new_inputs)
if (
status == CheckInvariantStatus.StaticInputIdxMismatch
or status == CheckInvariantStatus.CudagraphManagedIdxMismatch
):
unexpected_rerecord = True
unexpected_rerecord_reason = status_logger
if self.current_node is not None and function_id in self.roots:
self.try_end_curr_execution()
if self.current_node is None:
return self.run(new_inputs, function_id)
if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0:
self._update_non_npugraph_managed_mutation(function_id, new_inputs)
if self.non_npugraph_managed_mutation_hint[self._get_node_id()][
function_id
]:
return self.ids_to_funcs[function_id].model(new_inputs)
if unexpected_rerecord:
curr_node_id = self._get_node_id()
self.num_rerecord[curr_node_id][function_id] += 1
if self.exceed_rerecord_limit(curr_node_id, function_id):
_id = curr_node_id.id if curr_node_id else None
log_cudagraph_skip_and_bump_counter(
f"skipping npuagraph due to function {function_id.id} exceeding max "
f"re-recording limit "
f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) "
f"on npugraph node {_id} due to {unexpected_rerecord_reason()}."
)
return self.ids_to_funcs[function_id].model(new_inputs)
self.debug_fail_counter += 1
self.try_end_curr_execution()
if self.current_node is not None:
self.apply_checkpoint_execution_state_in_allocator()
with dynamo_timed(
"NPUGraphTreeManager.record_function",
log_pt2_compile_event=True,
):
out = self.record_function(new_inputs, function_id)
return out
def shutdown(self) -> None:
"""
Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
might reference a backward which invokes a NPU Graph Node, we have to manually clear them on shutdown
to avoid a reference cycle.
"""
log.debug("NPUGRAPH-TREE Manager shutdown begin, device=%s", self.device_index)
nodes = []
for roots in self.roots.values():
nodes.extend(roots)
while nodes:
node = nodes.pop()
for children in node.children.values():
nodes.extend(children)
node.remove_node_cached_tensors()
node.graph = None
self.graph = None
self.roots = None
self.current_node = None
log.debug("NPUGRAPH-TREE Manager shutdown done, device=%s", self.device_index)
def record_function(
self, new_inputs: List[InputType], function_id: FunctionID
) -> OutputType:
if isinstance(self.current_node, NPUWarmupNode):
raise RuntimeError("self.current_node is NPUWarmupNode object")
graph_id = self.new_graph_id()
log.debug("NPUGRAPH-TREE Node Record function=%s, graph=%s", function_id.id, graph_id.id)
torch.npu.synchronize()
node = NPUGraphNode(
self.ids_to_funcs[function_id],
graph_id,
self.current_node,
new_inputs,
self.npu_graphs_thread_pool,
self.device_index,
self.ids_to_stack_traces[function_id],
self.stream,
)
if self.current_node is None:
self.roots[function_id].append(node)
else:
self.current_node.add_child(function_id, node)
self.current_node = node
self.path_state = ExecutionState.RECORDING
self.update_generation()
log.debug("NPUGRAPH-TREE State state=RECORDING, node=%s, gen=%d",
graph_id.id, self.current_gen)
torch.npu.synchronize()
return node.run_first_inputs(new_inputs)
def execute_node(
self, node: NPUGraphNode, new_inputs: List[InputType]
) -> OutputType:
self.current_node = node
self.path_state = ExecutionState.EXECUTION
self.update_generation()
log.debug("NPUGRAPH-TREE State state=EXECUTION, node=%s, gen=%d",
node.id, self.current_gen)
log.debug("NPUGRAPH-TREE Execute node=%s", self.current_node.id)
return node.run(new_inputs)
def run_eager(
self, new_inputs: List[InputType], function_id: FunctionID
) -> OutputType:
already_warm = function_id in self.warmed_up_functions
if not already_warm:
log.debug("NPUGRAPH-TREE Warmup Running warmup, function=%s", function_id)
else:
log.debug("NPUGRAPH-TREE Eager Running eager (ancestor warmup), function=%s", function_id)
self.warmed_up_functions.add(function_id)
node = NPUWarmupNode(
self.ids_to_funcs[function_id],
self.current_node,
self.npu_graphs_thread_pool,
self.graph,
self.device_index,
self.ids_to_stack_traces[function_id],
self.stream,
already_warm,
self.new_warmup_node_id(),
)
self.current_node = node
self.path_state = ExecutionState.WARMUP
self.update_generation()
log.debug("NPUGRAPH-TREE State state=WARMUP, node=%s, already_warm=%s, gen=%d",
node.id, already_warm, self.current_gen)
return node.run(new_inputs)
def new_graph_id(self) -> GraphID:
return GraphID(next(self.graph_counter))
def new_func_id(self) -> FunctionID:
return FunctionID(next(self.func_counter))
def add_function(
self,
model: ModelType,
inputs: List[InputType],
static_input_idxs: Sequence[int],
stack_traces: Optional[StackTraces],
mode: CompilationMode,
constants: Tuple[torch.Tensor, ...],
placeholders: Tuple[PlaceholderInfo, ...],
mutated_input_idxs: Tuple[int, ...],
) -> Tuple[ModelType, OutputType]:
id_for_func = self.new_func_id()
self.ids_to_stack_traces[id_for_func] = stack_traces
self.ids_to_funcs[id_for_func] = WrappedFunction(
model,
list(static_input_idxs),
id_for_func,
tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_npu),
placeholders,
mutated_input_idxs,
)
self.id_to_mode[id_for_func] = mode
fn = functools.partial(self.run, function_id=id_for_func)
get_container(self.device_index).add_strong_reference(fn)
return fn, fn(inputs)
@property
def in_recording(self) -> bool:
return self.path_state == ExecutionState.RECORDING
@property
def in_warmup(self) -> bool:
return self.path_state == ExecutionState.WARMUP
def get_roots(self) -> Iterator[NPUGraphNode]:
for nodes in self.roots.values():
yield from nodes
@property
def current_node(self) -> Optional[Union[NPUGraphNode, NPUWarmupNode]]:
return self._current_node
@current_node.setter
def current_node(
self, value: Optional[Union[NPUGraphNode, NPUWarmupNode]]
) -> None:
self._current_node = value
if value is None:
self.path_state = ExecutionState.NONE
def update_generation(self) -> None:
self.current_gen = self.get_curr_generation()
@staticmethod
def get_curr_generation() -> int:
if MarkStepBox.mark_step_counter != 0:
return MarkStepBox.mark_step_counter
return GenerationTracker.generation
@staticmethod
def user_invoked_mark_step() -> bool:
return MarkStepBox.mark_step_counter != 0
def can_start_new_generation(self) -> bool:
if not self.in_new_torch_compile_invocation():
return False
if self.user_invoked_mark_step():
return True
return not self.running_forwards_with_pending_backwards
def in_new_torch_compile_invocation(self) -> bool:
return self.current_gen != self.get_curr_generation()
def try_end_curr_recording(self, function_id: FunctionID) -> None:
"""
Check if the current recording can be terminated, either because all outputs of the
previously recorded node are dead or because it was executed in a different
generation. Will set current_node to None and in_recording to False if successful.
"""
if not self.in_recording:
raise RuntimeError("check self.in_recording fail")
if self.current_node is None:
raise RuntimeError("check self.current_node is not None fail")
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.clear_current_path_state_and_set_to_none()
return
if self.current_node.all_outputs_are_dead():
self.clear_current_path_state_and_set_to_none()
return
self.check_warn_on_unable_to_start_executing(function_id)
def try_end_curr_execution(self) -> None:
"""
Check if the current executing node can be terminated, either because all outputs of the
previously executed node are dead or because it was executed in a different generation.
Will set current_node to None if successful.
"""
if self.in_recording:
raise RuntimeError("check not self.in_recording fail")
if self.current_node is None:
return
if self.can_start_new_generation():
self.clear_current_path_state_and_set_to_none()
return
if self.current_node.all_outputs_are_dead():
self.clear_current_path_state_and_set_to_none()
def try_end_curr_warmup(self, function_id: FunctionID) -> None:
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.current_node = None
return
if self.current_node is None:
raise RuntimeError("check self.current_node is not None fail")
if self.current_node.all_outputs_are_dead():
self.current_node = None
return
self.check_warn_on_unable_to_start_executing(function_id)
def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None:
"Warn if we in a potential loop where we are unable to hit fast path"
if (
function_id in self.warned_functions
or not self.in_new_torch_compile_invocation()
):
return
if self.current_node is None:
raise RuntimeError("check self.current_node is not None fail")
existing_nodes = [
node
for node in self.current_node._path_from_root
if node.wrapped_function.id == function_id
]
if len(existing_nodes) <= 1:
return
parents = {
n.parent.wrapped_function.id
for n in itertools.chain(existing_nodes, (self.current_node,))
if n.parent is not None
}
if len(parents) == len(existing_nodes):
return
self.warned_functions.add(function_id)
warnings.warn(
"Unable to hit fast path of NPUGraphs because of pending, uninvoked backwards. "
"Consider running with torch.no_grad() or using torch.compiler.npugraph_mark_step_begin() "
"before each model invocation"
)
@staticmethod
def format_dealloc_msg(stack_trace: Optional[str]) -> str:
stack_trace = (
stack_trace.strip() if stack_trace else "[Could not find stack trace]"
)
return (
"Error: accessing tensor output of NPUGraphs that has been overwritten by a subsequent run. "
f"Stack trace: {stack_trace}. "
"To prevent overwriting, clone the tensor outside of torch.compile() "
"or call torch.compiler.npugraph_mark_step_begin() before each model invocation."
)
def dealloc_current_path_weakrefs(self) -> None:
if self.current_node is None:
raise RuntimeError("check self.current_node is not None fail")
stor_stack_trace: Dict[int, Optional[str]] = {}
for node in self.current_node._path_from_root:
if not len(node.tensor_weakrefs) == len(node.stack_traces):
raise RuntimeError("check len(node.tensor_weakrefs) == len(node.stack_traces) fail")
for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces):
ten = None if t is None else t()
if ten is None:
continue
torch_npu._C._set_storage_access_error_msg(
ten, self.format_dealloc_msg(stack_trace)
)
if self.disable_invalidate_aliases:
continue
for storage_ref, stack_trace in zip(
node.outputs_weakrefs, node.stack_traces
):
if not storage_ref:
continue
stor_stack_trace[storage_ref.data_ptr()] = stack_trace
deleted = set()
for storage_ref in self.current_node.path_live_weakrefs():
_storage_deref = storage_ref()
if _storage_deref and storage_ref.data_ptr() not in deleted:
deleted.add(storage_ref.data_ptr())
msg = self.format_dealloc_msg(
stor_stack_trace.get(storage_ref.data_ptr())
)
torch_npu._C._free_And_Remove_DeleterFn(_storage_deref)
if self.disable_invalidate_aliases:
continue
torch_npu._C._set_storage_data_ptr_access_error_msg(_storage_deref, msg)
def clear_current_path_state_and_set_to_none(self) -> None:
if not isinstance(self.current_node, NPUGraphNode):
raise RuntimeError("check self.current_node is NPUGraphNode object fail")
self.current_node.clear_path_state()
self.current_node = None
def apply_checkpoint_execution_state_in_allocator(self) -> None:
"""
Checkpoint the current execution state in the caching allocator so that
additional npugraph recordings can be made respecting existent live storages.
"""
if not isinstance(self.current_node, NPUGraphNode):
raise RuntimeError("check self.current_node is NPUGraphNode object fail")
self.debug_checkpointing_counter += 1
log.debug(
"Checkpointing cuda caching allocator state. Number of checkpoints %d",
self.debug_checkpointing_counter,
)
state = self.current_node.checkpointed_caching_state
device = self.current_node.device
if state is None or device is None:
raise RuntimeError("check state is not None and device is not None fail")
stale_storages: List[int] = []
self.current_node.remove_path_cached_tensors()
live_storages_wrappers = list(self.current_node.path_live_weakrefs())
live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers]
ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation()
log.debug("NPUGRAPH-TREE Checkpoint detail: device=%s, pool=(%s,%s), "
"live_storages=%d, dead_ptrs_count=%d",
device,
self.npu_graphs_thread_pool[0], self.npu_graphs_thread_pool[1],
len(live_storages_weak_refs), len(ptrs_to_deallocate))
torch_npu._C._npu_setCheckpointPoolState(
device, state, stale_storages, live_storages_weak_refs
)
for ptr in set(ptrs_to_deallocate):
torch_npu._C._npu_npuCachingAllocator_raw_delete(ptr)
if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(
self.device_index, self.npu_graphs_thread_pool, live_storages_wrappers
)
for wrapper in live_storages_wrappers:
if not wrapper():
raise RuntimeError("check wrapper() fail")
if not torch_npu._C._has_Standard_Deleter(wrapper()):
raise RuntimeError("check torch_npu._C._has_Standard_Deleter(wrapper()) fail")
if wrapper.data_ptr() in ptrs_to_deallocate:
raise RuntimeError("check wrapper.data_ptr() not in ptrs_to_deallocate fail")
def live_npugraph_pool_storages_in_curr_execution(
self,
) -> List[StorageWeakRefPointer]:
if self.current_node is None:
return []
return [t() for t in self.current_node.path_live_weakrefs()]