from typing import Optional, Tuple
import torch
from torch._subclasses.fake_tensor import is_fake
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("kv_rmsnorm_rope_cache", mutates_args=("kv_cache",))
def _(
kv: torch.Tensor,
gamma: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_lora_rank: int,
qk_rope_head_dim: int,
epsilon: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Fused KV RmsNorm + RoPE + Cache write for MLA attention.
Equivalent to vllm-ascend's torch_npu.npu_kv_rmsnorm_rope_cache().
This is a meta implementation for TensorCast performance modeling.
Algorithm:
1. Split kv into kv_c (compressed) and k_pe (rope part)
2. Apply RmsNorm to kv_c: kv_c_normed = kv_c * gamma / sqrt(mean(kv_c^2) + epsilon)
3. Apply RoPE to k_pe using cos/sin
4. Write both kv_c_normed and rotated k_pe to kv_cache at slot_mapping positions
Args:
kv: Input tensor of shape (num_tokens, kv_lora_rank + qk_rope_head_dim)
Must be contiguous and have correct dtype (BF16/FP16)
gamma: RmsNorm weight of shape (kv_lora_rank,)
cos, sin: Rotary embeddings of shape (1, seq_len, qk_rope_head_dim)
kv_cache: Cache tensor of shape (total_blocks, block_size, kv_lora_rank + qk_rope_head_dim)
slot_mapping: Cache slot indices of shape (num_tokens,)
Must be in range [0, total_blocks * block_size)
kv_lora_rank: Dimension of compressed KV (must be > 0)
qk_rope_head_dim: Dimension of RoPE part (must be > 0)
epsilon: RmsNorm epsilon (default: 1e-6)
Returns:
k_pe: RoPE-rotated key of shape (num_tokens, qk_rope_head_dim)
kv_c_normed: Normalized compressed KV of shape (num_tokens, kv_lora_rank)
Note:
This is a meta operation for performance modeling.
The actual implementation in vllm-ascend uses torch_npu.npu_kv_rmsnorm_rope_cache.
"""
num_tokens = kv.size(0)
device = kv.device
dtype = kv.dtype
return (
torch.empty((num_tokens, qk_rope_head_dim), dtype=dtype, device=device),
torch.empty((num_tokens, kv_lora_rank), dtype=dtype, device=device),
)
@register_tensor_cast_op("concat_and_cache_mla", mutates_args=("kv_cache",))
def _(
kv_c_normed: torch.Tensor,
k_rot: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""
concat `kv_c_normed` and `k_rot` with into `kv_cache` according to `slot_mapping`.
Args:
kv_c_normed: (num_tokens, kv_lora_rank)
k_rot: (num_tokens, qk_rope_head_dim)
kv_cache: (total_num_blocks, block_size, kv_lora_rank + qk_rope_head_dim)
slot_mapping: see `AttentionMetadataBase`
"""
@register_tensor_cast_op("mlapo")
def _(
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
q_a_proj_weight: Optional[torch.Tensor],
q_a_layernorm_weight: Optional[torch.Tensor],
q_b_proj_weight: Optional[torch.Tensor],
kv_a_proj_weight: Optional[torch.Tensor],
kv_a_layernorm_weight: torch.Tensor,
num_heads: int,
qk_head_dim: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
kv_lora_rank: int,
q_lora_rank: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fused MLA preprocessing op that models RMS norm, matmuls, and RoPE rotation.
Args:
hidden_states: (num_tokens, hidden_size) activations entering MLA.
cos/sin: rotary embedding caches shaped (1, seq_len, qk_rope_head_dim).
q_a_proj_weight / q_b_proj_weight: LoRA weights with shapes
(q_lora_rank, hidden_size) and (num_heads * qk_head_dim, q_lora_rank).
q_a_layernorm_weight: RMSNorm scale for the LoRA branch (q_lora_rank,).
kv_a_proj_weight: (kv_lora_rank + qk_rope_head_dim, hidden_size) matrix
producing compressed key/value streams; kv_a_layernorm_weight matches
its last dimension.
num_heads/qk_* dims/kv_lora_rank/q_lora_rank: structural scalars that
describe the MLA layout.
Returns:
q_states: (num_tokens, num_heads, qk_head_dim)
kv_c_normed: (num_tokens, kv_lora_rank)
k_rot: (num_tokens, qk_rope_head_dim)
qa_normed: (num_tokens, q_lora_rank) when q_lora_rank is set;
otherwise an empty last-dimension tensor that the caller converts back to None.
"""
num_tokens = hidden_states.size(0)
device = hidden_states.device
dtype = hidden_states.dtype
qa_normed_dim = q_lora_rank or 0
return (
torch.empty((num_tokens, num_heads, qk_head_dim), dtype=dtype, device=device),
torch.empty((num_tokens, kv_lora_rank), dtype=dtype, device=device),
torch.empty((num_tokens, qk_rope_head_dim), dtype=dtype, device=device),
torch.empty((num_tokens, qa_normed_dim), dtype=dtype, device=device),
)
@register_tensor_cast_op("mlapo_quant")
def _(
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
q_a_proj_weight: Optional[torch.Tensor],
q_a_layernorm_weight: Optional[torch.Tensor],
q_b_proj_weight: Optional[torch.Tensor],
kv_a_proj_weight: Optional[torch.Tensor],
kv_a_layernorm_weight: torch.Tensor,
num_heads: int,
qk_head_dim: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
kv_lora_rank: int,
q_lora_rank: int,
q_a_proj_scale: torch.Tensor,
q_a_proj_offset: Optional[torch.Tensor],
q_b_proj_scale: torch.Tensor,
q_b_proj_offset: Optional[torch.Tensor],
kv_a_proj_scale: torch.Tensor,
kv_a_proj_offset: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantized variant of the fused MLA preprocessing op.
Args mirror `mlapo`, but q_a/q_b/kv_a *_scale/*_offset tensors encode the
quantization scheme (per-tensor/per-group) applied to their respective
linear layers.
Returns:
q_states: (num_tokens, num_heads, qk_head_dim)
kv_c_normed: (num_tokens, kv_lora_rank)
k_rot: (num_tokens, qk_rope_head_dim)
qa_normed: (num_tokens, q_lora_rank) when q_lora_rank is set;
otherwise an empty last-dimension tensor that the caller converts back to None.
"""
num_tokens = hidden_states.size(0)
device = hidden_states.device
dtype = hidden_states.dtype
qa_normed_dim = q_lora_rank or 0
return (
torch.empty((num_tokens, num_heads, qk_head_dim), dtype=dtype, device=device),
torch.empty((num_tokens, kv_lora_rank), dtype=dtype, device=device),
torch.empty((num_tokens, qk_rope_head_dim), dtype=dtype, device=device),
torch.empty((num_tokens, qa_normed_dim), dtype=dtype, device=device),
)
@register_tensor_cast_op("multihead_latent_attention")
def _(
q: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
query_lens: Optional[torch.Tensor],
W_UK_T: Optional[torch.Tensor],
W_UV: Optional[torch.Tensor],
kv_b_proj: Optional[torch.Tensor],
v_head_dim: int,
topk_limit: Optional[int] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This op computes multi-head latent attention (MLA). It is supposed to use different
algorithms for prefill and decode shapes while the input sequences could fuse prefill
and decode sequences and should be handled separately with different algorithms.
We judge the prefill or decode phase according to the query length per `query_start_loc`.
If the query length is
For prefill (non-strict math/code):
k_nope, v = (kv_c_normed @ kv_b_proj).view(-1, num_heads, qk_nope_head_dim + v_head_dim).split(dim=-1)
softmax(q @ (k_nope, k_rot) + sparse_mask(topk_indices)) @ v
For decode (non-strict math/code):
softmax(q @ W_UK_T @ k_cache + sparse_mask(topk_indices)) @ v_cache @ W_UV
`sparse_mask(topk_indices)` is omitted when `topk_indices` is None.
Args:
q: (num_tokens, num_heads, qk_nope_head_dim+qk_rope_head_dim)
The query states after compression and decompression.
kv_cache: (total_num_blocks, block_size, kv_lora_rank + qk_rope_head_dim)
The cached key-value states with current KV states already updated.
block_table/query_start_loc/seq_lens: see `AttentionMetadataBase`
W_UK_T, W_UV: (num_heads, qk_nope_head_dim, kv_lora_rank), (num_heads, kv_lora_rank, v_head_dim)
used in the decode phase, None if only prefill sequences are provided.
kv_b_proj: (kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
used in the prefill phase, None if only decode sequences are provided.
topk_limit: Number of top-K tokens for sparse attention
topk_indices: Preselected token positions for sparse attention.
Returns:
(num_tokens, num_heads, v_head_dim)
"""
if topk_indices is not None:
_ = topk_indices.shape[-1]
return torch.empty(q.shape[0], q.shape[1], v_head_dim, dtype=q.dtype, device="meta")
@register_tensor_cast_op("multihead_latent_attention_quant")
def _(
q: torch.Tensor,
kv_cache: torch.Tensor,
block_table: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
query_lens: Optional[torch.Tensor],
W_UK_T: Optional[torch.Tensor],
W_UV: Optional[torch.Tensor],
kv_b_proj: Optional[torch.Tensor],
v_head_dim: int,
topk_limit: Optional[int],
topk_indices: Optional[torch.Tensor],
query_scale: torch.Tensor,
query_offset: Optional[torch.Tensor],
kv_scale: torch.Tensor,
kv_offset: Optional[torch.Tensor],
kv_projected_scale: torch.Tensor,
kv_projected_offset: Optional[torch.Tensor],
qk_scale: torch.Tensor,
qk_offset: Optional[torch.Tensor],
v_scale: torch.Tensor,
v_offset: Optional[torch.Tensor],
attention_prob_scale: torch.Tensor,
attention_prob_offset: Optional[torch.Tensor],
kv_b_proj_scale: torch.Tensor,
kv_b_proj_offset: Optional[torch.Tensor],
out_scale: Optional[torch.Tensor],
out_offset: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""
Similar to `multihead_latent_attention` but with quantization support.
For prefill (non-strict math/code):
quant_kv_proj = quant(kv_c_normed @ kv_b_proj, kv_projected_scale, kv_projected_offset)
k_nope, v = quant_kv_proj.view(-1, num_heads, qk_nope_head_dim + v_head_dim).split(dim=-1)
out_fp = quant(
softmax(q @ (k_nope, k_rot) + sparse_mask(topk_indices)),
attention_prob_scale,
attention_prob_offset,
) @ v
out = quant(out_fp, out_scale, out_offset) # optional
For decode (non-strict math/code):
quant_qk = quant(q @ W_UK_T, qk_scale, qk_offset)
quant_scores = quant(
softmax(quant_qk @ k_cache + sparse_mask(topk_indices)),
attention_prob_scale,
attention_prob_offset,
)
out_fp = quant(quant_scores @ v_cache, v_scale, v_offset) @ W_UV
out = quant(out_fp, out_scale, out_offset) # optional
`sparse_mask(topk_indices)` is omitted when `topk_indices` is None.
Args:
topk_limit: Number of top-K tokens for sparse attention
topk_indices: Preselected token positions for sparse attention.
Returns:
(num_tokens, num_heads, v_head_dim)
"""
if topk_indices is not None:
_ = topk_indices.shape[-1]
if out_dtype is None:
out_dtype = q.dtype
return torch.empty(q.shape[0], q.shape[1], v_head_dim, dtype=out_dtype, device="meta")
@register_tensor_cast_op("dsa_indexer", mutates_args=("indexer_cache",))
def _(
hidden_states: torch.Tensor,
qa_normed: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
indexer_cache: torch.Tensor,
slot_mapping: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
seq_lens: Optional[torch.Tensor],
wq_b_weight: torch.Tensor,
wk_weight: torch.Tensor,
weights_proj_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
num_heads: int,
head_dim: int,
qk_rope_head_dim: int,
topk_limit: int,
) -> torch.Tensor:
"""
Fused DSA indexer semantic block.
For the DeepSeek-V3.2-style fp8 path (non-strict math/code):
q = rope(wq_b(qa_normed))
k = rope(k_norm(wk(hidden_states)))
q = rotate_activation(q)
k = rotate_activation(k)
q_fp8, q_scale = act_quant(q)
k_fp8, k_scale = act_quant(k)
k_cache, k_scale_cache = append(indexer_cache, k_fp8, k_scale)
weights = weights_proj(hidden_states) * num_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * head_dim**-0.5
index_score = fp8_index(q_fp8, weights, k_cache, k_scale_cache)
topk_indices = topk(index_score, k=min(topk_limit, active_seq_len), dim=-1).indices
Compared with the fp8 path, the bf16 / GLM5-style path removes:
- rotate_activation on q and k
- act_quant on q and k
- scale-cache writes alongside the key cache
- fp8-specific relu / q-scale / k-scale score shaping
and instead uses direct cache scoring plus head reduction:
weights = weights_proj(hidden_states) * num_heads**-0.5
head_scores = (q @ k_cache.transpose(-1, -2)) * head_dim**-0.5
index_score = reduce_sum(head_scores * weights.unsqueeze(-1), dim=-2)
topk_indices = topk(index_score, k=min(topk_limit, active_seq_len), dim=-1).indices
Returns:
topk_indices: (batch, seq_len, min(topk_limit, active_seq_len))
"""
batch, seq_len, _ = hidden_states.shape
if is_fake(hidden_states):
topk = topk_limit
else:
active_seq_len = int(seq_lens.max().item()) if seq_lens is not None else seq_len
topk = min(topk_limit, active_seq_len)
return torch.empty(batch, seq_len, topk, dtype=torch.long, device=hidden_states.device)