import unittest
import pytest
import torch
from tensor_cast.core.input_generator import RequestInfo, generate_inputs
from tensor_cast.core.quantization.datatypes import QuantizeAttentionAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.runtime import Runtime
from .conftest import get_session_model
class TestDeepseekV32Model(unittest.TestCase):
def test_model_init(self):
model_id = "deepseek-ai/DeepSeek-V3.2"
num_queries = 3500
user_input = UserInputConfig(
model_id=model_id,
num_queries=1,
query_len=num_queries,
context_length=num_queries,
device="TEST_DEVICE",
num_mtp_tokens=2,
quantize_attention_action=QuantizeAttentionAction.INT8,
)
model = get_session_model(user_input)
inputs = generate_inputs(
model,
[
RequestInfo(
query_len=num_queries,
seq_len=num_queries,
concurrency=1,
is_decode=True,
)
],
)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config) as runtime, torch.no_grad():
model.forward(**inputs)
result = runtime.table_averages()
self.assertIn("tensor_cast.multihead_latent_attention_quant.default", result)
self.assertIn("tensor_cast.dsa_indexer.default", result)
total_time_s = runtime.total_execution_time_s()[perf_model.name]
self.assertGreater(total_time_s, 0)
@pytest.mark.nightly
class TestDeepseekV32ModelNightly(unittest.TestCase):
def test_deepseek_v32_mla_performance(self):
def get_mla_time(model_id: str, seq_len: int) -> float:
user_input = UserInputConfig(
model_id=model_id,
num_queries=1,
query_len=seq_len,
context_length=seq_len,
device="TEST_DEVICE",
num_mtp_tokens=2,
quantize_attention_action=QuantizeAttentionAction.INT8,
)
model = get_session_model(user_input)
inputs = generate_inputs(
model,
[
RequestInfo(
query_len=seq_len,
seq_len=seq_len,
concurrency=1,
is_decode=True,
)
],
)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config) as runtime, torch.no_grad():
model.forward(**inputs)
total_time = 0.0
for event in runtime.event_list:
func_name = str(event.op_invoke_info.func)
if "multihead_latent_attention_quant" in func_name:
total_time = event.perf_results.get("analytic").execution_time_s
return total_time
seq_len = 3500
time_v31 = get_mla_time("deepseek-ai/DeepSeek-V3.1", seq_len)
time_v32 = get_mla_time("deepseek-ai/DeepSeek-V3.2", seq_len)
self.assertGreater(time_v31, 0)
self.assertGreater(time_v32, 0)
self.assertLess(time_v32, time_v31)