from types import SimpleNamespace
import pytest
import torch
from tensor_cast.layers.mtp import MtpWrapper
from tensor_cast.layers.sampler import (
Sampler,
SamplingMetadata,
SpecDecodeMetadata,
select_lm_head_hidden_states,
)
from tensor_cast.model_config import MtpConfig
from tensor_cast.transformers.model import CausalLmWrapper
def _spec_metadata(logits_indices=None, num_active_requests=2, num_speculative_tokens=2):
if logits_indices is None:
logits_indices = [2, 3, 4, 5, 6, 7]
return SpecDecodeMetadata(
logits_indices=torch.tensor(logits_indices, dtype=torch.long),
num_active_requests=num_active_requests,
num_speculative_tokens=num_speculative_tokens,
)
def _lm_head_weight():
return torch.tensor(
[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
]
)
def _project(hidden_states):
return hidden_states @ _lm_head_weight().T
def test_spec_decode_selects_target_and_proposal_rows():
sampling_metadata = SamplingMetadata(spec_decode_metadata=_spec_metadata())
hidden_states = torch.arange(8 * 2, dtype=torch.float32).view(1, 8, 2)
target_hidden_states = select_lm_head_hidden_states(hidden_states, sampling_metadata, mode="target")
proposal_hidden_states = select_lm_head_hidden_states(hidden_states, sampling_metadata, mode="proposal")
assert target_hidden_states.tolist() == [
[4.0, 5.0],
[6.0, 7.0],
[8.0, 9.0],
[10.0, 11.0],
[12.0, 13.0],
[14.0, 15.0],
]
assert proposal_hidden_states.tolist() == [[8.0, 9.0], [14.0, 15.0]]
def test_lm_head_selection_ignores_default_selected_token_sentinel():
hidden_states = torch.arange(2 * 3 * 4, dtype=torch.float32).view(2, 3, 4)
selected_hidden_states = select_lm_head_hidden_states(hidden_states, SamplingMetadata())
assert selected_hidden_states is hidden_states
def test_spec_decode_selection_rejects_wrong_logits_indices_length():
sampling_metadata = SamplingMetadata(
spec_decode_metadata=_spec_metadata(logits_indices=[2, 4], num_active_requests=2, num_speculative_tokens=2)
)
hidden_states = torch.arange(8 * 2, dtype=torch.float32).view(1, 8, 2)
with pytest.raises(ValueError, match="logits_indices length must equal"):
select_lm_head_hidden_states(hidden_states, sampling_metadata, mode="target")
def test_spec_decode_sampler_returns_target_tokens_plus_bonus_token():
logits = torch.zeros(1, 6, 8)
for row, token_id in enumerate([1, 2, 3, 4, 5, 6]):
logits[0, row, token_id] = 100.0 + row
sampling_metadata = SamplingMetadata(spec_decode_metadata=_spec_metadata(logits_indices=list(range(6))))
next_tokens = Sampler()(logits, sampling_metadata)
assert next_tokens.tolist() == [[1, 2, 3], [4, 5, 6]]
class _FixedCausalInner(torch.nn.Module):
def __init__(self, hidden_states):
super().__init__()
self.hidden_states = hidden_states
def forward(self, **kwargs):
return (self.hidden_states,)
def test_causal_lm_wrapper_projects_spec_decode_verification_rows():
hidden_states = torch.arange(8 * 2, dtype=torch.float32).view(1, 8, 2)
wrapper = CausalLmWrapper(
SimpleNamespace(hidden_size=2, vocab_size=3),
_FixedCausalInner(hidden_states),
)
with torch.no_grad():
wrapper.lm_head.weight.copy_(_lm_head_weight())
sampling_metadata = SamplingMetadata(spec_decode_metadata=_spec_metadata())
logits = wrapper(
input_ids=None,
position_ids=torch.arange(8, dtype=torch.long).view(1, 8),
sampling_metadata=sampling_metadata,
)
expected_hidden_states = hidden_states.view(-1, 2).index_select(0, _spec_metadata().logits_indices)
assert logits.tolist() == _project(expected_hidden_states).tolist()
class _UnusedMtpBlock(torch.nn.Module):
def forward(self, *args, **kwargs):
raise AssertionError("test replaces the generated MTP layer")
class _FixedMtpLayer(torch.nn.Module):
def __init__(self, hidden_states):
super().__init__()
self.hidden_states = hidden_states
def forward(self, *args, **kwargs):
return self.hidden_states
class _FakeMtpBlock(torch.nn.Module):
def __init__(self, hf_config, layer_idx=None):
super().__init__()
def forward(self, hidden_states, *args, **kwargs):
return hidden_states
class _FakeRotaryEmbedding(torch.nn.Module):
def forward(self, hidden_states, position_ids):
return torch.empty_like(hidden_states)
class _FixedMtpInner(torch.nn.Module):
def __init__(self, logits, hidden_states, hf_config):
super().__init__()
self.logits = logits
self.hidden_states = hidden_states
self.block = _FakeMtpBlock(hf_config)
self.layer = torch.nn.Module()
self.layer.rotary_emb = _FakeRotaryEmbedding()
def forward(self, input_ids, position_ids, inputs_embeds, **kwargs):
assert kwargs["output_intermediate_hidden_states"]
return self.logits, self.hidden_states
def _mtp_wrapper(logits, hidden_states, hf_config):
wrapper = MtpWrapper(
MtpConfig(num_mtp_layers=1, mtp_block_module_name="_FakeMtpBlock"),
hf_config,
_FixedMtpInner(logits, hidden_states, hf_config),
)
with torch.no_grad():
wrapper.mtp.lm_head.weight.copy_(_lm_head_weight())
return wrapper
def test_mtp_wrapper_forward_prefill_reuses_already_selected_logits():
hf_config = SimpleNamespace(hidden_size=2, vocab_size=3, rms_norm_eps=1e-6, num_hidden_layers=1)
logits = torch.zeros(1, 2, 3)
logits[0, 0, 2] = 10.0
logits[0, 1, 1] = 10.0
target_hidden_states = torch.zeros(1, 6, 2)
mtp_hidden_states = torch.zeros(1, 6, 2)
mtp_hidden_states[0, 2, 0] = 10.0
mtp_hidden_states[0, 5, 1] = 10.0
wrapper = _mtp_wrapper(logits, target_hidden_states, hf_config)
wrapper.mtp.layers = torch.nn.ModuleList([_FixedMtpLayer(mtp_hidden_states)])
sampling_metadata = SamplingMetadata(
query_start_loc=torch.tensor([0, 3, 6], dtype=torch.long),
selected_token_indices=torch.tensor([2, 5], dtype=torch.long),
)
output = wrapper(
input_ids=torch.zeros(1, 6, dtype=torch.long),
position_ids=torch.arange(6, dtype=torch.long).view(1, 6),
inputs_embeds=torch.zeros(1, 6, 2),
sampling_metadata=sampling_metadata,
)
assert output.tolist() == [[2, 0], [1, 1]]
def test_mtp_wrapper_forward_feeds_bonus_token_and_projects_mtp_proposal_rows():
hf_config = SimpleNamespace(hidden_size=2, vocab_size=3, rms_norm_eps=1e-6, num_hidden_layers=1)
logits = torch.zeros(1, 6, 3)
logits[0, :, 0] = 1.0
logits[0, 2, 2] = 10.0
logits[0, 5, 2] = 10.0
target_hidden_states = torch.zeros(1, 8, 2)
mtp_hidden_states = torch.zeros(1, 8, 2)
mtp_hidden_states[0, 4, 0] = 10.0
mtp_hidden_states[0, 7, 1] = 10.0
wrapper = _mtp_wrapper(logits, target_hidden_states, hf_config)
wrapper.mtp.layers = torch.nn.ModuleList([_FixedMtpLayer(mtp_hidden_states)])
sampling_metadata = SamplingMetadata(
query_start_loc=torch.tensor([0, 4, 8], dtype=torch.long),
spec_decode_metadata=_spec_metadata(),
)
output = wrapper(
input_ids=torch.zeros(1, 8, dtype=torch.long),
position_ids=torch.arange(8, dtype=torch.long).view(1, 8),
inputs_embeds=torch.zeros(1, 8, 2),
sampling_metadata=sampling_metadata,
)
assert output.tolist() == [[2, 0], [2, 1]]