import sys
import logging
import re
import functools
import textwrap
import traceback
import inspect
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
import torch
import torch.cuda._sanitizer as csan
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
import torch_npu
logger = logging.getLogger(__name__)
FACTORY_FUNCTION_REGEX = re.compile("(new_.*|.*_like)")
@dataclass
class CrossStreamUsage:
"""Records when a tensor is used on a stream different from its allocation stream."""
usage_stream: csan.StreamId
seq_num: csan.SeqNum
operator: str
stack_trace: traceback.StackSummary
@dataclass
class NPUTensorInfo:
"""Tracks tensor allocation and cross-stream usage for record_stream detection.
Maintained independently from the parent EventHandler's TensorInfo, which tracks
read/write accesses for data race detection.
"""
allocation_stream: Optional[csan.StreamId] = None
allocation_stack_trace: Optional[traceback.StackSummary] = None
recorded_streams: Set[csan.StreamId] = field(default_factory=set)
cross_stream_usages: Dict[csan.StreamId, CrossStreamUsage] = field(default_factory=dict)
class MissingRecordStreamError(csan.SynchronizationError):
"""Tensor was used across streams without record_stream or proper sync.
Detected at tensor deallocation or via flush_record_stream_warnings().
Per PyTorch docs, record_stream is NOT required if creation_stream has been
synchronized to wait for the usage_stream before deallocation:
- creation_stream.wait_stream(usage_stream)
- creation_stream.wait_event(event_on_usage_stream)
- torch.npu.synchronize() (device-level sync covers all directions)
"""
def __init__(
self,
data_ptr: csan.DataPtr,
allocation_stack_trace: Optional[traceback.StackSummary],
allocation_stream: csan.StreamId,
usage_stream: csan.StreamId,
usage: CrossStreamUsage,
recorded_streams: set[csan.StreamId],
):
self.data_ptr = data_ptr
self.allocation_stack_trace = allocation_stack_trace
self.allocation_stream = allocation_stream
self.usage_stream = usage_stream
self.usage = usage
self.recorded_streams = recorded_streams
def __repr__(self):
return (
f"MissingRecordStreamError(data_ptr={self.data_ptr}, "
f"alloc_stream={self.allocation_stream}, "
f"usage_stream={self.usage_stream}, "
f"operator='{self.usage.operator}')"
)
def __str__(self):
result = textwrap.dedent(
f"""\
============================
NPUSanitizer: missing record_stream detected!
Tensor (data ptr: {self.data_ptr}) allocated on stream {self.allocation_stream}
was used on stream {self.usage_stream} without record_stream or
creation_stream.wait_stream(usage_stream).
This may cause use-after-free if the caching allocator reuses memory
on the allocation stream before the usage stream finishes.
Fix with ONE of:
A) tensor.record_stream(stream) — tell allocator about the usage
B) creation_stream.wait_stream(usage_stream) before deallocation
Cross-stream usage during kernel:
{self.usage.operator}
"""
)
result += f"With stack trace:\n{''.join(self.usage.stack_trace.format())}\n"
if self.recorded_streams:
result += f"Streams recorded via record_stream: {self.recorded_streams}\n"
else:
result += "No streams were recorded via record_stream.\n"
if self.allocation_stack_trace:
result += (
"Tensor was allocated with stack trace:\n"
f"{''.join(self.allocation_stack_trace.format())}"
)
return result
class NPURecordStreamHandler(csan.EventHandler):
"""EventHandler with deferred record_stream detection.
Record_stream checks are deferred to deallocation time (or manual flush via
flush_record_stream_warnings), because what matters for memory safety is whether
creation_stream has synced with usage_stream BEFORE the tensor's memory is
reused — not at the time of the cross-stream kernel launch.
This avoids false positives when the user plans to sync after the kernel launch
but before tensor deallocation, which is the typical usage pattern.
"""
def __init__(self) -> None:
super().__init__()
self._npu_tensors: dict[csan.DataPtr, NPUTensorInfo] = {}
self.record_stream_errors: list[MissingRecordStreamError] = []
def _handle_memory_allocation(self, data_ptr: csan.DataPtr) -> None:
super()._handle_memory_allocation(data_ptr)
alloc_trace = None
try:
alloc_trace = self.tensors_accessed.get_allocation_stack_trace(data_ptr)
except KeyError:
pass
current_stream: Optional[csan.StreamId] = None
try:
current_stream = int(torch_npu.npu.current_stream().npu_stream)
except RuntimeError:
pass
self._npu_tensors[data_ptr] = NPUTensorInfo(
allocation_stream=current_stream,
allocation_stack_trace=alloc_trace,
)
def _handle_memory_deallocation(self, data_ptr: csan.DataPtr) -> None:
if data_ptr in self._npu_tensors:
for error in self._get_record_stream_errors(data_ptr):
print(error, file=sys.stderr)
self.record_stream_errors.append(error)
del self._npu_tensors[data_ptr]
super()._handle_memory_deallocation(data_ptr)
def _handle_kernel_launch(
self,
stream: csan.StreamId,
read_only: set[csan.DataPtr],
read_write: set[csan.DataPtr],
outputs: set[csan.DataPtr],
operator: str,
tensor_aliases: dict[int, list[str]],
storage_dataptrs_accessed: Optional[Set[csan.DataPtr]] = None
) -> List[csan.SynchronizationError]:
errors = super()._handle_kernel_launch(
stream, read_only, read_write, outputs, operator, tensor_aliases
)
accessed = (
storage_dataptrs_accessed
if storage_dataptrs_accessed is not None
else (read_only | read_write)
)
self._record_cross_stream_usage(stream, accessed, operator)
return errors
def _record_cross_stream_usage(
self, stream: csan.StreamId, all_accessed: set[csan.DataPtr], operator: str
) -> None:
"""Record that tensors are being accessed on a non-allocation stream.
Stack trace is captured lazily (only once per kernel launch) to avoid
redundant walks when multiple tensors are accessed in the same kernel.
Records the current seq_num so that sync checks can verify the sync
happened AFTER this usage, not just from stream creation inheritance.
"""
stack_trace = None
current_seq = self.seq_num
for data_ptr in all_accessed:
info = self._npu_tensors.get(data_ptr)
if info is None or info.allocation_stream is None:
continue
if info.allocation_stream == stream:
continue
existing = info.cross_stream_usages.get(stream)
if existing is not None:
existing.seq_num = current_seq
existing.operator = operator
if stack_trace is None:
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()),
lookup_lines=False,
)
stack_trace.reverse()
existing.stack_trace = stack_trace
continue
if stack_trace is None:
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()),
lookup_lines=False,
)
stack_trace.reverse()
info.cross_stream_usages[stream] = CrossStreamUsage(
usage_stream=stream,
seq_num=current_seq,
operator=operator,
stack_trace=stack_trace,
)
def _get_record_stream_errors(self, data_ptr: csan.DataPtr) -> List[MissingRecordStreamError]:
info = self._npu_tensors.get(data_ptr)
if info is None or info.allocation_stream is None:
return []
errors = []
for usage_stream, usage in info.cross_stream_usages.items():
if usage_stream in info.recorded_streams:
continue
if self._is_creation_stream_synced_to_usage(
info.allocation_stream, usage_stream, usage.seq_num
):
continue
errors.append(MissingRecordStreamError(
data_ptr=data_ptr,
allocation_stack_trace=info.allocation_stack_trace,
allocation_stream=info.allocation_stream,
usage_stream=usage_stream,
usage=usage,
recorded_streams=info.recorded_streams.copy(),
))
return errors
def flush_record_stream_warnings(self) -> List[MissingRecordStreamError]:
"""Check all tracked tensors for missing record_stream and print errors.
Call after all stream operations are complete (including any synchronization)
to detect tensors used cross-stream without record_stream or
creation-to-usage synchronization.
Errors are printed to stderr and appended to self.record_stream_errors.
"""
errors = []
for data_ptr in list(self._npu_tensors):
new_errors = self._get_record_stream_errors(data_ptr)
for error in new_errors:
print(error, file=sys.stderr)
errors.extend(new_errors)
self.record_stream_errors.extend(errors)
return errors
def _is_creation_stream_synced_to_usage(
self,
creation_stream: csan.StreamId,
usage_stream: csan.StreamId,
usage_seq_num: csan.SeqNum = 0,
) -> bool:
"""Check if creation stream has synced with usage stream's operations.
Compares against usage_seq_num to distinguish real synchronization from
stream creation inheritance. Stream creation sets initial sync state to 0,
but actual kernel seq_nums start at 1, so comparing >= usage_seq_num
ensures we detect real sync operations rather than inherited initial state.
"""
try:
creation_state = self.syncs.current_sync_states.get(creation_stream, {})
return creation_state.get(usage_stream, -1) >= usage_seq_num
except (AttributeError, KeyError):
return False
def _handle_record_stream(self, data_ptr: csan.DataPtr, stream: csan.StreamId) -> None:
"""Track a record_stream call for memory safety checking."""
if data_ptr not in self._npu_tensors:
self._npu_tensors[data_ptr] = NPUTensorInfo()
self._npu_tensors[data_ptr].recorded_streams.add(stream)
def _handle_erase_stream(
self, data_ptr: csan.DataPtr, stream: csan.StreamId
) -> None:
"""Track eraseStream after a communication work no longer owns a stream."""
info = self._npu_tensors.get(data_ptr)
if info is None:
return
info.recorded_streams.discard(stream)
class NPUArgumentHandler:
def __init__(self):
self.dataptrs_read: set[csan.DataPtr] = set()
self.dataptrs_written: set[csan.DataPtr] = set()
self.tensor_aliases: dict[int, list[str]] = {}
self.outputs: set[csan.DataPtr] = set()
self.storage_dataptrs_accessed: set[csan.DataPtr] = set()
def _handle_argument(
self,
value,
is_write: bool,
metadata_only: bool,
name: Optional[str] = None,
is_output: bool = False,
) -> None:
if not isinstance(value, torch.Tensor) or not value.is_npu:
return
if metadata_only:
return
data_ptr = value.data_ptr() if value.data_ptr() else id(value)
if is_write:
self.dataptrs_written.add(data_ptr)
else:
self.dataptrs_read.add(data_ptr)
self.tensor_aliases.setdefault(data_ptr, [])
if name is not None:
self.tensor_aliases[data_ptr].append(name)
if is_output:
self.outputs.add(data_ptr)
try:
storage = value.untyped_storage()
if storage is not None:
storage_ptr = storage.data_ptr()
if storage_ptr:
self.storage_dataptrs_accessed.add(storage_ptr)
except (RuntimeError, AttributeError):
pass
def parse_inputs(self, schema, args, kwargs, *, is_factory: bool = False) -> None:
from torch.cuda._sanitizer import zip_arguments
for argument, value in zip_arguments(schema, args, kwargs):
is_write = argument.alias_info is not None and argument.alias_info.is_write
metadata_only = is_factory or (
argument.alias_info is not None and not argument.alias_info.is_write
)
pytree.tree_map_(
functools.partial(
self._handle_argument,
is_write=is_write,
name=argument.name,
metadata_only=metadata_only,
),
value,
)
def parse_outputs(self, schema, outputs, *, is_factory: bool = False) -> None:
from torch.cuda._sanitizer import zip_arguments
for res, value in zip(schema.returns, (outputs,)):
metadata_only = res.alias_info is not None and not res.alias_info.is_write
pytree.tree_map_(
functools.partial(
self._handle_argument,
is_write=True,
metadata_only=metadata_only,
is_output=True,
),
value,
)
class NPUSanitizerDispatchMode(TorchDispatchMode):
def __init__(self, event_handler: csan.EventHandler):
super().__init__()
self.event_handler = event_handler
self.args_handler = None
self.npu_adjust_autograd = [
"adaptive_avg_pool2d", "batch_norm",
"log_softmax", "nll_loss", "to"
]
def enable_autograd(self, aten_api):
if aten_api in self.npu_adjust_autograd:
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
func_name = func.__name__ if hasattr(func, '__name__') else str(func)
if "record_stream" in func_name:
return self._handle_record_stream_op(func, args, kwargs)
is_factory = bool(FACTORY_FUNCTION_REGEX.match(func._schema.name))
self.args_handler = NPUArgumentHandler()
aten_api = func.__name__.split(".")[0]
self.enable_autograd(aten_api)
self.parse_inputs(func._schema, args, kwargs, is_factory=is_factory)
outputs = func(*args, **kwargs)
self.parse_outputs(func._schema, outputs, is_factory=is_factory)
if (
not self.args_handler.dataptrs_read
and not self.args_handler.dataptrs_written
and not self.args_handler.outputs
and not self.args_handler.storage_dataptrs_accessed
):
return outputs
npu_stream = 0
try:
npu_stream = int(torch_npu.npu.current_stream().npu_stream)
except RuntimeError as err:
logger.info(
"Failed to get current stream, ignore this kernel launch record. error info is: %s",
err
)
return outputs
self.check_errors(func, npu_stream)
return outputs
def _handle_record_stream_op(self, func, args, kwargs):
"""Short-circuit record_stream so it isn't treated as a regular kernel launch.
Tracking is done in C++ via the NPURecordStreamCallbacks trace, which is fired
from NpuCachingAllocator::recordStream / NPUPluggableAllocator::recordStream —
the chokepoint that all entry points (aten op, NPUGuardImpl, HCCL/LCCL, RPC,
pluggable allocator) funnel through. Doing tracking here would only cover the
aten-op path and would use tensor.data_ptr() (with view offset), which would
not match the storage data_ptr that the allocation callback uses.
"""
return func(*args, **kwargs)
def parse_inputs(self, schema, args, kwargs, is_factory=False):
self.args_handler.parse_inputs(schema, args, kwargs, is_factory=is_factory)
def parse_outputs(self, schema, outputs, is_factory=False):
self.args_handler.parse_outputs(schema, outputs, is_factory=is_factory)
def check_errors(self, func, npu_stream):
errors = self.event_handler._handle_kernel_launch(
npu_stream,
self.args_handler.dataptrs_read - self.args_handler.dataptrs_written,
self.args_handler.dataptrs_written,
self.args_handler.outputs,
str(func._schema),
self.args_handler.tensor_aliases,
storage_dataptrs_accessed=self.args_handler.storage_dataptrs_accessed,
)
if errors:
for error in errors:
print(error, file=sys.stderr)
raise csan.CUDASanitizerErrors(errors)