import unittest
from itertools import product
import torch
from parameterized import parameterized
from tensor_cast.core.quantization.datatypes import QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import (
TEST_DEVICE,
CommGrid,
DeviceProfile,
InterconnectTopology,
)
from tensor_cast.layers.attention import AttentionTensorCast
from tensor_cast.layers.sampler import SamplingMetadata
from tensor_cast.model_config import (
AttentionQuantConfig,
ModelConfig,
MultiheadLatentAttentionQuantConfig,
ParallelConfig,
QuantConfig,
)
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.performance_model.bound_analyzer import StatsKey
from tensor_cast.quantize_utils import AttentionQuantType, LinearQuantType
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.custom_model_registry import get_moe_config, get_mtp_block_module_name
from tensor_cast.transformers.model import TransformerModel
from tensor_cast.utils import DTYPE_FP8, is_fp8_dtype, performance_dtype
from .conftest import get_session_hf_config
from .test_common import (
create_attn_metadata_and_kv_cache,
create_mla_metadata_and_kv_cache,
get_cached_build_model,
has_submodule_with_cls_name,
)
def get_quant_config(
start_layer_id=-1,
end_layer_id=-1,
attn_quant_type: AttentionQuantType = AttentionQuantType.INT8,
):
quant_config = QuantConfig()
config = AttentionQuantConfig(
quant_type=attn_quant_type,
query_scale=torch.tensor(1.0),
kv_scale=torch.tensor(1.0),
attention_prob_scale=torch.tensor(1.0),
)
if start_layer_id == -1 or end_layer_id == -1:
quant_config.attention_configs[-1] = config
for i in range(start_layer_id, end_layer_id):
quant_config.attention_configs[i] = config
return quant_config
def get_mla_quant_config(start_layer_id=-1, end_layer_id=-1):
from .test_common import get_quant_config as get_quant_config_common
quant_config = get_quant_config_common(quant_type=LinearQuantType.W8A8)
config = MultiheadLatentAttentionQuantConfig(
quant_type=AttentionQuantType.INT8,
query_scale=torch.tensor(1.0),
kv_scale=torch.tensor(1.0),
attention_prob_scale=torch.tensor(1.0),
kv_projected_scale=torch.tensor(1.0),
qk_scale=torch.tensor(1.0),
v_scale=torch.tensor(1.0),
out_scale=torch.tensor(1.0),
)
if start_layer_id == -1 or end_layer_id == -1:
quant_config.attention_configs[-1] = config
for i in range(start_layer_id, end_layer_id):
quant_config.attention_configs[i] = config
return quant_config
class TestQuantAttention(unittest.TestCase):
QUANT_TYPES = [AttentionQuantType.INT8, AttentionQuantType.FP8]
@classmethod
def setUpClass(cls):
cls._model_cache = {}
cls._transformer_cache = {}
@classmethod
def _get_transformer_model(cls, model_id: str, model_config: ModelConfig) -> TransformerModel:
key = (model_id, repr(model_config))
if key not in cls._transformer_cache:
cls._transformer_cache[key] = TransformerModel(model_id, model_config)
return cls._transformer_cache[key]
def test_all_torch_float8_dtypes_share_fp8_performance_dtype(self):
fp8_dtypes = [
dtype for name, dtype in vars(torch).items() if name.startswith("float8") and isinstance(dtype, torch.dtype)
]
self.assertGreater(len(fp8_dtypes), 0)
for dtype in fp8_dtypes:
self.assertTrue(is_fp8_dtype(dtype))
self.assertEqual(performance_dtype(dtype), DTYPE_FP8)
def assert_mma_ops_time_positive(self, runtime, op_name):
total_mma_ops_time_s = 0
for event in runtime.event_list:
if op_name not in str(event.op_invoke_info.func):
continue
for result in event.perf_results.values():
total_mma_ops_time_s += result.statistics.get(StatsKey.MMA_OPS, 0)
self.assertGreater(total_mma_ops_time_s, 0)
def test_fp8_mla_quant_mma_ops_time_is_nonzero(self):
q = torch.empty((2, 2, 16), dtype=torch.float8_e4m3fn, device="meta")
kv_cache = torch.empty((2, 32, 12), dtype=torch.float8_e4m3fn, device="meta")
block_table = torch.empty((2, 1), dtype=torch.int32, device="meta")
query_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32)
request_total_seq_lens = torch.tensor([32, 32], dtype=torch.int32)
query_lens = torch.tensor([1, 1], dtype=torch.int32)
W_UK_T = torch.empty((2, 12, 8), dtype=torch.bfloat16, device="meta")
W_UV = torch.empty((2, 8, 16), dtype=torch.bfloat16, device="meta")
scale = torch.tensor(1.0)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config) as runtime, torch.no_grad():
torch.ops.tensor_cast.multihead_latent_attention_quant(
q,
kv_cache,
block_table,
query_start_loc,
request_total_seq_lens,
query_lens,
W_UK_T,
W_UV,
None,
16,
None,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
torch.bfloat16,
)
self.assert_mma_ops_time_positive(runtime, "multihead_latent_attention_quant.default")
def test_fp8_mla_quant_uses_custom_device_fp8_variant(self):
device = DeviceProfile(
name=f"TEST_DEVICE_CUSTOM_FP8_E5M2_{id(self)}",
vendor="TEST_VENDOR",
mma_ops={torch.float8_e5m2: 100 * 1e12},
gp_ops={torch.float32: 10 * 1e12, torch.half: 10 * 1e12},
memory_size_bytes=64 * (1024**3),
memory_bandwidth_bytes_ps=1.6 * (1024**4),
comm_grid=CommGrid(
grid=torch.arange(2),
topologies={0: InterconnectTopology(1e9, 1e-6)},
),
)
self.assertEqual(device.mma_ops, {DTYPE_FP8: 100 * 1e12})
q = torch.empty((2, 2, 16), dtype=torch.float8_e4m3fn, device="meta")
kv_cache = torch.empty((2, 32, 12), dtype=torch.float8_e4m3fn, device="meta")
block_table = torch.empty((2, 1), dtype=torch.int32, device="meta")
query_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32)
request_total_seq_lens = torch.tensor([32, 32], dtype=torch.int32)
query_lens = torch.tensor([1, 1], dtype=torch.int32)
W_UK_T = torch.empty((2, 12, 8), dtype=torch.bfloat16, device="meta")
W_UV = torch.empty((2, 8, 16), dtype=torch.bfloat16, device="meta")
scale = torch.tensor(1.0)
perf_model = AnalyticPerformanceModel(device)
with Runtime(perf_model, device) as runtime, torch.no_grad():
torch.ops.tensor_cast.multihead_latent_attention_quant(
q,
kv_cache,
block_table,
query_start_loc,
request_total_seq_lens,
query_lens,
W_UK_T,
W_UV,
None,
16,
None,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
scale,
None,
torch.bfloat16,
)
self.assert_mma_ops_time_positive(runtime, "multihead_latent_attention_quant.default")
def test_device_profile_rejects_mismatched_fp8_perf_values(self):
with self.assertRaisesRegex(ValueError, "FP8 variants must share the same performance value"):
DeviceProfile(
name=f"TEST_DEVICE_FP8_CONFLICT_{id(self)}",
vendor="TEST_VENDOR",
mma_ops={
torch.float8_e4m3fn: 100 * 1e12,
torch.float8_e5m2: 120 * 1e12,
},
gp_ops={torch.float32: 10 * 1e12},
memory_size_bytes=64 * (1024**3),
memory_bandwidth_bytes_ps=1.6 * (1024**4),
comm_grid=CommGrid(
grid=torch.arange(2),
topologies={0: InterconnectTopology(1e9, 1e-6)},
),
)
@parameterized.expand(
list(
product(
["Qwen/Qwen3-32B", "Qwen/Qwen3-235B-A22B", "zai-org/GLM-4.5"],
QUANT_TYPES,
)
)
)
def test_standard_attention(self, model_id, attn_quant_type):
kv_quant_start_idx = 0
kv_quant_end_idx = 1
hf_config = get_session_hf_config(model_id)
moe_config = get_moe_config(hf_config.model_type)
model_config = ModelConfig(
ParallelConfig(),
get_quant_config(kv_quant_start_idx, kv_quant_end_idx, attn_quant_type),
attention_cls=AttentionTensorCast,
num_hidden_layers_override=2,
moe_config=moe_config,
hf_config=hf_config,
)
model = self._get_transformer_model(model_id, model_config)
attn_meta, kv_cache_by_layers, num_tokens = create_attn_metadata_and_kv_cache(model, model_config)
inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config) as runtime, torch.no_grad():
outputs = model.forward(
inputs,
position_ids,
attention_meta=attn_meta,
kv_cache_by_layers=kv_cache_by_layers,
)
self.assertEqual(outputs.shape, (1, num_tokens, model.vocab_size))
result = runtime.table_averages()
self.assertIn("quantize.default", result)
self.assertIn("reshape_and_cache.default", result)
self.assertIn("attention_quant.default", result)
self.assert_mma_ops_time_positive(runtime, "attention_quant.default")
@parameterized.expand(list(product(["deepseek-ai/DeepSeek-V3.1"], QUANT_TYPES)))
def test_mla(self, model_id, attn_quant_type):
num_mtp_layers = 1
user_config = UserInputConfig(
model_id=model_id,
num_mtp_tokens=num_mtp_layers,
quantize_linear_action=QuantizeLinearAction.W8A8_STATIC,
quantize_attention_action=attn_quant_type,
)
model = get_cached_build_model(self._model_cache, user_config)
mtp_block_module_name = get_mtp_block_module_name(model.model_config.hf_config.model_type)
self.assertIsNotNone(mtp_block_module_name)
attn_meta, kv_cache_by_layers, num_tokens = create_mla_metadata_and_kv_cache(model, model.model_config)
self.assertTrue(has_submodule_with_cls_name(model, "MultiheadLatentAttentionTensorCast"))
inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with Runtime(perf_model, machine_config) as runtime, torch.no_grad():
outputs = model.forward(
inputs,
position_ids,
attention_meta=attn_meta,
kv_cache_by_layers=kv_cache_by_layers,
sampling_metadata=SamplingMetadata(),
)
self.assertEqual(outputs.shape, (1, num_mtp_layers + 1))
result = runtime.table_averages()
self.assertIn("quantize.default", result)
self.assertIn("concat_and_cache_mla.default", result)
self.assertIn("multihead_latent_attention_quant.default", result)
self.assert_mma_ops_time_positive(runtime, "multihead_latent_attention_quant.default")