import functools
import logging
from collections import defaultdict
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)
import torch
from torch.utils._ordered_set import OrderedSet
from torch._dynamo import utils as dynamo_utils
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.cudagraphs import (
check_for_mutation_ignore_cuda_graph_managed_tensor,
find_input_mutations,
get_device_node_mapping,
get_stack_traces,
)
from torch._dynamo.backends.debugging import boxed_nop
from torch._dynamo.backends.registry import register_backend
from torch._inductor import config
from torch._inductor.compile_fx import (
get_input_idxs_to_check,
index_expanded_dims_and_copy_,
static_input,
)
from torch._inductor.cudagraph_utils import (
_get_use_stack_trace,
format_default_skip_message,
get_mutation_stack_trace,
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
BoxedDeviceIndex,
PlaceholderInfo,
)
from torch._inductor.output_code import get_expanded_dims
from torch._inductor.utils import (
align_inputs_from_check_idxs,
copy_misaligned_inputs,
count_tangents,
get_first_incompatible_cudagraph_node,
num_fw_fixed_arguments,
output_node,
remove_unaligned_input_idxs,
BoxedBool,
InputType,
)
from torch.multiprocessing.reductions import StorageWeakRef
import torch_npu.npu.aclnn
log = logging.getLogger("torch_npu.aclgraph")
def npugraph_mark_step_begin():
from torch_npu.npu._graph_tree import mark_step_begin
mark_step_begin()
def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: Dict[torch.device, torch.fx.Node]
) -> Optional[str]:
from torch_npu._inductor import config as npu_config
if npu_config.npugraph_trees.disable_cpu_input_check:
device_node_mapping.pop(torch.device("cpu"), None)
cpu_node = device_node_mapping.get(torch.device("cpu"))
if cpu_node:
msg = f"cpu device ({cpu_node.name})"
stack_trace = _get_use_stack_trace(cpu_node)
log.info("NPUGraph: skipped — CPU node detected (%s)", cpu_node.name)
log.debug("NPUGRAPH DFX Fix: use .to('npu'), stack:\n%s", stack_trace)
if stack_trace:
return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
return format_default_skip_message(msg)
if (
len(device_node_mapping) == 1
and next(iter(device_node_mapping.keys())).type == "npu"
):
return None
keys_repr = (repr(key) for key in device_node_mapping.keys())
log.info("NPUGraph: skipped — multiple devices (%s)", list(device_node_mapping.keys()))
log.warning("NPUGRAPH DFX Fix: ensure all tensors on same NPU device")
return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
def npugraphify(
model: Callable[..., Any],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
stack_traces: List[Optional[str]],
is_backward: bool,
is_inference: bool,
constants: Tuple[torch.Tensor, ...] = (),
placeholders: Sequence[PlaceholderInfo] = (),
mutated_input_idxs: Tuple[int, ...] = (),
) -> Callable[..., Any]:
from torch_npu.npu._graph_tree import npugraphify_impl as new_npugraphify_impl
npugraphify_fn: Callable[..., Any]
if config.triton.cudagraph_trees:
npugraphify_fn = functools.partial(
new_npugraphify_impl,
device_index=device_index,
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=constants,
placeholders=placeholders,
mutated_input_idxs=mutated_input_idxs,
)
else:
npugraphify_fn = npugraphify_impl
compiled_fn = None
def run(new_inputs: Sequence[InputType]) -> Any:
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.dynamo_timed(
"npugraphify",
log_pt2_compile_event=True,
), dynamo_utils.preserve_rng_state():
compiled_fn = npugraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
return run
def npugraphify_impl(
model: Callable[..., Any],
inputs: List[torch.Tensor],
static_input_idxs: Sequence[int] = (),
) -> Callable[[List[InputType]], Any]:
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
static_input_idxs= OrderedSet(
remove_unaligned_input_idxs(inputs, static_input_idxs)
)
copy_misaligned_inputs(inputs, check_input_idxs)
if not isinstance(inputs, list):
raise RuntimeError("check isinstance(inputs, list) fail")
inps_expanded_dims = [
get_expanded_dims(x) if idx not in static_input_idxs else []
for idx, x in enumerate(inputs)
]
static_inputs = [
x
if not isinstance(x, torch.Tensor)
else static_input(x)
if idx not in static_input_idxs
else x.detach()
for idx, x in enumerate(inputs)
]
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
torch.npu.synchronize()
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
if torch_npu.npu.aclnn._use_static_aclnn_kernel:
from torch_npu._inductor.npu_static_kernel import StaticKernelCompiler
static_kernel_complier = StaticKernelCompiler()
with static_kernel_complier:
with torch.npu.stream(stream):
model(list(static_inputs))
else:
with torch.npu.stream(stream):
model(list(static_inputs))
stream.synchronize()
torch.npu.current_stream().wait_stream(stream)
torch.npu.synchronize()
graph = torch.npu.NPUGraph()
with torch.npu.graph(graph, stream=stream, capture_error_mode="thread_local"):
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
if config.size_asserts:
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
if not len(static_inputs) == len(new_inputs):
raise RuntimeError("check len(static_inputs) == len(new_inputs) fail")
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
):
if not isinstance(dst, torch.Tensor):
continue
if not isinstance(src, torch.Tensor):
raise RuntimeError("check isinstance(src, torch.Tensor) fail")
if idx in static_input_idxs:
if not dst.data_ptr() == src.data_ptr():
raise RuntimeError("check dst.data_ptr() == src.data_ptr() fail")
else:
index_expanded_dims_and_copy_(dst, src, expanded_dims)
new_inputs.clear()
graph.replay()
return static_outputs
else:
copy_indices = [
idx
for idx in range(len(static_inputs))
if idx not in static_input_idxs
]
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
for idx in copy_indices:
expanded_dims = inps_expanded_dims[idx]
src = new_inputs[idx]
if not isinstance(src, torch.Tensor):
raise RuntimeError("check isinstance(src, torch.Tensor) fail")
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
new_inputs.clear()
graph.replay()
return static_outputs
return align_inputs_from_check_idxs(run, check_input_idxs)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if not torch._dynamo.config.cudagraph_backend_support_input_mutation:
mut_skip = check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
)
if mut_skip:
return mut_skip
skip = check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
)
if skip:
return skip
node = get_first_incompatible_cudagraph_node(aot_model)
if node:
return format_default_skip_message(f"incompatible op ({node.name})")
return None
def get_device_index(gm) -> int:
device_node_mapping = get_device_node_mapping(gm)
from torch_npu._inductor import config as npu_config
if npu_config.npugraph_trees.disable_cpu_input_check:
device_node_mapping.pop(torch.device("cpu"), None)
device = next(iter(device_node_mapping))
if not (device.type == "npu"):
raise RuntimeError("check device.type == npu fail", )
return device.index
def npugraphs(dynamo_model, dynamo_inputs):
from torch_npu.npu._graph_tree import npugraphify_impl as new_npugraphify_impl
do_npugraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_npugraphs(aot_model, aot_inputs, is_inference=False):
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
skip_msg = check_for_skip(aot_model, fixed)
if skip_msg:
BoxedBool.disable(do_npugraphs)
log.info("NPUGraph: skipped — %s", skip_msg)
log.warning("NPUGRAPH DFX Fix: check skip reason above, common causes: "
"input mutation, incompatible ops, data-dependent control flow")
return interp
boxed_device_index.set(get_device_index(aot_model))
out = new_npugraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
def backward_npugraphs(aot_model, aot_inputs):
interp = boxed_nop(aot_model, aot_inputs)
if not do_npugraphs:
return aot_model
fixed = count_tangents(aot_model)
skip_msg = check_for_skip(aot_model, fixed)
if skip_msg:
log_cudagraph_skip_and_bump_counter(
"skipping npugraphs due to %s", skip_msg
)
from torch_npu.npu._graph_tree import get_manager
manager = get_manager(
boxed_device_index.value, create_if_none_exists=False
)
if manager is None:
raise RuntimeError("check manager is None fail")
def fn(inputs):
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
return fn
out = new_npugraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
aot_npugraphs = aot_autograd(
fw_compiler=forward_npugraphs,
bw_compiler=backward_npugraphs,
inference_compiler=functools.partial(forward_npugraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_npugraphs(dynamo_model, dynamo_inputs)
class NpugraphsBackend:
compiler_name = "npugraphs"
@staticmethod
def reset():
from torch_npu.npu._graph_tree import reset_npugraph_trees
reset_npugraph_trees()
@staticmethod
def __call__(model, inputs):
return npugraphs(model, inputs)
def _apply_npugraph_tree_methods():
register_backend(name="npugraphs", compiler_fn=NpugraphsBackend())
torch._inductor.compile_fx.cudagraphify = npugraphify
torch._inductor.cudagraph_utils.check_multiple_devices_or_any_cpu_nodes = check_multiple_devices_or_any_cpu_nodes
torch.compiler.npugraph_mark_step_begin = npugraph_mark_step_begin