from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Optional
import torch
from .. import ops
from ..model_config import MlaConfig, MultiheadLatentAttentionQuantConfig
from ..parallel_group import ParallelGroup
from ..utils import exact_division
from .attention import AttentionMetadataBase
from .quant_linear import TensorCastQuantLinear
from .utils import get_partial_sharded, ModelWrapperBase
def tp_plan_module_path(prefix: str, relative_path: str) -> str:
"""Return a TP-plan glob for a module under ``prefix``.
Decoder stacks use ``{prefix}.*.{relative_path}`` because each layer index
sits between ``prefix`` and the submodule. MTP blocks are already addressed
as ``mtp.layers.*.mtp_block``, so their attention/MLP weights live directly
under that prefix without an extra layer wildcard.
Use :func:`tp_plan_nested_module_path` instead when the target module may sit
under extra wrapper segments such as ``self_attn._inner`` (e.g. ``o_proj``).
"""
if ".mtp_block" in prefix:
return f"{prefix}.{relative_path}"
return f"{prefix}.*.{relative_path}"
def tp_plan_nested_module_path(prefix: str, relative_path: str) -> str:
"""Return a TP-plan glob that allows optional wrapper segments before ``relative_path``.
Example: ``model.layers.*.o_proj`` matches ``model.layers.0.self_attn._inner.o_proj``,
and ``mtp.layers.*.mtp_block.*.o_proj`` matches MTP attention outputs.
"""
return f"{prefix}.*.{relative_path}"
class MultiheadLatentAttentionBase(torch.nn.Module, ABC):
supports_parallel_group_manager: bool = False
def __init__(
self,
mla_config: MlaConfig,
mla_module: torch.nn.Module,
decode_only: bool = False,
) -> None:
super().__init__()
self.mla_config = mla_config
self._inner = mla_module
self.decode_only = decode_only
self.quant_config: Optional[MultiheadLatentAttentionQuantConfig] = None
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attention_meta: Optional[AttentionMetadataBase] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
pass
def __getattr__(self, name: str) -> Any:
if hasattr(self.mla_config.field_names, name):
return getattr(self._inner, getattr(self.mla_config.field_names, name))
return super().__getattr__(name)
def quantize_params(self):
"""
Called during the initialization phase after the inner module is quantized.
This allows quantization for extra parameters in this module.
"""
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.view(-1, cos.shape[-1])
sin = sin.view(-1, sin.shape[-1])
q_embed = (q * cos.unsqueeze(1)) + (rotate_half(q) * sin.unsqueeze(1))
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def apply_rotary_pos_emb_interleave(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.view(-1, cos.shape[-1])
sin = sin.view(-1, sin.shape[-1])
s, n, d = q.shape
q = q.view(s, n, d // 2, 2).transpose(-1, -2).reshape(s, n, d)
s, d = k.shape
k = k.view(s, d // 2, 2).transpose(-1, -2).reshape(s, d)
q_embed = (q * cos.unsqueeze(1)) + (rotate_half(q) * sin.unsqueeze(1))
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MultiheadLatentAttentionTensorCast(MultiheadLatentAttentionBase):
def __init__(
self,
mla_config: MlaConfig,
mla_module: torch.nn.Module,
tp_group: ParallelGroup,
decode_only: bool = False,
):
super().__init__(mla_config, mla_module, decode_only)
self.tp_group = tp_group
self._num_heads_per_rank = exact_division(self.num_heads, tp_group.world_size)
self._setup_kv_b_decomposition(tp_group)
def _setup_kv_b_decomposition(self, tp_group: ParallelGroup) -> None:
"""
Hook: shard ``kv_b_proj`` across TP ranks and split it into ``W_UK``/``W_UV``.
Default behaviour matches the legacy MLA path used by V3 / V3.2: it requires
``self.kv_b_proj`` to be present on the inner module. Subclasses whose
attention does not have a ``kv_b_proj`` (e.g. DeepSeek V4 with shared KV) can
override this hook with a no-op.
"""
sharded_weight = get_partial_sharded(
self.kv_b_proj.weight.data,
tp_group.world_size,
tp_group.rank_in_group,
unit_num=self.num_heads,
)
self.kv_b_proj_weight_t = sharded_weight.transpose(0, 1)
kv_b_proj_view = self.kv_b_proj_weight_t.view(
self.kv_lora_rank,
self._num_heads_per_rank,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_view.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
self.W_UV = W_UV.transpose(0, 1)
self.W_UK_T = W_UK.permute(1, 2, 0)
@classmethod
def requires_indexer_cache(cls) -> bool:
"""
Hook for subclasses that require the auxiliary sparse-attention indexer
cache allocated by input_generator.
The default MLA path does not use an indexer cache.
"""
return False
@classmethod
def build_tp_plan_extras(cls, prefix: str, params: dict, config_info) -> dict[str, tuple[str, dict]]:
"""
Hook for subclasses to inject extra TP sharding rules into the generic MLA
q/kv shard plan without adding model-specific branches to transformations.py.
The default MLA path has no extra modules beyond q/kv_b projections.
Subclasses can override this to register additional learned linears that
belong to the attention block (e.g. V4 indexer projections).
"""
return {}
@classmethod
def build_o_proj_tp_plan_extras(cls, prefix: str, params: dict, config_info) -> dict[str, tuple[str, dict]]:
"""
Hook for subclasses to inject extra O-projection-related TP sharding rules
into the generic MLA shard plan.
"""
return {}
@staticmethod
def extract_qparams(
module: torch.nn.Module,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
target = module
if hasattr(module, "_inner"):
target = module._inner
if isinstance(target, TensorCastQuantLinear):
return target.qweight, target.weight_scale, target.weight_offset
weight = getattr(target, "weight", None)
if weight is None:
raise AttributeError(f"Module {module.__class__.__name__} does not expose a weight tensor. ")
return weight.data, None, None
def _pre_attention_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_meta: Optional[AttentionMetadataBase] = None,
qa_normed: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Pre-attention processing hook that runs before core attention computation.
This hook is INTENDED FOR IN-PLACE CACHE PREPARATION (e.g., writing precomputed key
features or index data into pre-allocated cache tensors such as indexer_cache).
"""
return None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
kv_cache_unused: Optional[torch.Tensor] = None,
attention_meta: Optional[AttentionMetadataBase] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
kv_cache_by_layers = kwargs.pop("kv_cache_by_layers", None)
kv_cache = kv_cache_by_layers[self.layer_idx] if kv_cache_by_layers else None
batch_size, seq_length = hidden_states.shape[:-1]
num_tokens = batch_size * seq_length
hidden_states_view = hidden_states.view(num_tokens, -1)
cos, sin = position_embeddings
self.q_a_proj_weight, self.q_a_proj_scale, self.q_a_proj_offset = self.extract_qparams(self.q_a_proj)
self.q_b_proj_weight, self.q_b_proj_scale, self.q_b_proj_offset = self.extract_qparams(self.q_b_proj)
self.kv_a_proj_weight, self.kv_a_proj_scale, self.kv_a_proj_offset = self.extract_qparams(
self.kv_a_proj_with_mqa
)
self.q_a_layernorm_weight = self.q_a_layernorm.weight.data
self.kv_a_layernorm_weight = self.kv_a_layernorm.weight.data
linear_quant_enabled = (
getattr(self, "q_a_proj_scale", None) is not None
and getattr(self, "q_b_proj_scale", None) is not None
and getattr(self, "kv_a_proj_scale", None) is not None
)
if linear_quant_enabled:
q_states, kv_c_normed, k_rot, qa_normed = torch.ops.tensor_cast.mlapo_quant(
hidden_states_view,
cos,
sin,
self.q_a_proj_weight,
self.q_a_layernorm_weight,
self.q_b_proj_weight,
self.kv_a_proj_weight,
self.kv_a_layernorm_weight,
self._num_heads_per_rank,
self.qk_head_dim,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.kv_lora_rank,
self.q_lora_rank,
self.q_a_proj_scale,
self.q_a_proj_offset,
self.q_b_proj_scale,
self.q_b_proj_offset,
self.kv_a_proj_scale,
self.kv_a_proj_offset,
)
else:
q_states, kv_c_normed, k_rot, qa_normed = torch.ops.tensor_cast.mlapo(
hidden_states_view,
cos,
sin,
self.q_a_proj_weight,
self.q_a_layernorm_weight,
self.q_b_proj_weight,
self.kv_a_proj_weight,
self.kv_a_layernorm_weight,
self._num_heads_per_rank,
self.qk_head_dim,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.kv_lora_rank,
self.q_lora_rank,
)
if self.q_lora_rank is not None:
qa_normed = qa_normed.view(batch_size, seq_length, -1)
else:
qa_normed = None
pre_attn_out = self._pre_attention_forward(
hidden_states=hidden_states,
qa_normed=qa_normed,
position_embeddings=position_embeddings,
attention_meta=attention_meta,
**kwargs,
)
query_start_loc = attention_meta.query_start_loc if attention_meta else None
seq_lens = attention_meta.seq_lens if attention_meta else None
query_lens = attention_meta.query_lens if attention_meta else None
if self.quant_config is not None:
quant_config = self.quant_config
out_dtype = self.quant_config.get_quant_dtype()
q_states = torch.ops.tensor_cast.quantize(
q_states,
quant_config.query_scale,
quant_config.query_offset,
out_dtype,
)
kv_c_normed = torch.ops.tensor_cast.quantize(
kv_c_normed,
quant_config.kv_scale,
quant_config.kv_offset,
out_dtype,
)
k_rot = torch.ops.tensor_cast.quantize(
k_rot,
quant_config.kv_scale,
quant_config.kv_offset,
out_dtype,
)
if attention_meta is not None:
torch.ops.tensor_cast.concat_and_cache_mla(kv_c_normed, k_rot, kv_cache, attention_meta.slot_mapping)
else:
if attention_meta is not None:
torch.ops.tensor_cast.concat_and_cache_mla(kv_c_normed, k_rot, kv_cache, attention_meta.slot_mapping)
extra_backend_kwargs = {
"topk_limit": None,
"topk_indices": None,
**self._get_backend_kwargs(pre_attn_out),
}
if self.quant_config is not None:
attention_backend = partial(
torch.ops.tensor_cast.multihead_latent_attention_quant,
W_UK_T=self.W_UK_T,
W_UV=self.W_UV,
kv_b_proj=self.kv_b_proj_weight_t,
v_head_dim=self.v_head_dim,
query_scale=self.quant_config.query_scale,
query_offset=self.quant_config.query_offset,
kv_scale=self.quant_config.kv_scale,
kv_offset=self.quant_config.kv_offset,
kv_projected_scale=self.quant_config.kv_projected_scale,
kv_projected_offset=self.quant_config.kv_projected_offset,
qk_scale=self.quant_config.qk_scale,
qk_offset=self.quant_config.qk_offset,
v_scale=self.quant_config.v_scale,
v_offset=self.quant_config.v_offset,
attention_prob_scale=self.quant_config.attention_prob_scale,
attention_prob_offset=self.quant_config.attention_prob_offset,
kv_b_proj_scale=self.kv_b_proj_scale,
kv_b_proj_offset=self.kv_b_proj_offset,
out_scale=self.quant_config.out_scale,
out_offset=self.quant_config.out_offset,
out_dtype=hidden_states.dtype,
**extra_backend_kwargs,
)
else:
attention_backend = partial(
torch.ops.tensor_cast.multihead_latent_attention,
W_UK_T=self.W_UK_T,
W_UV=self.W_UV,
kv_b_proj=self.kv_b_proj_weight_t,
v_head_dim=self.v_head_dim,
**extra_backend_kwargs,
)
attn_output = attention_backend(
q=q_states,
kv_cache=kv_cache,
block_table=attention_meta.block_table_tensor if attention_meta is not None else None,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
query_lens=query_lens,
)
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None
def _get_backend_kwargs(self, pre_attn_out) -> dict:
"""
Hook for subclasses to inject additional arguments into the attention backend.
Default implementation returns an empty dict (standard dense attention).
Subclasses can override this to pass specific parameters (e.g., top-k, window size).
"""
return {}
def quantize_params(self):
assert self.quant_config is not None, "quant_config must be set before quantization"
self._quantize_kv_b_decomposition()
def _quantize_kv_b_decomposition(self) -> None:
"""
Hook: quantize the ``kv_b_proj`` decomposition tensors set up by
:meth:`_setup_kv_b_decomposition`. Subclasses that override the setup hook
should typically override this one as well (see DeepSeek V4 wrapper).
"""
out_dtype = self.quant_config.get_quant_dtype()
kv_b_proj = self.kv_b_proj
if not isinstance(kv_b_proj, TensorCastQuantLinear):
raise ValueError("MLA quantization requires kv_b_proj to be quantized")
self.kv_b_proj_scale = kv_b_proj.weight_scale
self.kv_b_proj_offset = kv_b_proj.weight_offset
self.kv_b_proj_weight_t = torch.ops.tensor_cast.quantize(
self.kv_b_proj_weight_t,
self.kv_b_proj_scale,
self.kv_b_proj_offset,
out_dtype,
)
self.W_UK_T = torch.ops.tensor_cast.quantize(
self.W_UK_T,
self.kv_b_proj_scale,
self.kv_b_proj_offset,
out_dtype,
)
self.W_UV = torch.ops.tensor_cast.quantize(
self.W_UV,
self.kv_b_proj_scale,
self.kv_b_proj_offset,
out_dtype,
)
def _resolve_sparse_topk_limit(
indexer: torch.nn.Module,
config: Optional[torch.nn.Module] = None,
topk_limit: Optional[int] = None,
) -> int:
if topk_limit is not None:
return topk_limit
inner_topk_limit = getattr(indexer, "topk_limit", None)
if inner_topk_limit is not None:
return inner_topk_limit
for candidate in (config, getattr(indexer, "config", None)):
candidate_topk_limit = getattr(candidate, "topk_limit", None)
if candidate_topk_limit is not None:
return candidate_topk_limit
candidate_index_topk = getattr(candidate, "index_topk", None)
if candidate_index_topk is not None:
return candidate_index_topk
raise AttributeError("topk_limit")
class DeepseekSparseAttention(MultiheadLatentAttentionTensorCast):
@classmethod
def requires_indexer_cache(cls) -> bool:
return True
def __init__(
self,
mla_config: MlaConfig,
mla_module: torch.nn.Module,
tp_group: ParallelGroup,
decode_only: bool = False,
):
super().__init__(mla_config, mla_module, tp_group, decode_only)
self.indexer = DeepseekSparseAttentionIndexer(
self._inner.indexer,
topk_limit=_resolve_sparse_topk_limit(
self._inner.indexer,
config=getattr(self._inner, "config", None),
),
)
def _get_backend_kwargs(self, pre_attn_out) -> dict:
"""
Sparse Attention Args:
- topk_limit: int, number of selected sparse tokens
- topk_indices: Tensor, precomputed sparse position indices
"""
return {
"topk_limit": self.indexer.topk_limit,
"topk_indices": pre_attn_out,
}
def _pre_attention_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_meta: Optional[AttentionMetadataBase] = None,
qa_normed: Optional[torch.Tensor] = None,
**kwargs,
):
return self._run_sparse_attention_indexer(
hidden_states, qa_normed, position_embeddings, attention_meta, **kwargs
)
def _run_sparse_attention_indexer(
self,
hidden_states: torch.Tensor,
qa_normed: Optional[torch.Tensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_meta: Optional[AttentionMetadataBase] = None,
**kwargs,
):
if qa_normed is None:
return None
indexer_cache_by_layers = kwargs.pop("indexer_cache_by_layers", None)
indexer_cache = indexer_cache_by_layers[self.layer_idx] if indexer_cache_by_layers else None
return self.indexer(hidden_states, qa_normed, position_embeddings, indexer_cache, attention_meta)
class DeepseekSparseAttentionIndexer(ModelWrapperBase):
def __init__(self, indexer, topk_limit: Optional[int] = None):
super().__init__(indexer)
self._topk_limit = _resolve_sparse_topk_limit(indexer, topk_limit=topk_limit)
@property
def num_heads(self) -> int:
if hasattr(self._inner, "num_heads"):
return self._inner.num_heads
return self._inner.n_heads
@property
def head_dim(self) -> int:
if hasattr(self._inner, "head_dim"):
return self._inner.head_dim
return self._inner.index_head_dim
@property
def topk_limit(self) -> int:
return self._topk_limit
def forward(
self,
hidden_states: torch.Tensor,
qa_normed: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
indexer_cache: torch.Tensor,
attention_meta: Optional[AttentionMetadataBase] = None,
):
cos, sin = position_embeddings
wq_b_weight, _, _ = MultiheadLatentAttentionTensorCast.extract_qparams(self.wq_b)
wk_weight, _, _ = MultiheadLatentAttentionTensorCast.extract_qparams(self.wk)
weights_proj_weight, _, _ = MultiheadLatentAttentionTensorCast.extract_qparams(self.weights_proj)
return torch.ops.tensor_cast.dsa_indexer(
hidden_states,
qa_normed,
cos,
sin,
indexer_cache,
attention_meta.slot_mapping if attention_meta else None,
attention_meta.block_table_tensor if attention_meta else None,
attention_meta.seq_lens if attention_meta is not None else None,
wq_b_weight,
wk_weight,
weights_proj_weight,
self.k_norm.weight,
self.num_heads,
self.head_dim,
self.qk_rope_head_dim,
self.topk_limit,
)