import torch
from torch import nn
from tensor_cast.layers.parallel_embedding import ParallelEmbedding
from tensor_cast.model_config import WordEmbeddingTPMode
class _FakeParallelGroup:
def __init__(self, world_size, rank_in_group=0):
self.world_size = world_size
self.rank_in_group = rank_in_group
def all_reduce(self, input_):
return input_
def all_gather(self, input_, dim=-1):
return torch.cat([input_] * self.world_size, dim=dim)
def _make_embedding(vocab_size: int, hidden_size: int, padding_idx):
embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx)
return embedding
def test_row_shard_padding_idx_inside_local_range_is_relativized():
"""Padding idx falling in the owning rank's local range must be rebased."""
vocab_size, hidden_size, tp_size = 32, 8, 4
block_size = vocab_size // tp_size
padding_idx = 30
owning_rank = padding_idx // block_size
embedding = _make_embedding(vocab_size, hidden_size, padding_idx)
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=owning_rank),
shard_mode=WordEmbeddingTPMode.row,
)
assert parallel._inner.weight.shape == (block_size, hidden_size)
assert parallel._inner.padding_idx == padding_idx - owning_rank * block_size
assert parallel._row_start == owning_rank * block_size
assert parallel._row_end == owning_rank * block_size + block_size
def test_row_shard_padding_idx_outside_local_range_becomes_none():
"""Ranks that do not own the padding row must have padding_idx cleared."""
vocab_size, hidden_size, tp_size = 32, 8, 4
block_size = vocab_size // tp_size
padding_idx = 30
for rank in range(tp_size):
if rank == padding_idx // block_size:
continue
embedding = _make_embedding(vocab_size, hidden_size, padding_idx)
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=rank),
shard_mode=WordEmbeddingTPMode.row,
)
assert parallel._inner.padding_idx is None, f"rank {rank} should have padding_idx=None"
def test_row_shard_no_padding_idx_stays_none():
"""An embedding without padding_idx must keep padding_idx=None after sharding."""
vocab_size, hidden_size, tp_size = 32, 8, 4
for rank in range(tp_size):
parallel = ParallelEmbedding(
embedding=_make_embedding(vocab_size, hidden_size, padding_idx=None),
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=rank),
shard_mode=WordEmbeddingTPMode.row,
)
assert parallel._inner.padding_idx is None
def test_col_shard_preserves_padding_idx():
"""Column sharding does not touch the vocab dim, so padding_idx must be unchanged."""
vocab_size, hidden_size, tp_size = 32, 8, 4
padding_idx = 30
for rank in range(tp_size):
embedding = _make_embedding(vocab_size, hidden_size, padding_idx)
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=rank),
shard_mode=WordEmbeddingTPMode.col,
)
assert parallel._inner.weight.shape == (vocab_size, hidden_size // tp_size)
assert parallel._inner.padding_idx == padding_idx
def test_tp_size_one_is_noop():
"""With tp_size=1 the layer must not touch weight or padding_idx."""
vocab_size, hidden_size, padding_idx = 32, 8, 30
embedding = _make_embedding(vocab_size, hidden_size, padding_idx)
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=1, rank_in_group=0),
shard_mode=WordEmbeddingTPMode.row,
)
assert parallel._inner.weight.shape == (vocab_size, hidden_size)
assert parallel._inner.padding_idx == padding_idx
def test_row_shard_negative_padding_idx_is_normalized():
"""nn.Embedding normalizes negative padding_idx at __init__, but a caller may set it
after construction. ParallelEmbedding must still place it on the correct rank.
"""
vocab_size, hidden_size, tp_size = 32, 8, 4
block_size = vocab_size // tp_size
raw_negative = -1
expected_positive = vocab_size + raw_negative
owning_rank = expected_positive // block_size
for rank in range(tp_size):
embedding = _make_embedding(vocab_size, hidden_size, padding_idx=None)
embedding.padding_idx = raw_negative
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=rank),
shard_mode=WordEmbeddingTPMode.row,
)
if rank == owning_rank:
assert parallel._inner.padding_idx == expected_positive - rank * block_size
else:
assert parallel._inner.padding_idx is None
def test_row_shard_padding_idx_assertion_does_not_trigger():
"""Regression: F.embedding's padding_idx < weight.size(0) assertion must hold
on every rank after sharding (the failure mode that surfaced under torch.compile).
"""
vocab_size, hidden_size, tp_size = 32, 8, 4
padding_idx = 30
for rank in range(tp_size):
embedding = _make_embedding(vocab_size, hidden_size, padding_idx)
parallel = ParallelEmbedding(
embedding=embedding,
tp_group=_FakeParallelGroup(world_size=tp_size, rank_in_group=rank),
shard_mode=WordEmbeddingTPMode.row,
)
pad = parallel._inner.padding_idx
if pad is not None:
assert pad < parallel._inner.weight.size(0), (
f"rank {rank}: padding_idx {pad} must be < num_embeddings {parallel._inner.weight.size(0)}"
)
local_idx = torch.zeros((1, 2), dtype=torch.long)
out = parallel._inner(local_idx)
assert out.shape == (1, 2, hidden_size)