import unittest
import pytest
import torch
from parameterized import parameterized
from tensor_cast.compilation import get_backend
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import TEST_DEVICE
from tensor_cast.layers.attention import AttentionTensorCast
from tensor_cast.layers.internal import CopyLayerWrapper, RegionMarkerWrapper
from tensor_cast.layers.sampler import SamplingMetadata
from tensor_cast.model_config import ModelConfig, ParallelConfig, QuantConfig
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.performance_model.memory_tracker import MemoryTracker
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.custom_model_registry import get_mtp_block_module_name
from tensor_cast.transformers.model import TransformerModel
from .conftest import get_session_hf_config
from .test_common import (
assert_close,
create_mla_metadata_and_kv_cache,
get_cached_build_model,
has_submodule_with_cls_name,
)
class RepetitionTestMixin:
_model_cache: dict = {}
@classmethod
def setUpClass(cls):
cls._transformer_cache = {}
@classmethod
def _get_transformer_model(cls, model_id: str, model_config: ModelConfig) -> TransformerModel:
key = (model_id, repr(model_config))
if key not in cls._transformer_cache:
cls._transformer_cache[key] = TransformerModel(model_id, model_config)
return cls._transformer_cache[key]
def setUp(self):
torch.compiler.reset()
def check_num_effective_layers(self, layers, expected_num):
count = sum(1 for layer in layers if not isinstance(layer, CopyLayerWrapper))
self.assertEqual(count, expected_num, f"{layers}")
def check_representative_layers(self, layers, expected_repeat_counts):
region_layers = [layer for layer in layers if isinstance(layer, RegionMarkerWrapper)]
self.assertEqual(len(region_layers), len(expected_repeat_counts), f"{layers}")
for layer, expected_repeat_count in zip(region_layers, expected_repeat_counts):
self.assertEqual(layer.repeat_count, expected_repeat_count)
def check_copy_layers_hidden(self, layers):
copy_layers = [layer for layer in layers if isinstance(layer, CopyLayerWrapper)]
self.assertTrue(copy_layers, f"{layers}")
for layer in copy_layers:
self.assertEqual(list(layer.children()), [])
self.assertEqual(list(layer.named_children()), [])
class RepetitionTestCase(RepetitionTestMixin, unittest.TestCase):
def _run_test_vanilla_transformer_model(self, model_id, do_compile):
num_tokens = 100
model_config = ModelConfig(
ParallelConfig(),
QuantConfig(),
attention_cls=AttentionTensorCast,
num_hidden_layers_override=3,
)
model_config_with_repeats = ModelConfig(
ParallelConfig(),
QuantConfig(),
attention_cls=AttentionTensorCast,
num_hidden_layers_override=3,
enable_repetition=True,
)
model_config.hf_config = get_session_hf_config(model_id)
model_config_with_repeats.hf_config = get_session_hf_config(model_id)
model = self._get_transformer_model(model_id, model_config)
model_with_repeats = self._get_transformer_model(model_id, model_config_with_repeats)
self.check_num_effective_layers(model_with_repeats.unwrap().layers, 1)
self.assertEqual(len(model_with_repeats.unwrap().layers), 3)
self.check_representative_layers(model_with_repeats.unwrap().layers, [3])
self.check_copy_layers_hidden(model_with_repeats.unwrap().layers)
if do_compile:
model = torch.compile(model, backend=get_backend(), dynamic=True, fullgraph=True)
model_with_repeats = torch.compile(model_with_repeats, backend=get_backend(), dynamic=True, fullgraph=True)
inputs = torch.empty([2, num_tokens], dtype=torch.long, device="meta")
position_ids = torch.empty([2, num_tokens], dtype=torch.long, device="meta")
device_profile = TEST_DEVICE
perf_model = AnalyticPerformanceModel(device_profile)
with (
Runtime(perf_model, device_profile, MemoryTracker(device_profile)) as runtime,
torch.no_grad(),
):
outputs = model.forward(inputs, position_ids)
self.assertEqual(outputs.shape, (2, num_tokens, model.vocab_size))
with (
Runtime(perf_model, device_profile, MemoryTracker(device_profile)) as runtime_with_repeats,
torch.no_grad(),
):
outputs = model_with_repeats.forward(inputs, position_ids)
self.assertEqual(outputs.shape, (2, num_tokens, model_with_repeats.vocab_size))
assert_close(
self,
len(runtime.event_list),
len(runtime_with_repeats.event_list),
rtol=0.027 if do_compile else 0,
)
runtime_cost_s = runtime.total_execution_time_s()[perf_model.name]
runtime_cost_with_repeats_s = runtime_with_repeats.total_execution_time_s()[perf_model.name]
assert_close(
self,
runtime_cost_s,
runtime_cost_with_repeats_s,
rtol=0.01 if do_compile else 0,
)
peak_mem_usage = runtime.memory_tracker.peak_mem_usage()
peak_mem_usage_with_repeats = runtime_with_repeats.memory_tracker.peak_mem_usage()
assert_close(
self,
peak_mem_usage,
peak_mem_usage_with_repeats,
)
def _run_test_deepseek_with_kvcache(self, model_id):
num_mtp_layers = 3
user_config = UserInputConfig(
model_id=model_id,
num_mtp_tokens=num_mtp_layers,
)
model = get_cached_build_model(RepetitionTestMixin._model_cache, user_config)
mtp_block_module_name = get_mtp_block_module_name(model.model_config.hf_config.model_type)
self.assertIsNotNone(mtp_block_module_name)
self.check_num_effective_layers(model.unwrap().layers, 2)
self.assertEqual(len(model.unwrap().layers), model.text_config.num_hidden_layers)
self.check_copy_layers_hidden(model.unwrap().layers)
if model_id == "deepseek-ai/DeepSeek-V3.1":
self.check_representative_layers(model.unwrap().layers, [3, 58])
else:
self.assertEqual(
sum(layer.repeat_count for layer in model.unwrap().layers if isinstance(layer, RegionMarkerWrapper)),
model.text_config.num_hidden_layers,
)
self.check_num_effective_layers(model._inner.mtp.layers, 1)
attn_meta, kv_cache_by_layers, num_tokens = create_mla_metadata_and_kv_cache(model, model.model_config)
self.assertTrue(has_submodule_with_cls_name(model, "MultiheadLatentAttentionTensorCast"))
inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
device_profile = TEST_DEVICE
perf_model = AnalyticPerformanceModel(device_profile)
with (
Runtime(perf_model, device_profile) as runtime,
torch.no_grad(),
):
outputs = model.forward(
inputs,
position_ids,
attention_meta=attn_meta,
kv_cache_by_layers=kv_cache_by_layers,
sampling_metadata=SamplingMetadata(
query_start_loc=attn_meta.query_start_loc,
selected_token_indices=None,
),
)
self.assertEqual(outputs.shape, (2, num_mtp_layers + 1))
result = runtime.table_averages()
start_str = "tensor_cast.multihead_latent_attention.default"
end_str = "64"
found = any(
line.strip().startswith(start_str) and line.strip().endswith(end_str) for line in result.splitlines()
)
self.assertTrue(found, result)
@parameterized.expand(
[
["Qwen/Qwen3-32B", False],
]
)
def test_vanilla_transformer_model(self, model_id, do_compile):
self._run_test_vanilla_transformer_model(model_id, do_compile)
@parameterized.expand(
[
["deepseek-ai/DeepSeek-V3.1"],
]
)
def test_deepseek_with_kvcache(self, model_id):
self._run_test_deepseek_with_kvcache(model_id)
@pytest.mark.nightly
class RepetitionNightlyTestCase(RepetitionTestMixin, unittest.TestCase):
@parameterized.expand(
[
["Qwen/Qwen3-32B"],
]
)
def test_vanilla_transformer_model(self, model_id):
RepetitionTestCase._run_test_vanilla_transformer_model(self, model_id, True)
@parameterized.expand(
[
["moonshotai/Kimi-K2-Base"],
]
)
def test_deepseek_with_kvcache(self, model_id):
RepetitionTestCase._run_test_deepseek_with_kvcache(self, model_id)