from collections.abc import Generator
from contextlib import contextmanager
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
__all__ = ["NPUGraphCaptureControlFlowOpDispatchMode", "ControlFlowOpWarmupDispatchMode"]
class NPUGraphCaptureControlFlowOpDispatchMode(TorchDispatchMode):
@classmethod
def ignore_compile_internals(cls) -> bool:
return True
def __init__(self) -> None:
self.supports_higher_order_operators = True
super().__init__()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func is torch.ops.higher_order.cond:
with self:
return if_else_node(*args)
kwargs = {} if kwargs is None else kwargs
return func(*args, **kwargs)
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
@classmethod
def ignore_compile_internals(cls) -> bool:
return True
def __init__(self) -> None:
super().__init__()
self.supports_higher_order_operators = True
self.capture_stream = torch.npu.Stream()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.ops.higher_order.cond:
if torch.npu.is_current_stream_capturing():
with self:
return if_else_node(*args)
with (
torch.npu.graph(
torch.npu.NPUGraph(),
pool=None,
stream=self.capture_stream,
capture_error_mode="relaxed",
),
self,
):
if_else_node(*args)
return func(*args, **kwargs)
return func(*args, **kwargs)
@contextmanager
def _if_body(pred: torch.Tensor) -> Generator[None, None, None]:
current_npu_graph = torch.npu.NPUGraph.get_currently_capturing_graph()
current_npu_graph.begin_capture_to_if_node(pred)
try:
yield
finally:
current_npu_graph.end_capture_to_conditional_node()
def if_else_node(pred: torch.Tensor, true_fn, false_fn, operands):
if not pred.is_npu:
raise ValueError(
"Conditions must be on an npu device to use conditional nodes in npu graphs"
)
outs = []
for lazy_pred, fn in [
(lambda: pred, true_fn),
(lambda: torch.logical_not(pred), false_fn),
]:
with _if_body(lazy_pred()):
outs.append(fn(*operands))
if len(outs) == 2:
for if_out, else_out in zip(
pytree.tree_iter(outs[0]), pytree.tree_iter(outs[1])
):
if_out.copy_(else_out)
return outs[0]