from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import torch
from tensor_cast.core.input_generator import (
RequestInfo,
_is_v4_model,
_layer_uses_sparse_attention_indexer,
_load_preprocessor_pixel_limits,
_resolve_decoder_layers,
_resolve_indexer_cache_dtype,
_resolve_main_kv_cache_dtype,
_resolve_sparse_attention_indexer_cache_width,
_resolve_sparse_attention_kv_cache_width,
_resolve_v4_kv_cache_size,
generate_image_inputs,
generate_inputs,
generate_inputs_varlen,
get_sparse_attention_indexer_cache_info,
resize_image,
)
from tensor_cast.layers.deepseek_v4 import DeepseekV4SparseAttention
from tensor_cast.layers.sampler import Sampler
from tensor_cast.model_config import MtpConfig
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.model import TransformerModel
def _fake_mtp_input_model(num_mtp_tokens=2):
return SimpleNamespace(
is_vl_model=False,
num_hidden_layers=0,
model_config=SimpleNamespace(
mtp_config=MtpConfig(
num_mtp_layers=num_mtp_tokens,
mtp_block_module_name="DeepseekV3DecoderLayer",
),
parallel_config=SimpleNamespace(data_parallel_size=1, tensor_parallel_size=1),
mla_config=None,
hf_config=SimpleNamespace(model_type="deepseek_v3"),
),
)
def _proposal_indices(spec_metadata):
spec_window = spec_metadata.num_speculative_tokens + 1
return spec_metadata.logits_indices.view(spec_metadata.num_active_requests, spec_window)[:, -1]
@pytest.mark.parametrize("is_decode", [True, False])
def test_selected_token_indices_for_lmhead(qwen3_32b_lmhead_attention_transformer: TransformerModel, is_decode):
model = qwen3_32b_lmhead_attention_transformer
query_len = 100
batch_size = 2
inputs = generate_inputs(
model,
[
RequestInfo(
query_len=query_len,
seq_len=query_len,
concurrency=batch_size,
is_decode=is_decode,
)
],
)
if is_decode:
output_shape = (1, batch_size * query_len, model.vocab_size)
else:
output_shape = (1, batch_size, model.vocab_size)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config), torch.no_grad():
outputs = model.forward(**inputs)
assert outputs.shape == output_shape
def test_generate_inputs_mtp_decode_does_not_select_all_packed_rows_for_target_lm_head():
inputs = generate_inputs(
_fake_mtp_input_model(),
[RequestInfo(query_len=5, seq_len=16, concurrency=2, is_decode=True)],
)
spec_metadata = inputs["sampling_metadata"].spec_decode_metadata
assert spec_metadata.logits_indices.tolist() != list(range(10))
assert spec_metadata.logits_indices.tolist() == [2, 3, 4, 7, 8, 9]
assert _proposal_indices(spec_metadata).tolist() == [4, 9]
assert spec_metadata.num_active_requests == 2
assert spec_metadata.num_speculative_tokens == 2
assert inputs["sampling_metadata"].selected_token_indices is None
def test_generate_inputs_varlen_mtp_decode_does_not_reuse_padded_prefix_rows():
inputs = generate_inputs_varlen(
_fake_mtp_input_model(),
[
RequestInfo(query_len=5, seq_len=16, is_decode=True),
RequestInfo(query_len=3, seq_len=12, is_decode=True),
],
block_size=128,
)
spec_metadata = inputs["sampling_metadata"].spec_decode_metadata
assert spec_metadata.logits_indices.tolist() != list(range(8))
assert spec_metadata.logits_indices.tolist() == [2, 3, 4, 5, 6, 7]
assert _proposal_indices(spec_metadata).tolist() == [4, 7]
assert spec_metadata.num_active_requests == 2
assert spec_metadata.num_speculative_tokens == 2
assert inputs["sampling_metadata"].selected_token_indices is None
def test_generate_inputs_varlen_mtp_decode_uses_ordinary_selection_for_short_query_window():
inputs = generate_inputs_varlen(
_fake_mtp_input_model(),
[
RequestInfo(query_len=3, seq_len=16, is_decode=True),
RequestInfo(query_len=2, seq_len=12, is_decode=True),
],
block_size=128,
)
sampling_metadata = inputs["sampling_metadata"]
next_tokens = Sampler()(torch.empty(1, 5, 8, device="meta"), sampling_metadata)
assert sampling_metadata.spec_decode_metadata is None
assert sampling_metadata.selected_token_indices is None
assert next_tokens.shape == (2, 1)
@pytest.mark.parametrize("is_decode", [True, False])
def test_varlen_selected_token_indices_for_lmhead(qwen3_32b_lmhead_attention_transformer: TransformerModel, is_decode):
model = qwen3_32b_lmhead_attention_transformer
query_len = [90, 110]
batch_size = len(query_len)
request_infos = []
for i in range(batch_size):
request_infos.append(
RequestInfo(
query_len=query_len[i],
seq_len=query_len[i],
is_decode=is_decode,
)
)
inputs = generate_inputs_varlen(model, request_infos, 128)
if is_decode:
output_shape = (1, sum(query_len), model.vocab_size)
else:
output_shape = (1, batch_size, model.vocab_size)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config), torch.no_grad():
outputs = model.forward(**inputs)
assert outputs.shape == output_shape
@patch(
"tensor_cast.core.input_generator.get_sparse_attention_indexer_cache_info",
return_value={},
)
@patch("tensor_cast.core.input_generator._get_kv_cache_info", return_value=({}, 0))
def test_varlen_qwen3_5_cache_position_starts_at_context(_mock_kv_cache, _mock_sparse_cache):
model = SimpleNamespace(
model_config=SimpleNamespace(
hf_config=SimpleNamespace(model_type="qwen3_5"),
mtp_config=None,
)
)
requests = [
RequestInfo(query_len=1, seq_len=2199, is_decode=True, context_length=2198),
RequestInfo(query_len=2, seq_len=12, is_decode=True, context_length=10),
]
inputs = generate_inputs_varlen(model, requests, 128)
assert torch.equal(inputs["cache_position"], torch.tensor([2198, 10, 11], dtype=torch.long))
assert inputs["cache_position"].tensor_cast_query_lens == (1, 2)
assert inputs["cache_position"].tensor_cast_is_decode == (True, True)
assert inputs["cache_position"].tensor_cast_has_previous_state
@patch(
"tensor_cast.core.input_generator.get_sparse_attention_indexer_cache_info",
return_value={},
)
@patch("tensor_cast.core.input_generator._get_kv_cache_info", return_value=({}, 0))
def test_qwen3_5_decode_mtp_cache_position_metadata(_mock_kv_cache, _mock_sparse_cache):
model = SimpleNamespace(
is_vl_model=False,
model_config=SimpleNamespace(
hf_config=SimpleNamespace(model_type="qwen3_5"),
mtp_config=SimpleNamespace(num_mtp_layers=3),
parallel_config=SimpleNamespace(data_parallel_size=1),
),
)
inputs = generate_inputs(
model,
[
RequestInfo(
query_len=4,
seq_len=2202,
concurrency=21,
is_decode=True,
context_length=2198,
)
],
)
cache_position = inputs["cache_position"]
assert torch.equal(cache_position, torch.arange(2198, 2198 + 84, dtype=torch.long))
assert cache_position.tensor_cast_query_lens == (4,) * 21
assert cache_position.tensor_cast_is_decode == (True,) * 21
assert cache_position.tensor_cast_has_previous_state
assert cache_position.tensor_cast_base_decode_query_len == 1
assert cache_position.tensor_cast_num_mtp_tokens == 3
assert cache_position.tensor_cast_effective_decode_steps == 4
@patch("tensor_cast.core.input_generator.get_sparse_attention_indexer_cache_info", return_value={})
@patch("tensor_cast.core.input_generator._get_kv_cache_info", return_value=({}, 0))
def test_generate_inputs_sets_max_total_seq_len(_mock_kv_cache, _mock_sparse_cache):
model = SimpleNamespace(
is_vl_model=False,
model_config=SimpleNamespace(
hf_config=SimpleNamespace(model_type="deepseek_v4"),
mtp_config=None,
parallel_config=SimpleNamespace(data_parallel_size=1),
),
)
inputs = generate_inputs(
model,
[
RequestInfo(
query_len=102400,
context_length=921600,
seq_len=1024000,
concurrency=1,
is_decode=False,
)
],
)
attention_meta = inputs["attention_meta"]
assert attention_meta.max_total_seq_len == 1024000
assert int(attention_meta.seq_lens.max().item()) == attention_meta.max_total_seq_len
@patch("tensor_cast.core.input_generator.get_sparse_attention_indexer_cache_info", return_value={})
@patch("tensor_cast.core.input_generator._get_kv_cache_info", return_value=({}, 0))
def test_generate_inputs_varlen_sets_max_total_seq_len(_mock_kv_cache, _mock_sparse_cache):
model = SimpleNamespace(
model_config=SimpleNamespace(
hf_config=SimpleNamespace(model_type="deepseek_v4"),
mtp_config=None,
),
)
requests = [
RequestInfo(query_len=102400, context_length=921600, seq_len=1024000, is_decode=False),
RequestInfo(query_len=8192, context_length=253952, seq_len=262144, is_decode=False),
]
inputs = generate_inputs_varlen(model, requests, 128)
attention_meta = inputs["attention_meta"]
assert attention_meta.max_total_seq_len == 1024000
assert int(attention_meta.seq_lens.max().item()) == attention_meta.max_total_seq_len
def test_resize_image_uses_local_preprocessor_config(tmp_path):
_load_preprocessor_pixel_limits.cache_clear()
(tmp_path / "preprocessor_config.json").write_text(
'{"size": {"shortest_edge": 65536, "longest_edge": 16777216}}',
encoding="utf-8",
)
resized_height, resized_width = resize_image(
str(tmp_path),
"qwen3_5",
1080,
1920,
patch_size=16,
merge_size=2,
temporal_patch_size=2,
)
_load_preprocessor_pixel_limits.cache_clear()
assert (resized_height, resized_width) == (1088, 1920)
def test_read_preprocessor_config_invalid_json_returns_none(tmp_path):
_load_preprocessor_pixel_limits.cache_clear()
(tmp_path / "preprocessor_config.json").write_text("not valid json", encoding="utf-8")
with patch("tensor_cast.core.input_generator.logger") as mock_logger:
from tensor_cast.core.input_generator import _read_preprocessor_config
result = _read_preprocessor_config(tmp_path / "preprocessor_config.json")
mock_logger.debug.assert_called_once()
assert result is None
def test_read_preprocessor_config_missing_file_returns_none(tmp_path):
_load_preprocessor_pixel_limits.cache_clear()
from tensor_cast.core.input_generator import _read_preprocessor_config
result = _read_preprocessor_config(tmp_path / "nonexistent.json")
assert result is None
def test_resolve_local_preprocessor_config_non_dir_returns_none():
_load_preprocessor_pixel_limits.cache_clear()
from tensor_cast.core.input_generator import _resolve_local_preprocessor_config
result = _resolve_local_preprocessor_config("not/a/real/path")
assert result is None
def test_load_preprocessor_pixel_limits_no_config_json_returns_none(tmp_path):
_load_preprocessor_pixel_limits.cache_clear()
min_px, max_px = _load_preprocessor_pixel_limits(str(tmp_path))
assert min_px is None
assert max_px is None
_DSA_INDEXER_CACHE_QUERY_LEN = 32
_DSA_INDEXER_CACHE_NUM_MTP_TOKENS = 2
_DSA_INDEXER_CACHE_BLOCK_SIZE = 128
_DSA_INDEXER_CACHE_NUM_BLOCKS = (
_DSA_INDEXER_CACHE_QUERY_LEN + _DSA_INDEXER_CACHE_NUM_MTP_TOKENS + 1 + _DSA_INDEXER_CACHE_BLOCK_SIZE - 1
) // _DSA_INDEXER_CACHE_BLOCK_SIZE
def test_dsa_indexer_cache_dtype_follows_attention_quant_config(
deepseek_v32_build_model_int8,
):
model = deepseek_v32_build_model_int8
cache_info = get_sparse_attention_indexer_cache_info(
model,
num_blocks=_DSA_INDEXER_CACHE_NUM_BLOCKS,
block_size=_DSA_INDEXER_CACHE_BLOCK_SIZE,
)
assert cache_info["indexer_cache_by_layers"][0].dtype == torch.int8
def test_dsa_indexer_cache_dtype_uses_fp8_when_attention_quant_is_fp8(
deepseek_v32_build_model_fp8,
):
model = deepseek_v32_build_model_fp8
cache_info = get_sparse_attention_indexer_cache_info(
model,
num_blocks=_DSA_INDEXER_CACHE_NUM_BLOCKS,
block_size=_DSA_INDEXER_CACHE_BLOCK_SIZE,
)
assert cache_info["indexer_cache_by_layers"][0].dtype == torch.float8_e4m3fn
def test_qwen3_vl_1080p_resize_to_1088x1920(tmp_path):
_load_preprocessor_pixel_limits.cache_clear()
(tmp_path / "preprocessor_config.json").write_text(
'{"size": {"shortest_edge": 65536, "longest_edge": 16777216}}',
encoding="utf-8",
)
model = SimpleNamespace(
model_id=str(tmp_path),
model_config=SimpleNamespace(
dtype=torch.bfloat16,
parallel_config=SimpleNamespace(data_parallel_size=1),
hf_config=SimpleNamespace(
model_type="qwen3_vl",
vision_config=SimpleNamespace(
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
in_channels=3,
),
),
),
)
image_kwargs = generate_image_inputs(
model=model,
image_batch_size=1,
image_height=1080,
image_width=1920,
concurrency=1,
)
assert torch.equal(image_kwargs["image_grid_thw"], torch.tensor([[1, 68, 120]]))
class TestSparseAttentionCacheHelpers:
def test_resolve_kv_cache_width_from_attention_head_dim(self):
model = MagicMock()
model.text_config.kv_lora_rank = 512
model.text_config.qk_rope_head_dim = 64
attention = MagicMock(head_dim=480, _head_dim=None)
assert _resolve_sparse_attention_kv_cache_width(model, attention) == 480
def test_resolve_kv_cache_width_fallback_without_layer(self):
model = MagicMock()
model.text_config.kv_lora_rank = 512
model.text_config.qk_rope_head_dim = 64
assert _resolve_sparse_attention_kv_cache_width(model, None) == 576
def test_resolve_indexer_cache_width_prefers_index_head_dim(self):
model = MagicMock()
model.text_config.index_head_dim = 999
attention = MagicMock(_index_head_dim=128, indexer=None)
assert _resolve_sparse_attention_indexer_cache_width(model, attention) == 128
def test_resolve_indexer_cache_width_from_indexer_module(self):
model = MagicMock()
model.text_config.index_head_dim = None
attention = MagicMock(_index_head_dim=None, indexer=MagicMock(head_dim=64))
assert _resolve_sparse_attention_indexer_cache_width(model, attention) == 64
def test_layer_uses_sparse_attention_indexer(self):
assert _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=True))
assert not _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=False, indexer=None))
assert _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=None, indexer=object()))
def test_resolve_decoder_layers_direct_layout(self):
layers = [MagicMock(), MagicMock()]
model = MagicMock()
model.unwrap.return_value = SimpleNamespace(layers=layers)
assert _resolve_decoder_layers(model) is layers
def test_resolve_decoder_layers_nested_causal_lm_layout(self):
nested_layers = [MagicMock()]
model = MagicMock()
model.unwrap.return_value = SimpleNamespace(model=SimpleNamespace(layers=nested_layers))
assert _resolve_decoder_layers(model) is nested_layers
@patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
def test_get_sparse_attention_indexer_cache_info_v4_layers(self, _mock_attn_quant):
model = MagicMock()
model.num_hidden_layers = 2
model.model_config.mla_config = MagicMock(mla_cls=DeepseekV4SparseAttention)
model.model_config.dtype = torch.bfloat16
model.unwrap.return_value = MagicMock(
layers=[
MagicMock(self_attn=MagicMock(use_indexer=False, indexer=None)),
MagicMock(
self_attn=MagicMock(
use_indexer=True,
_index_head_dim=128,
indexer=None,
)
),
]
)
info = get_sparse_attention_indexer_cache_info(model, num_blocks=4, block_size=16)
assert 1 in info["indexer_cache_by_layers"]
assert info["indexer_cache_by_layers"][1].shape == (4, 16, 128)
assert info["indexer_cache_per_token"] > 0
class TestDeepseekV4KvCacheHelpers:
@pytest.mark.parametrize(
("hf_model_type", "text_model_type", "expected"),
[
("deepseek_v4", None, True),
(None, "deepseek_v4", True),
("deepseek_v32", "deepseek_v32", False),
(None, None, False),
],
)
def test_is_v4_model(self, hf_model_type, text_model_type, expected):
model = MagicMock()
model.model_config.hf_config = MagicMock(model_type=hf_model_type) if hf_model_type is not None else None
model.text_config = MagicMock(model_type=text_model_type) if text_model_type is not None else None
assert _is_v4_model(model) is expected
@patch("tensor_cast.core.input_generator.get_attention_quant_config")
def test_resolve_main_kv_cache_dtype_v4_ignores_attention_quant(self, mock_get_attn_quant):
mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.float8_e4m3fn)
model = MagicMock()
model.model_config.dtype = torch.bfloat16
model.model_config.hf_config = MagicMock(model_type="deepseek_v4")
model.text_config = None
assert _resolve_main_kv_cache_dtype(model, 0) == torch.bfloat16
@patch("tensor_cast.core.input_generator.get_attention_quant_config")
def test_resolve_main_kv_cache_dtype_non_v4_uses_attention_quant(self, mock_get_attn_quant):
mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.float8_e4m3fn)
model = MagicMock()
model.model_config.dtype = torch.bfloat16
model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
model.text_config = None
assert _resolve_main_kv_cache_dtype(model, 0) == torch.float8_e4m3fn
@patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
def test_resolve_main_kv_cache_dtype_non_v4_fallback_to_model_dtype(self, _mock_get_attn_quant):
model = MagicMock()
model.model_config.dtype = torch.bfloat16
model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
model.text_config = None
assert _resolve_main_kv_cache_dtype(model, 0) == torch.bfloat16
@patch("tensor_cast.core.input_generator.get_attention_quant_config")
def test_resolve_indexer_cache_dtype_uses_attention_quant(self, mock_get_attn_quant):
mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.int8)
model = MagicMock()
model.model_config.dtype = torch.bfloat16
assert _resolve_indexer_cache_dtype(model, 0) == torch.int8
@patch(
"tensor_cast.core.input_generator._resolve_sparse_attention_kv_cache_width",
return_value=576,
)
def test_resolve_v4_kv_cache_size_compressed_sparse_layer(self, _mock_head_dim):
model = MagicMock()
model.text_config.sliding_window = 128
model.text_config.kv_lora_rank = 512
model.text_config.qk_rope_head_dim = 64
attention_layer = MagicMock(compress_ratio=4, head_dim=576)
shape = _resolve_v4_kv_cache_size(
model,
attention_layer=attention_layer,
num_blocks=100,
block_size=128,
batch_size=2,
total_kv_tokens=8192,
)
assert shape == [18, 128, 576]
@patch(
"tensor_cast.core.input_generator._resolve_sparse_attention_kv_cache_width",
return_value=480,
)
def test_resolve_v4_kv_cache_size_fallback_without_batch_info(self, _mock_head_dim):
model = MagicMock()
model.text_config.sliding_window = 128
attention_layer = MagicMock(compress_ratio=4)
shape = _resolve_v4_kv_cache_size(
model,
attention_layer=attention_layer,
num_blocks=42,
block_size=128,
)
assert shape == [42, 128, 480]
@patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
def test_v4_indexer_cache_compression_only_for_v4_model(self, _mock_attn_quant):
model = MagicMock()
model.num_hidden_layers = 1
model.model_config.mla_config = MagicMock(mla_cls=DeepseekV4SparseAttention)
model.model_config.dtype = torch.bfloat16
model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
model.text_config = None
model.unwrap.return_value = MagicMock(
layers=[
MagicMock(
self_attn=MagicMock(
use_indexer=True,
_index_head_dim=128,
indexer=None,
compress_ratio=4,
)
)
]
)
info = get_sparse_attention_indexer_cache_info(
model,
num_blocks=100,
block_size=128,
batch_size=2,
total_kv_tokens=8192,
)
assert info["indexer_cache_by_layers"][0].shape[0] == 100