from typing import List
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("_internal_mark_region_begin")
def _(
x: torch.Tensor,
id: int,
) -> torch.Tensor:
"""Mark the beginning of a region of execution."""
return x
@register_tensor_cast_op("_internal_mark_region_end")
def _(
x: torch.Tensor,
id: int,
) -> torch.Tensor:
"""Mark the end of a region of execution."""
return x
@register_tensor_cast_op("_internal_copy_region")
def _(
x: torch.Tensor,
id: int,
) -> torch.Tensor:
"""Copy a region of execution marked previously."""
return x
@register_tensor_cast_op("_internal_wait_and_bind")
def _(
x: torch.Tensor,
stream_id: int,
deps: List[torch.Tensor],
) -> torch.Tensor:
"""Bind the next real op on ``x`` to ``stream_id`` after waiting on ``deps``.
This is a control-flow anchor used by multistream lowering. It does not modify
the data carried by ``x``. Instead, the runtime interprets it as:
1. the next real op consuming ``x`` should execute on ``stream_id``;
2. that real op must wait until all dependency tokens in ``deps`` are ready.
Example:
y = _internal_wait_and_bind(x, 1, [token0])
z = real_op(y)
token1 = _internal_record(z, 1)
Here ``real_op`` runs on stream 1 only after ``token0`` is ready.
"""
return x.clone()
@register_tensor_cast_op("_internal_record")
def _(
x: torch.Tensor,
stream_id: int,
) -> torch.Tensor:
"""Create a control token marking completion of the preceding real op.
This op is paired with ``_internal_wait_and_bind`` during multistream lowering.
The returned scalar tensor is a runtime control token, not a model activation.
Example:
y = real_op(x)
token = _internal_record(y, 0)
z = _internal_wait_and_bind(other, 1, [token])
The wait op can use ``token`` to express that a later op must not start until
``real_op`` on stream 0 has completed.
"""
return torch.empty((), dtype=torch.int64, device=x.device)