import torch
from .. import ops
from .utils import ModelWrapperBase
class RegionMarkerWrapper(ModelWrapperBase):
def __init__(
self,
region_id: int,
layer: torch.nn.Module,
repeat_count: int = 1,
):
"""
Wrap a layer with region markers.
Args:
region_id: The id of the region to mark.
layer: Original layer instance to wrap
repeat_count: Number of structurally equivalent layers represented by this
wrapper. The first occurrence runs the real layer, the rest replay the
marked region.
"""
super().__init__(layer)
self.region_id = region_id
self.repeat_count = repeat_count
self.returns_tuple = True
self.return_length = 1
def forward(self, *args, **kwargs):
hidden_states = args[0]
hidden_states = torch.ops.tensor_cast._internal_mark_region_begin(
hidden_states,
self.region_id,
)
args = (hidden_states,) + args[1:]
result = self._inner.forward(*args, **kwargs)
if isinstance(result, tuple):
self.returns_tuple = True
self.return_length = len(result)
hidden_states = result[0]
hidden_states = torch.ops.tensor_cast._internal_mark_region_end(
hidden_states,
self.region_id,
)
return (hidden_states,) + result[1:]
else:
self.returns_tuple = False
hidden_states = result
hidden_states = torch.ops.tensor_cast._internal_mark_region_end(
hidden_states,
self.region_id,
)
return hidden_states
class CopyLayerWrapper(torch.nn.Module):
def __init__(
self,
region_id: int,
layer: torch.nn.Module,
representative: RegionMarkerWrapper,
):
"""
Wrap a layer with a copy operation that copies a previously marked region.
Args:
region_id: The id of the range to repeat.
layer: Original layer instance used only to copy lightweight metadata.
representative: The representative layer whose return format should be mirrored.
"""
super().__init__()
self.region_id = region_id
object.__setattr__(self, "representative", representative)
for attr_name in ("attention_type", "layer_type"):
if hasattr(layer, attr_name):
setattr(self, attr_name, getattr(layer, attr_name))
def forward(self, *args, **kwargs):
hidden_states = args[0]
hidden_states = torch.ops.tensor_cast._internal_copy_region(
hidden_states,
self.region_id,
)
output_attentions = kwargs.get("output_attentions", False)
use_cache = kwargs.get("use_cache", False)
output_router_logits = kwargs.get("output_router_logits", False)
if self.representative.returns_tuple:
outputs = (hidden_states,)
return_length = getattr(self.representative, "return_length", 1)
if output_attentions:
outputs += (None,)
if use_cache:
outputs += (None,)
if output_router_logits:
outputs += (None,)
while len(outputs) < return_length:
outputs += (None,)
else:
outputs = hidden_states
return outputs