from typing import Any, Dict, List, Sequence
from torch_npu.utils._error_code import ErrCode, pta_error
ACLGRAPH_UPDATE_PLAN_GLOBAL = "_torch_npu_aclgraph_update_plan"
def _normalize_op_name(op_name: str) -> str:
for suffix in (".default", ".out"):
if op_name.endswith(suffix):
return op_name[: -len(suffix)]
return op_name
def _op_names_compatible(expected_op: str, actual_op: str) -> bool:
return expected_op == actual_op or _normalize_op_name(expected_op) == _normalize_op_name(actual_op)
def _get_update_specs(handler_cls: Any, op_name: str) -> List[Any]:
get_update_specs = getattr(handler_cls, "get_update_specs", None)
if get_update_specs is not None:
return get_update_specs(op_name)
return getattr(handler_cls, "UPDATE_SPECS", {}).get(op_name, [])
def _consumable_keys(record: Any) -> set:
from torch_npu.npu._npugraph_handlers.npugraph_handler import _NPU_GRAPH_OP_HANDLERS
op_name = record.op_cache_entry.__name__
keys = set(getattr(record, "kwargs", {}).keys())
handler_cls = _NPU_GRAPH_OP_HANDLERS.get(op_name)
if handler_cls is not None:
keys.update(key for _, _, key in _get_update_specs(handler_cls, op_name))
return keys
def resolve_aclgraph_update_plan(
plan: Sequence[Dict[str, Any]],
new_inputs: Sequence[Any],
) -> List[Dict[str, Any]]:
cpu_update_input: List[Dict[str, Any]] = []
for entry_idx, entry in enumerate(plan or []):
updates = entry.get("updates", {})
resolved = {}
for key, source in updates.items():
resolved[key] = _resolve_source(entry_idx, key, source, new_inputs)
cpu_update_input.append(resolved)
return cpu_update_input
def validate_aclgraph_update_plan(
plan: Sequence[Dict[str, Any]],
graph_dispatch_records: Sequence[Any],
) -> None:
plan = plan or []
if not plan and graph_dispatch_records:
ops = [record.op_cache_entry.__name__ for record in graph_dispatch_records]
raise RuntimeError(
"Captured updatable ACLGraph operators but missing ACLGraph update plan: "
f"{ops}. This may be caused by reusing cached compiled code generated "
"with ACLGraph disabled or by using the npugraphs backend, which does "
"not support ACLGraph update plans yet.",
pta_error(ErrCode.PARAM),
)
if len(plan) != len(graph_dispatch_records):
raise RuntimeError(
"ACLGraph update plan length mismatch: "
f"plan has {len(plan)} entries but graph captured "
f"{len(graph_dispatch_records)} updatable records",
pta_error(ErrCode.PARAM),
)
for idx, (entry, record) in enumerate(zip(plan, graph_dispatch_records)):
if not isinstance(entry, dict):
raise RuntimeError(
f"ACLGraph update plan has invalid plan entry at index {idx}: {entry!r}",
pta_error(ErrCode.PARAM),
)
expected_op = entry.get("op")
actual_op = record.op_cache_entry.__name__
if not isinstance(expected_op, str):
raise RuntimeError(
"ACLGraph update plan has invalid plan entry: "
f"entry {idx} op must be a string, got {expected_op!r}",
pta_error(ErrCode.PARAM),
)
if not _op_names_compatible(expected_op, actual_op):
raise RuntimeError(
"ACLGraph update plan op mismatch: "
f"entry {idx} expects {expected_op!r}, captured {actual_op!r}",
pta_error(ErrCode.PARAM),
)
updates = entry.get("updates", {})
if not isinstance(updates, dict):
raise RuntimeError(
"ACLGraph update plan has invalid plan entry: "
f"entry {idx} updates must be a dict, got {updates!r}",
pta_error(ErrCode.PARAM),
)
if not updates:
raise RuntimeError(
"ACLGraph update plan entry has no updates: "
f"entry {idx}, op {actual_op!r}",
pta_error(ErrCode.PARAM),
)
consumable = _consumable_keys(record)
unknown = sorted(set(updates) - consumable)
if unknown:
raise RuntimeError(
"ACLGraph update plan has key(s) that captured op cannot consume: "
f"entry {idx}, op {actual_op!r}, keys {unknown}",
pta_error(ErrCode.PARAM),
)
for key, source in updates.items():
_validate_source(idx, key, source)
def build_cpu_update_input_for_graph(
plan: Sequence[Dict[str, Any]],
new_inputs: Sequence[Any],
graph_dispatch_records: Sequence[Any],
) -> List[Dict[str, Any]]:
validate_aclgraph_update_plan(plan, graph_dispatch_records)
return resolve_aclgraph_update_plan(plan, new_inputs)
def validate_aclgraph_update_plan_for_graph(
plan: Sequence[Dict[str, Any]],
graph: Any,
) -> None:
if graph is None or not graph.auto_dispatch_capture:
return
validate_aclgraph_update_plan(
plan,
graph.graph_dispatch_mode.graph_dispatch_records,
)
def update_aclgraph_records_for_graph(
cpu_update_input: Sequence[Dict[str, Any]],
graph: Any,
) -> bool:
if graph is None or not graph.auto_dispatch_capture:
return False
if not cpu_update_input:
return False
graph.update(cpu_update_input)
return True
def _resolve_source(
entry_idx: int,
key: str,
source: Dict[str, Any],
new_inputs: Sequence[Any],
) -> Any:
if not isinstance(source, dict):
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has invalid source {source!r}",
pta_error(ErrCode.PARAM),
)
kind = source.get("kind")
if kind == "input":
index = source.get("index")
if not isinstance(index, int) or index < 0 or index >= len(new_inputs):
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"input index {index} is out of range for {len(new_inputs)} inputs",
pta_error(ErrCode.PARAM),
)
return new_inputs[index]
if kind == "constant":
return source.get("value")
if kind == "none":
return None
if kind == "list":
items = source.get("items")
if not isinstance(items, list):
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has invalid list source items {items!r}",
pta_error(ErrCode.PARAM),
)
return [_resolve_source(entry_idx, key, item, new_inputs) for item in items]
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has unsupported source kind {kind!r}",
pta_error(ErrCode.PARAM),
)
def _validate_literal_constant(value: Any) -> None:
import torch
if isinstance(value, torch.Tensor):
raise RuntimeError(
"Unsupported ACLGraph update Tensor constant; Tensor update values must be graph inputs",
pta_error(ErrCode.PARAM),
)
if value is None or isinstance(value, (int, float, bool, str)):
return
if isinstance(value, (list, tuple)):
for item in value:
_validate_literal_constant(item)
return
raise RuntimeError(
f"Unsupported ACLGraph update constant: {value!r}",
pta_error(ErrCode.PARAM),
)
def _validate_source(entry_idx: int, key: str, source: Dict[str, Any]) -> None:
if not isinstance(source, dict):
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has invalid source {source!r}",
pta_error(ErrCode.PARAM),
)
kind = source.get("kind")
if kind == "input":
index = source.get("index")
if not isinstance(index, int) or index < 0:
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has invalid input index {index}",
pta_error(ErrCode.PARAM),
)
elif kind == "constant":
try:
_validate_literal_constant(source.get("value"))
except RuntimeError as exc:
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has unsupported constant value {source.get('value')!r}",
pta_error(ErrCode.PARAM),
) from exc
elif kind == "none":
return
elif kind == "list":
items = source.get("items")
if not isinstance(items, list):
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has invalid list source items {items!r}",
pta_error(ErrCode.PARAM),
)
for item in items:
_validate_source(entry_idx, key, item)
else:
raise RuntimeError(
f"ACLGraph update plan entry {entry_idx} key {key} "
f"has unsupported source kind {kind!r}",
pta_error(ErrCode.PARAM),
)