from typing import Optional, Tuple, Union
import torch
from ..patch_torch import support_autocast_for_meta
class CachingRotaryEmb(torch.nn.Module):
"""
Cache the position embeddings so that we can do quick index_select without
computing them again and again in each forward.
cos and sin are stored separately to align with NPU profiling shapes.
NPU stores cos/sin as (max_pos, rope_dim) each; previous implementation
concatenated them into (max_pos, 2*rope_dim) which produced an
aten.index.Tensor shape that never appears in profiling data.
"""
def __init__(
self,
rotary_emb: torch.nn.Module,
act_dtype: torch.dtype,
max_position_embeddings: int,
expand_to_3d_position_ids: bool = False,
):
super().__init__()
self.act_dtype = act_dtype
self.use_3d_position_index = False
x = torch.empty(max_position_embeddings, device="meta", dtype=act_dtype).unsqueeze(0)
position_ids = torch.arange(0, max_position_embeddings, device="meta", dtype=torch.long).unsqueeze(0)
if expand_to_3d_position_ids:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
with support_autocast_for_meta():
position_embeddings = rotary_emb(x, position_ids)
self.cos_cache: Optional[torch.Tensor]
self.sin_cache: Optional[torch.Tensor]
if isinstance(position_embeddings, (tuple, list)) and len(position_embeddings) == 2:
cos, sin = position_embeddings
cos = cos.squeeze()
sin = sin.squeeze()
if expand_to_3d_position_ids and cos.ndim == 3:
self.use_3d_position_index = True
self.register_buffer("cos_cache", cos, persistent=False)
self.register_buffer("sin_cache", sin, persistent=False)
else:
self.cos_cache = None
self.sin_cache = None
self.rotary_emb = rotary_emb
def forward(
self, x: torch.Tensor, position_ids: torch.Tensor
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
if self.cos_cache is not None and x.dtype == self.act_dtype:
if self.use_3d_position_index:
batch_idx = torch.arange(position_ids.size(0), device=position_ids.device)[:, None, None]
return (
self.cos_cache[batch_idx, position_ids],
self.sin_cache[batch_idx, position_ids],
)
if position_ids.ndim == 3:
position_ids = position_ids[0]
flat_ids = position_ids.flatten()
cos = self.cos_cache.index_select(0, flat_ids).reshape(position_ids.size(0), -1, self.cos_cache.size(-1))
sin = self.sin_cache.index_select(0, flat_ids).reshape(position_ids.size(0), -1, self.sin_cache.size(-1))
return cos, sin
else:
return self.rotary_emb(x, position_ids)