import unittest
from unittest.mock import MagicMock, patch
import torch
from mindie_llm.runtime.layers.embedding.embedding import VocabParallelEmbedding, ParallelLMHead
from mindie_llm.runtime.layers.quantization.unquantized import (
UnquantizedEmbeddingMethod,
UnquantizedLinearMethod,
)
from mindie_llm.runtime.layers.parameter import BaseParameter, ColumnParameter
from mindie_llm.runtime.utils.distributed import set_parallel_info_manager
from mindie_llm.runtime.model_runner.forward_context import set_forward_context, ForwardContext
from mindie_llm.runtime.model_runner.forward_context_exp import BatchDescriptor
class TestVocabParallelEmbedding(unittest.TestCase):
"""Test cases for VocabParallelEmbedding with UnquantizedEmbeddingMethod."""
def setUp(self):
"""Set up test fixtures."""
self.mock_parallel_info = MagicMock()
self.mock_parallel_info.rank = 0
self.mock_parallel_info.group_size = 1
self.mock_parallel_info.process_group = None
self.mock_parallel_info_manager = MagicMock()
self.mock_parallel_info_manager.word_embed_tp = self.mock_parallel_info
set_parallel_info_manager(self.mock_parallel_info_manager)
def tearDown(self):
"""Clean up after tests."""
set_parallel_info_manager(None)
set_forward_context(None)
def _set_forward_context(self):
"""Helper method to set up forward context for tests."""
ctx = ForwardContext(
attn_metadata={},
lm_head_indices=None,
mtp_metadata=None,
is_prefill=False,
num_tokens_across_dp_cpu=torch.tensor([0]),
batch_descriptor=BatchDescriptor(num_tokens=0, is_flash_comm_enabled=False),
)
ctx.enable_flash_comm = False
set_forward_context(ctx)
def test_init_without_quant_config(self):
"""Test initialization without quantization config."""
num_embeddings = 1000
embedding_dim = 512
layer = VocabParallelEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
self.assertEqual(layer.num_embeddings, num_embeddings)
self.assertEqual(layer.embedding_dim, embedding_dim)
self.assertIsInstance(layer.quant_method, UnquantizedEmbeddingMethod)
self.assertEqual(layer.tp_rank, 0)
self.assertEqual(layer.tp_size, 1)
self.assertFalse(layer.is_parallel)
self.assertEqual(layer.output_partition_size, embedding_dim)
self.assertIsNotNone(layer.weight)
def test_init_with_custom_dtype(self):
"""Test initialization with custom dtype."""
num_embeddings = 1000
embedding_dim = 512
params_dtype = torch.float16
layer = VocabParallelEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
params_dtype=params_dtype,
)
self.assertEqual(layer.params_dtype, params_dtype)
self.assertEqual(layer.weight.data.dtype, params_dtype)
def test_init_with_prefix(self):
"""Test initialization with prefix."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
prefix="embed_tokens",
)
self.assertEqual(layer.prefix, "embed_tokens")
def test_init_with_partition_weights_false(self):
"""Test initialization with partition_weights=False."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=False,
)
self.assertFalse(layer.is_parallel)
self.assertEqual(layer.output_partition_size, 512)
def test_init_with_partition_weights_true_single_rank(self):
"""Test initialization with partition_weights=True but single rank."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=True,
)
self.assertFalse(layer.is_parallel)
self.assertEqual(layer.output_partition_size, 512)
def test_init_with_partition_weights_true_multi_rank(self):
"""Test initialization with partition_weights=True and multiple ranks."""
self.mock_parallel_info.group_size = 2
self.mock_parallel_info.rank = 0
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=True,
)
self.assertTrue(layer.is_parallel)
self.assertEqual(layer.output_partition_size, 256)
def test_weight_shape(self):
"""Test that weight shape is correct."""
num_embeddings = 1000
embedding_dim = 512
layer = VocabParallelEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
expected_shape = (num_embeddings, embedding_dim)
self.assertEqual(layer.weight.data.shape, expected_shape)
def test_weight_loader_non_parallel(self):
"""Test weight loading in non-parallel mode."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=False,
)
loaded_weight = torch.randn(1000, 512)
param = layer.weight
with patch.object(param, 'load_weight') as mock_load:
layer.weight_loader(param, loaded_weight)
mock_load.assert_called_once_with(loaded_weight)
def test_weight_loader_parallel(self):
"""Test weight loading in parallel mode."""
self.mock_parallel_info.group_size = 2
self.mock_parallel_info.rank = 0
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=True,
)
loaded_weight = torch.randn(1000, 512)
param = layer.weight
with patch.object(param, 'load_row_parallel_weight') as mock_load:
layer.weight_loader(param, loaded_weight)
mock_load.assert_called_once_with(loaded_weight=loaded_weight, tp_rank=0)
@patch('torch.nn.functional.embedding')
@patch('torch.distributed.all_gather_into_tensor')
def test_forward_non_parallel(self, mock_all_gather, mock_embedding):
"""Test forward pass in non-parallel mode."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=False,
)
layer.weight.data = torch.randn(1000, 512)
x = torch.tensor([0, 1, 2])
mock_embedding.return_value = torch.randn(3, 512)
self._set_forward_context()
output = layer.forward(x)
mock_embedding.assert_called_once()
call_args = mock_embedding.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.equal(call_args[0][0], x))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (3, 512))
mock_all_gather.assert_not_called()
@patch('torch.nn.functional.embedding')
@patch('torch.distributed.all_gather_into_tensor')
def test_forward_parallel(self, mock_all_gather, mock_embedding):
"""Test forward pass in parallel mode."""
self.mock_parallel_info.group_size = 2
self.mock_parallel_info.rank = 0
self.mock_parallel_info.process_group = MagicMock()
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
partition_weights=True,
)
layer.weight.data = torch.randn(1000, 256)
x = torch.tensor([0, 1, 2])
mock_embedding.return_value = torch.randn(3, 256)
def mock_all_gather_side_effect(output_tensor, input_tensor, group=None):
output_tensor.copy_(input_tensor.repeat(2, 1, 1))
mock_all_gather.side_effect = mock_all_gather_side_effect
self._set_forward_context()
output = layer.forward(x)
mock_embedding.assert_called_once()
call_args = mock_embedding.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.equal(call_args[0][0], x))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
mock_all_gather.assert_called_once()
self.assertEqual(output.shape, (3, 512))
def test_extra_repr(self):
"""Test extra_repr method."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
params_dtype=torch.float32,
)
repr_str = layer.extra_repr()
self.assertIn("num_embeddings=1000", repr_str)
self.assertIn("embedding_dim=512", repr_str)
self.assertIn("UnquantizedEmbeddingMethod", repr_str)
self.assertIn("tp_size=1", repr_str)
def test_unquantized_embedding_method_apply(self):
"""Test that UnquantizedEmbeddingMethod.apply is called correctly."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
)
layer.weight.data = torch.randn(1000, 512)
x = torch.tensor([0, 1, 2])
self._set_forward_context()
with patch.object(layer.quant_method, 'apply') as mock_apply:
mock_apply.return_value = torch.randn(3, 512)
output = layer.forward(x)
mock_apply.assert_called_once_with(layer, x)
self.assertEqual(output.shape, (3, 512))
@patch('torch.nn.functional.embedding')
def test_unquantized_embedding_method_apply_real_computation(self, mock_embedding):
"""Test UnquantizedEmbeddingMethod.apply with mocked embedding function."""
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
)
torch.manual_seed(42)
layer.weight.data = torch.randn(1000, 512)
x = torch.tensor([0, 1, 2])
mock_embedding.return_value = torch.randn(3, 512)
self._set_forward_context()
output = layer.forward(x)
mock_embedding.assert_called_once()
call_args = mock_embedding.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.equal(call_args[0][0], x))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (3, 512))
def test_weight_loader_w8a8sc_padding(self):
"""Test W8A8SC weight padding: F.pad(loaded_weight, (0, 0, 0, 1)) (embedding.py:90-94)"""
from mindie_llm.runtime.layers.parameter import ModelWeightParameter
mock_quant_config = MagicMock()
mock_quant_config.model_quant_type = "W8A8SC"
layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
quant_config=mock_quant_config,
)
layer.weight = ModelWeightParameter(torch.empty(1000, 512))
layer.weight.add_attrs({"input_dim": 1, "output_dim": 0})
loaded_weight = torch.randn(999, 512)
with patch.object(layer.weight, 'load_weight') as mock_load:
layer.weight_loader(layer.weight, loaded_weight)
called_weight = mock_load.call_args[0][0]
self.assertEqual(called_weight.shape, (1000, 512))
class TestParallelLMHead(unittest.TestCase):
"""Test cases for ParallelLMHead with UnquantizedLinearMethod."""
def setUp(self):
"""Set up test fixtures."""
self.mock_word_embed_tp = MagicMock()
self.mock_word_embed_tp.rank = 0
self.mock_word_embed_tp.group_size = 1
self.mock_lm_head_tp = MagicMock()
self.mock_lm_head_tp.rank = 0
self.mock_lm_head_tp.group_size = 1
self.mock_lm_head_tp.process_group = None
self.mock_parallel_info_manager = MagicMock()
self.mock_parallel_info_manager.word_embed_tp = self.mock_word_embed_tp
self.mock_parallel_info_manager.lm_head_tp = self.mock_lm_head_tp
set_parallel_info_manager(self.mock_parallel_info_manager)
def tearDown(self):
"""Clean up after tests."""
set_parallel_info_manager(None)
def test_init_without_bias(self):
"""Test initialization without bias."""
num_embeddings = 1000
embedding_dim = 512
layer = ParallelLMHead(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
bias=False,
)
self.assertEqual(layer.num_embeddings, num_embeddings)
self.assertEqual(layer.embedding_dim, embedding_dim)
self.assertFalse(layer.has_bias)
self.assertIsInstance(layer.quant_method, UnquantizedLinearMethod)
self.assertIsNone(layer.bias)
def test_init_with_bias(self):
"""Test initialization with bias."""
num_embeddings = 1000
embedding_dim = 512
layer = ParallelLMHead(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
bias=True,
)
self.assertTrue(layer.has_bias)
self.assertIsNotNone(layer.bias)
self.assertEqual(layer.bias.data.shape, (1000,))
def test_init_with_custom_dtypes(self):
"""Test initialization with custom weight and bias dtypes."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=True,
weight_dtype=torch.float16,
bias_dtype=torch.float32,
)
self.assertEqual(layer.weight_dtype, torch.float16)
self.assertEqual(layer.bias_dtype, torch.float32)
self.assertEqual(layer.weight.data.dtype, torch.float16)
self.assertEqual(layer.bias.data.dtype, torch.float32)
def test_init_uses_lm_head_tp(self):
"""Test that ParallelLMHead uses lm_head_tp instead of word_embed_tp."""
self.mock_lm_head_tp.group_size = 2
self.mock_lm_head_tp.rank = 1
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
self.assertEqual(layer.tp_rank, 1)
self.assertEqual(layer.tp_size, 2)
self.assertEqual(layer.parallel_info, self.mock_lm_head_tp)
def test_weight_shape_linear(self):
"""Test that weight shape is correct for linear layer."""
num_embeddings = 1000
embedding_dim = 512
layer = ParallelLMHead(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
expected_shape = (num_embeddings, embedding_dim)
self.assertEqual(layer.weight.data.shape, expected_shape)
def test_weight_shape_parallel(self):
"""Test weight shape in parallel mode."""
self.mock_lm_head_tp.group_size = 2
self.mock_lm_head_tp.rank = 0
num_embeddings = 1000
embedding_dim = 512
layer = ParallelLMHead(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
expected_shape = (500, embedding_dim)
self.assertEqual(layer.weight.data.shape, expected_shape)
def test_weight_loader_column_parameter(self):
"""Test weight loading with ColumnParameter."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
param = ColumnParameter(torch.randn(1000, 512))
loaded_weight = torch.randn(1000, 512)
with patch.object(param, 'load_column_parallel_weight') as mock_load:
layer.weight_loader(param, loaded_weight)
mock_load.assert_called_once_with(loaded_weight=loaded_weight, tp_rank=0)
def test_weight_loader_base_parameter(self):
"""Test weight loading with BaseParameter."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
param = BaseParameter(torch.randn(1000))
loaded_weight = torch.randn(1000)
with patch.object(param, 'load_weight') as mock_load:
layer.weight_loader(param, loaded_weight)
mock_load.assert_called_once_with(loaded_weight)
def test_tie_weights(self):
"""Test tie_weights method."""
embed_layer = VocabParallelEmbedding(
num_embeddings=1000,
embedding_dim=512,
prefix="embed_tokens",
)
lm_head = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
prefix="lm_head",
)
result = lm_head.tie_weights(embed_layer)
self.assertEqual(lm_head.prefix, "embed_tokens")
self.assertEqual(result, lm_head)
@patch('torch.nn.functional.linear')
@patch('torch.distributed.all_gather_into_tensor')
def test_forward_without_indices(self, mock_all_gather, mock_linear):
"""Test forward pass without lm_head_indices."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
layer.weight.data = torch.randn(1000, 512)
hidden_states = torch.randn(2, 3, 512)
mock_linear.return_value = torch.randn(2, 3, 1000)
output = layer.forward(hidden_states)
mock_linear.assert_called_once()
call_args = mock_linear.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.allclose(call_args[0][0], hidden_states))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (2, 3, 1000))
mock_all_gather.assert_not_called()
@patch('torch.nn.functional.linear')
@patch('torch.distributed.all_gather_into_tensor')
def test_forward_with_indices(self, mock_all_gather, mock_linear):
"""Test forward pass with lm_head_indices."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
layer.weight.data = torch.randn(1000, 512)
hidden_states = torch.randn(1024, 512)
lm_head_indices = torch.tensor([0, 2])
mock_linear.return_value = torch.randn(2, 1000)
output = layer.forward(hidden_states, lm_head_indices)
mock_linear.assert_called_once()
self.assertEqual(output.shape, (2, 1000))
mock_all_gather.assert_not_called()
@patch('torch.nn.functional.linear')
@patch('torch.distributed.all_gather_into_tensor')
def test_forward_parallel(self, mock_all_gather, mock_linear):
"""Test forward pass in parallel mode."""
self.mock_lm_head_tp.group_size = 2
self.mock_lm_head_tp.rank = 0
self.mock_lm_head_tp.process_group = MagicMock()
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
)
layer.weight.data = torch.randn(500, 512)
hidden_states = torch.randn(3, 512)
mock_linear.return_value = torch.randn(3, 500)
def mock_all_gather_side_effect(output_tensor, input_tensor, group=None):
output_tensor.copy_(input_tensor.repeat(2, 1, 1))
mock_all_gather.side_effect = mock_all_gather_side_effect
output = layer.forward(hidden_states)
mock_linear.assert_called_once()
call_args = mock_linear.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.allclose(call_args[0][0], hidden_states))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
mock_all_gather.assert_called_once()
self.assertEqual(output.shape, (3, 1000))
def test_unquantized_linear_method_apply(self):
"""Test that UnquantizedLinearMethod.apply is called correctly."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=True,
)
layer.weight.data = torch.randn(1000, 512)
layer.bias.data = torch.randn(1000)
hidden_states = torch.randn(2, 3, 512)
with patch.object(layer.quant_method, 'apply') as mock_apply:
mock_apply.return_value = torch.randn(2, 3, 1000)
output = layer.forward(hidden_states)
mock_apply.assert_called_once_with(layer, hidden_states)
self.assertEqual(output.shape, (2, 3, 1000))
@patch('torch.nn.functional.linear')
def test_unquantized_linear_method_with_bias(self, mock_linear):
"""Test UnquantizedLinearMethod with bias."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=True,
)
layer.weight.data = torch.randn(1000, 512)
layer.bias.data = torch.randn(1000)
hidden_states = torch.randn(2, 3, 512)
mock_linear.return_value = torch.randn(2, 3, 1000)
output = layer.forward(hidden_states)
mock_linear.assert_called_once()
call_args = mock_linear.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.allclose(call_args[0][0], hidden_states))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (2, 3, 1000))
@patch('torch.nn.functional.linear')
def test_unquantized_linear_method_without_bias(self, mock_linear):
"""Test UnquantizedLinearMethod without bias."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=False,
)
layer.weight.data = torch.randn(1000, 512)
hidden_states = torch.randn(2, 3, 512)
mock_linear.return_value = torch.randn(2, 3, 1000)
output = layer.forward(hidden_states)
mock_linear.assert_called_once()
call_args = mock_linear.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.allclose(call_args[0][0], hidden_states))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (2, 3, 1000))
self.assertIsNone(layer.bias)
@patch('torch.nn.functional.linear')
def test_unquantized_linear_method_skip_bias_add(self, mock_linear):
"""Test UnquantizedLinearMethod with skip_bias_add attribute."""
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=True,
)
layer.weight.data = torch.randn(1000, 512)
layer.bias.data = torch.randn(1000)
layer.skip_bias_add = True
hidden_states = torch.randn(2, 3, 512)
mock_linear.return_value = torch.randn(2, 3, 1000)
output = layer.forward(hidden_states)
mock_linear.assert_called_once()
call_args = mock_linear.call_args
self.assertEqual(len(call_args[0]), 2)
self.assertTrue(torch.allclose(call_args[0][0], hidden_states))
self.assertTrue(torch.allclose(call_args[0][1], layer.weight.data))
self.assertEqual(output.shape, (2, 3, 1000))
def test_bias_shape_parallel(self):
"""Test bias shape in parallel mode."""
self.mock_lm_head_tp.group_size = 2
self.mock_lm_head_tp.rank = 0
layer = ParallelLMHead(
num_embeddings=1000,
embedding_dim=512,
bias=True,
)
expected_shape = (500,)
self.assertEqual(layer.bias.data.shape, expected_shape)
if __name__ == '__main__':
unittest.main()