import unittest
from unittest.mock import patch, MagicMock
import torch
from mindie_llm.runtime.layers.attention.backend.sparse_attention import SfaMetadata
class TestSfaMetadata(unittest.TestCase):
def setUp(self):
self.dummy_mask = torch.ones(1, 1)
self.dummy_cos = torch.randn(1, 1)
self.dummy_sin = torch.randn(1, 1)
@patch('mindie_llm.runtime.layers.attention.backend.sparse_attention.get_parallel_info_manager')
def test_prefill_without_q_lens(self, mock_get_parallel_info_manager):
mock_manager = MagicMock()
mock_manager.attn_cp = MagicMock()
mock_manager.attn_cp.group_size = 1
mock_get_parallel_info_manager.return_value = mock_manager
model_inputs = MagicMock()
model_inputs.is_prefill = True
model_inputs.context_length = [3, 5, 2]
model_inputs.q_lens = None
metadata = SfaMetadata.from_model_input(
model_inputs, self.dummy_cos, self.dummy_sin, self.dummy_mask
)
expected_kv = torch.tensor([3, 5, 2], dtype=torch.int32)
expected_query = torch.tensor([3, 8, 10], dtype=torch.int32)
self.assertTrue(torch.equal(metadata.actual_seq_lengths_kv, expected_kv))
self.assertTrue(torch.equal(metadata.actual_seq_lengths_query, expected_query))
@patch('mindie_llm.runtime.layers.attention.backend.sparse_attention.get_parallel_info_manager')
def test_prefill_with_q_lens(self, mock_get_parallel_info_manager):
mock_manager = MagicMock()
mock_manager.attn_cp = MagicMock()
mock_manager.attn_cp.group_size = 1
mock_get_parallel_info_manager.return_value = mock_manager
model_inputs = MagicMock()
model_inputs.is_prefill = True
model_inputs.context_length = [10, 20]
model_inputs.q_lens = torch.tensor([4, 6], dtype=torch.int32)
metadata = SfaMetadata.from_model_input(
model_inputs, self.dummy_cos, self.dummy_sin, self.dummy_mask
)
expected_kv = torch.tensor([10, 20], dtype=torch.int32)
expected_query = torch.tensor([4, 10], dtype=torch.int32)
self.assertTrue(torch.equal(metadata.actual_seq_lengths_kv, expected_kv))
self.assertTrue(torch.equal(metadata.actual_seq_lengths_query, expected_query))
@patch('mindie_llm.runtime.layers.attention.backend.sparse_attention.get_parallel_info_manager')
def test_prefill_single_sequence(self, mock_get_parallel_info_manager):
mock_manager = MagicMock()
mock_manager.attn_cp = MagicMock()
mock_manager.attn_cp.group_size = 1
mock_get_parallel_info_manager.return_value = mock_manager
model_inputs = MagicMock()
model_inputs.is_prefill = True
model_inputs.context_length = [7]
model_inputs.q_lens = None
metadata = SfaMetadata.from_model_input(
model_inputs, self.dummy_cos, self.dummy_sin, self.dummy_mask
)
expected_kv = torch.tensor([7], dtype=torch.int32)
expected_query = torch.tensor([7], dtype=torch.int32)
self.assertTrue(torch.equal(metadata.actual_seq_lengths_kv, expected_kv))
self.assertTrue(torch.equal(metadata.actual_seq_lengths_query, expected_query))
@patch('mindie_llm.runtime.layers.attention.backend.sparse_attention.get_parallel_info_manager')
def test_decode_mode(self, mock_get_parallel_info_manager):
mock_manager = MagicMock()
mock_manager.attn_cp = MagicMock()
mock_manager.attn_cp.group_size = 1
mock_get_parallel_info_manager.return_value = mock_manager
batch_size = 4
model_inputs = MagicMock()
model_inputs.is_prefill = False
model_inputs.block_tables = torch.zeros(batch_size, 8)
metadata = SfaMetadata.from_model_input(
model_inputs, self.dummy_cos, self.dummy_sin, self.dummy_mask
)
expected_kv = torch.tensor([1, 1, 1, 1], dtype=torch.int32)
expected_query = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
self.assertTrue(torch.equal(metadata.actual_seq_lengths_kv, expected_kv))
self.assertTrue(torch.equal(metadata.actual_seq_lengths_query, expected_query))
@patch('mindie_llm.runtime.layers.attention.backend.sparse_attention.get_parallel_info_manager')
def test_decode_single_sequence(self, mock_get_parallel_info_manager):
mock_manager = MagicMock()
mock_manager.attn_cp = MagicMock()
mock_manager.attn_cp.group_size = 1
mock_get_parallel_info_manager.return_value = mock_manager
model_inputs = MagicMock()
model_inputs.is_prefill = False
model_inputs.block_tables = torch.zeros(1, 8)
metadata = SfaMetadata.from_model_input(
model_inputs, self.dummy_cos, self.dummy_sin, self.dummy_mask
)
expected_kv = torch.tensor([1], dtype=torch.int32)
expected_query = torch.tensor([1], dtype=torch.int32)
self.assertTrue(torch.equal(metadata.actual_seq_lengths_kv, expected_kv))
self.assertTrue(torch.equal(metadata.actual_seq_lengths_query, expected_query))
if __name__ == "__main__":
unittest.main()