from typing import Optional, Tuple

import torch

from ..utils import register_tensor_cast_op


@register_tensor_cast_op("linear_attn_apply_padding_mask")
def _(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
) -> torch.Tensor:
    return torch.empty_like(hidden_states).contiguous()


@register_tensor_cast_op("linear_attn_causal_conv")
def _(
    mixed_qkv: torch.Tensor,
    conv_kernel_size: int,
) -> torch.Tensor:
    return torch.empty_like(mixed_qkv).contiguous()


@register_tensor_cast_op("linear_attn_causal_conv_update")
def _(
    mixed_qkv: torch.Tensor,
    conv_kernel_size: int,
) -> torch.Tensor:
    return torch.empty_like(mixed_qkv).contiguous()


@register_tensor_cast_op("linear_attn_fused_gdn_gating")
def _(
    query: torch.Tensor,
    key: torch.Tensor,
    b: torch.Tensor,
    a: torch.Tensor,
    a_log: torch.Tensor,
    dt_bias: torch.Tensor,
    num_v_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size, seq_len, _, head_k_dim = query.shape
    out_shape = (batch_size, seq_len, num_v_heads, head_k_dim)
    query_out = torch.empty(out_shape, dtype=query.dtype, device=query.device)
    key_out = torch.empty(out_shape, dtype=key.dtype, device=key.device)
    beta = torch.empty((batch_size, seq_len, num_v_heads), dtype=b.dtype, device=b.device)
    g = torch.empty((batch_size, seq_len, num_v_heads), dtype=torch.float32, device=a.device)
    return query_out, key_out, beta, g


@register_tensor_cast_op("linear_attn_chunk_gated_delta_rule")
def _(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    beta: torch.Tensor,
    g: torch.Tensor,
    chunk_size: int,
    state_read_passes: int,
    state_write_passes: int,
) -> torch.Tensor:
    del key, beta, g, chunk_size, state_read_passes, state_write_passes
    batch_size, seq_len, num_v_heads, _ = query.shape
    head_v_dim = value.shape[-1]
    return torch.empty(
        (batch_size, seq_len, num_v_heads, head_v_dim),
        dtype=query.dtype,
        device=query.device,
    )


@register_tensor_cast_op("linear_attn_recurrent_gated_delta_rule")
def _(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    beta: torch.Tensor,
    g: torch.Tensor,
    state_read_passes: int,
    state_write_passes: int,
) -> torch.Tensor:
    del key, beta, g, state_read_passes, state_write_passes
    batch_size, seq_len, num_v_heads, _ = query.shape
    head_v_dim = value.shape[-1]
    return torch.empty(
        (batch_size, seq_len, num_v_heads, head_v_dim),
        dtype=query.dtype,
        device=query.device,
    )


@register_tensor_cast_op("linear_attn_gated_rmsnorm")
def _(
    core_attn_out: torch.Tensor,
    z: torch.Tensor,
    weight: Optional[torch.Tensor],
    eps: float,
) -> torch.Tensor:
    del z, weight, eps
    return torch.empty_like(core_attn_out).contiguous()