import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from tensor_cast.layers.mla import (
DeepseekSparseAttention,
DeepseekSparseAttentionIndexer,
MultiheadLatentAttentionTensorCast,
)
from tensor_cast.layers.quant_linear import TensorCastQuantLinear
from tensor_cast.model_config import LinearQuantConfig
from tensor_cast.quantize_utils import LinearQuantType, QuantGranularity, QuantScheme
from tensor_cast.transformers.builtin_model.deepseek_v32 import DeepseekV32Config
class TestMlaIndexerCacheHooks(unittest.TestCase):
def test_base_requires_indexer_cache_is_false(self):
self.assertFalse(MultiheadLatentAttentionTensorCast.requires_indexer_cache())
def test_deepseek_sparse_requires_indexer_cache(self):
self.assertTrue(DeepseekSparseAttention.requires_indexer_cache())
def test_base_build_tp_plan_extras_empty(self):
self.assertEqual(
MultiheadLatentAttentionTensorCast.build_tp_plan_extras("layers", {}, SimpleNamespace()),
{},
)
self.assertEqual(
MultiheadLatentAttentionTensorCast.build_o_proj_tp_plan_extras("layers", {}, SimpleNamespace()),
{},
)
def test_setup_kv_b_decomposition_splits_projection(self):
num_heads = 4
kv_lora_rank = 64
qk_nope = 32
v_head = 16
kv_b_proj = nn.Linear(kv_lora_rank, num_heads * (qk_nope + v_head), bias=False)
wrapper = SimpleNamespace(
kv_b_proj=kv_b_proj,
num_heads=num_heads,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope,
v_head_dim=v_head,
_num_heads_per_rank=num_heads,
)
tp_group = MagicMock(world_size=1, rank_in_group=0)
MultiheadLatentAttentionTensorCast._setup_kv_b_decomposition(wrapper, tp_group)
self.assertEqual(wrapper.W_UV.shape[0], num_heads)
self.assertEqual(wrapper.W_UK_T.shape[0], num_heads)
@patch("torch.ops.tensor_cast.quantize", side_effect=lambda t, *args, **kwargs: t)
def test_quantize_kv_b_decomposition(self, _mock_quantize):
num_heads = 2
kv_lora_rank = 32
qk_nope = 16
v_head = 8
linear_quant_config = LinearQuantConfig(
weight_scale=torch.ones(1),
quant_type=LinearQuantType.W8A16,
weight_quant_granularity=QuantGranularity.PER_TENSOR,
weight_quant_scheme=QuantScheme.SYMMETRIC,
)
linear = nn.Linear(kv_lora_rank, num_heads * (qk_nope + v_head), bias=False)
setup_wrapper = SimpleNamespace(
kv_b_proj=linear,
num_heads=num_heads,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope,
v_head_dim=v_head,
_num_heads_per_rank=num_heads,
)
tp_group = MagicMock(world_size=1, rank_in_group=0)
MultiheadLatentAttentionTensorCast._setup_kv_b_decomposition(setup_wrapper, tp_group)
quant_kv_b = TensorCastQuantLinear(linear, linear_quant_config)
wrapper = SimpleNamespace(
kv_b_proj=quant_kv_b,
kv_b_proj_weight_t=setup_wrapper.kv_b_proj_weight_t,
W_UK_T=setup_wrapper.W_UK_T,
W_UV=setup_wrapper.W_UV,
quant_config=MagicMock(get_quant_dtype=MagicMock(return_value=torch.int8)),
)
MultiheadLatentAttentionTensorCast._quantize_kv_b_decomposition(wrapper)
self.assertIs(wrapper.kv_b_proj_scale, quant_kv_b.weight_scale)
class TestDeepseekSparseAttentionIndexer(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.seq_len = 10
inner_module = nn.Module()
inner_module.hidden_size = 16
inner_module.num_heads = 4
inner_module.head_dim = 8
inner_module.qk_rope_head_dim = 4
inner_module.topk_limit = 2
inner_module.q_lora_rank = 4
inner_module.wq_b = nn.Linear(
inner_module.q_lora_rank,
inner_module.num_heads * inner_module.head_dim,
bias=False,
)
inner_module.wk = nn.Linear(inner_module.hidden_size, inner_module.head_dim, bias=False)
inner_module.k_norm = nn.LayerNorm(inner_module.head_dim)
inner_module.weights_proj = nn.Linear(inner_module.hidden_size, inner_module.num_heads, bias=False)
inner_module.softmax_scale = inner_module.head_dim**-0.5
self.inner_module = inner_module
self.indexer = DeepseekSparseAttentionIndexer(inner_module)
self.hidden_states = torch.randn(self.batch_size, self.seq_len, inner_module.hidden_size)
self.qa_normed = torch.randn(self.batch_size, self.seq_len, inner_module.q_lora_rank)
self.position_embeddings = (
torch.randn(self.seq_len, inner_module.qk_rope_head_dim),
torch.randn(self.seq_len, inner_module.qk_rope_head_dim),
)
self.indexer_cache = torch.empty(self.batch_size, self.seq_len, inner_module.head_dim)
def test_topk_limit_is_available_on_wrapper(self):
self.assertEqual(self.indexer.topk_limit, 2)
def test_topk_limit_is_cached_on_wrapper(self):
inner_module = nn.Module()
inner_module.config = type("Config", (), {"topk_limit": 7})()
indexer = DeepseekSparseAttentionIndexer(inner_module)
del inner_module.config
self.assertEqual(indexer.topk_limit, 7)
def test_topk_limit_can_be_passed_explicitly(self):
inner_module = nn.Module()
indexer = DeepseekSparseAttentionIndexer(inner_module, topk_limit=11)
self.assertEqual(indexer.topk_limit, 11)
def test_deepseek_config_ignores_glm5_only_field(self):
config = DeepseekV32Config(topk_limit=33)
self.assertEqual(config.topk_limit, 33)
self.assertFalse(hasattr(config, "index_topk"))
def test_glm5_index_topk_config_falls_back_when_topk_limit_is_none(self):
inner_module = nn.Module()
inner_module.topk_limit = None
inner_module.config = type("GlmMoeDsaConfig", (), {"index_topk": 21})()
indexer = DeepseekSparseAttentionIndexer(inner_module)
self.assertEqual(indexer.topk_limit, 21)
@patch("torch.ops.tensor_cast.dsa_indexer")
def test_forward(self, mock_dsa_indexer):
mock_dsa_indexer.return_value = torch.randn(
self.batch_size,
self.seq_len,
min(self.indexer.topk_limit, self.seq_len),
)
res = self.indexer.forward(
self.hidden_states,
self.qa_normed,
self.position_embeddings,
self.indexer_cache,
)
self.assertEqual(
res.shape,
(
self.batch_size,
self.seq_len,
min(self.indexer.topk_limit, self.seq_len),
),
)
mock_dsa_indexer.assert_called_once()
def test_forward_passes_seq_lens_to_op_after_block_tables(self):
attention_meta = SimpleNamespace(
slot_mapping=None,
block_table_tensor=None,
seq_lens=torch.tensor([17, 19], dtype=torch.long),
)
with patch("torch.ops.tensor_cast.dsa_indexer") as mock_dsa_indexer:
mock_dsa_indexer.return_value = torch.randn(
self.batch_size,
self.seq_len,
self.indexer.topk_limit,
)
self.indexer.forward(
self.hidden_states,
self.qa_normed,
self.position_embeddings,
self.indexer_cache,
attention_meta,
)
self.assertTrue(torch.equal(mock_dsa_indexer.call_args.args[7], attention_meta.seq_lens))
def test_dsa_indexer_op_returns_query_major_topk_shape(self):
batch_size = 2
seq_len = 3
topk_limit = 2
out = torch.ops.tensor_cast.dsa_indexer(
torch.randn(batch_size, seq_len, self.inner_module.hidden_size),
torch.randn(batch_size, seq_len, self.inner_module.q_lora_rank),
torch.randn(seq_len, self.inner_module.qk_rope_head_dim),
torch.randn(seq_len, self.inner_module.qk_rope_head_dim),
torch.empty(batch_size, 5, self.inner_module.head_dim),
None,
None,
None,
self.inner_module.wq_b.weight,
self.inner_module.wk.weight,
self.inner_module.weights_proj.weight,
self.inner_module.k_norm.weight,
self.inner_module.num_heads,
self.inner_module.head_dim,
self.inner_module.qk_rope_head_dim,
topk_limit,
)
self.assertEqual(out.shape, (batch_size, seq_len, topk_limit))
def test_dsa_indexer_op_uses_active_sequence_length_for_topk_width(self):
batch_size = 2
seq_len = 1
num_heads = 1
topk_limit = 4
seq_lens = torch.tensor([3, 4], dtype=torch.long)
out = torch.ops.tensor_cast.dsa_indexer(
torch.randn(batch_size, seq_len, self.inner_module.hidden_size),
torch.randn(batch_size, seq_len, self.inner_module.q_lora_rank),
torch.randn(seq_len, self.inner_module.qk_rope_head_dim),
torch.randn(seq_len, self.inner_module.qk_rope_head_dim),
torch.empty(batch_size, 5, self.inner_module.head_dim),
None,
None,
seq_lens,
self.inner_module.wq_b.weight,
self.inner_module.wk.weight,
self.inner_module.weights_proj.weight,
self.inner_module.k_norm.weight,
num_heads,
self.inner_module.head_dim,
self.inner_module.qk_rope_head_dim,
topk_limit,
)
self.assertEqual(out.shape, (batch_size, seq_len, topk_limit))
def test_dsa_indexer_op_compiles_when_seq_lens_is_provided(self):
def fn(
hidden_states,
qa_normed,
cos,
sin,
indexer_cache,
seq_lens,
wq_b_weight,
wk_weight,
weights_proj_weight,
k_norm_weight,
):
return torch.ops.tensor_cast.dsa_indexer(
hidden_states,
qa_normed,
cos,
sin,
indexer_cache,
None,
None,
seq_lens,
wq_b_weight,
wk_weight,
weights_proj_weight,
k_norm_weight,
1,
self.inner_module.head_dim,
self.inner_module.qk_rope_head_dim,
4,
)
compiled = torch.compile(fn, backend="eager", fullgraph=True)
out = compiled(
torch.randn(2, 1, self.inner_module.hidden_size),
torch.randn(2, 1, self.inner_module.q_lora_rank),
torch.randn(1, self.inner_module.qk_rope_head_dim),
torch.randn(1, self.inner_module.qk_rope_head_dim),
torch.empty(2, 5, self.inner_module.head_dim),
torch.tensor([3, 4], dtype=torch.long),
self.inner_module.wq_b.weight,
self.inner_module.wk.weight,
self.inner_module.weights_proj.weight,
self.inner_module.k_norm.weight,
)
self.assertEqual(out.shape, (2, 1, 4))