import pytest
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.model_config import ParallelConfig, WordEmbeddingTPMode
def _get_parallel_config(parallel_configuration: tuple):
parallel_config = ParallelConfig(
world_size=parallel_configuration[0],
tensor_parallel_size=parallel_configuration[1],
o_proj_tensor_parallel_size=parallel_configuration[2],
mlp_tensor_parallel_size=parallel_configuration[3],
lmhead_tensor_parallel_size=parallel_configuration[4],
)
if len(parallel_configuration) > 5:
parallel_config.embedding_parallel = WordEmbeddingTPMode.col if parallel_configuration[5] else None
return parallel_config
def _has_dp_transform(parallel_config: ParallelConfig):
return (
parallel_config.data_parallel_size != parallel_config.mlp_data_parallel_size
or parallel_config.data_parallel_size != parallel_config.o_proj_data_parallel_size
or parallel_config.data_parallel_size != parallel_config.lmhead_data_parallel_size
)
def test_parallel_config_layer_split_flags():
cfg = ParallelConfig(
world_size=8,
tensor_parallel_size=2,
data_parallel_size=4,
mlp_tensor_parallel_size=4,
lmhead_tensor_parallel_size=1,
)
assert cfg.has_attn_tp()
assert cfg.has_mlp_tp()
assert not cfg.has_lmhead_tp()
@pytest.mark.parametrize(
"parallel_configuration, expected_dp_transform",
[
((16, 1, 1, 1, 1), False),
((16, 4, 2, 8, 16), True),
((16, 4, 2, 8, 16, True), True),
],
)
def test_parallel_config_topology_flags(parallel_configuration, expected_dp_transform):
cfg = _get_parallel_config(parallel_configuration)
assert cfg.world_size == parallel_configuration[0]
assert cfg.tensor_parallel_size == parallel_configuration[1]
assert _has_dp_transform(cfg) is expected_dp_transform
def test_mtp_ep_user_config_parallel_fields():
user_config = UserInputConfig(
model_id="deepseek-ai/DeepSeek-V3.1",
num_mtp_tokens=2,
world_size=16,
ep_size=16,
moe_dp_size=1,
moe_tp_size=1,
)
assert user_config.num_mtp_tokens == 2
assert user_config.world_size == 16
assert user_config.ep_size == 16