from typing import Sequence
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("cat")
def _(tensors: Sequence[torch.Tensor], dim: int = 0) -> torch.Tensor:
"""
Shape-only replacement for aten.cat that preserves dtype on meta tensors.
"""
if not tensors:
raise ValueError("tensor_cast.cat expects a non-empty tensor list")
ref = tensors[0]
for t in tensors[1:]:
if t.dtype != ref.dtype:
raise ValueError("tensor_cast.cat expects all input tensors to have the same dtype")
out_shape = list(ref.shape)
if dim < 0:
dim = dim + len(out_shape)
out_dim = 0
for t in tensors:
out_dim += t.shape[dim]
out_shape[dim] = out_dim
return torch.empty(
out_shape,
dtype=ref.dtype,
device=ref.device,
)