"""Unit tests for MoE permutation API"""
import pytest
import torch
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
moe_permute_with_probs,
moe_permute_and_pad_with_probs,
)
def _seed() -> None:
"""Set random seed for reproducibility."""
seed = 1234
torch.manual_seed(seed)
if hasattr(torch, "npu"):
torch.npu.manual_seed(seed)
def _get_device() -> torch.device:
"""Return device based on availability (NPU if available, else CPU)."""
return torch.device("npu" if torch.npu.is_available() else "cpu")
class TestIndexMode:
"""Tests for permutation with index routing map."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test fixtures."""
_seed()
self.num_tokens = 4
self.hidden_size = 4
self.topK = 2
self.num_experts = 4
def test_permute_unpermute_index_mode(self):
"""Test permute and unpermute with index mode, verify values and gradients."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
index = torch.tensor(
[
[0, 1],
[0, 1],
[2, 3],
[2, 3],
],
device=_get_device(),
dtype=torch.int32,
)
probs = torch.tensor(
[
[0.3, 0.7],
[0.3, 0.7],
[0.3, 0.7],
[0.3, 0.7],
],
device=_get_device(),
dtype=torch.float32,
)
num_out_tokens = self.num_tokens * self.topK
permuted, row_id_map = moe_permute(inp, index, num_out_tokens, map_type="index")
assert permuted.shape == (num_out_tokens, self.hidden_size), (
f"Expected permuted shape ({num_out_tokens}, {self.hidden_size}), got {permuted.shape}"
)
assert row_id_map.shape == (num_out_tokens,), (
f"Expected row_id_map shape ({num_out_tokens},), got {row_id_map.shape}"
)
expected_permuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(permuted, expected_permuted, rtol=1e-5, atol=1e-5), (
f"Permuted values incorrect.\nExpected:\n{expected_permuted}\nGot:\n{permuted}"
)
unpermuted = moe_unpermute(permuted, row_id_map, merging_probs=probs, map_type="index")
assert unpermuted.shape == (self.num_tokens, self.hidden_size), (
f"Expected unpermuted shape ({self.num_tokens}, {self.hidden_size}), "
f"got {unpermuted.shape}"
)
expected_unpermuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(unpermuted, expected_unpermuted, rtol=1e-5, atol=1e-5), (
f"Unpermuted values incorrect.\nExpected:\n{expected_unpermuted}\nGot:\n{unpermuted}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
probs_grad = probs.clone().detach().requires_grad_(True)
permuted_g, row_id_map_g = moe_permute(inp_grad, index, num_out_tokens, map_type="index")
permuted_processed = permuted_g * 2.0
unpermuted_g = moe_unpermute(
permuted_processed, row_id_map_g, merging_probs=probs_grad, map_type="index"
)
loss = unpermuted_g.sum()
loss.backward()
assert inp_grad.grad is not None, "Input gradient should not be None"
assert probs_grad.grad is not None, "Probs gradient should not be None"
assert inp_grad.grad.shape == inp.shape, (
f"Expected input grad shape {inp.shape}, got {inp_grad.grad.shape}"
)
assert probs_grad.grad.shape == probs.shape, (
f"Expected probs grad shape {probs.shape}, got {probs_grad.grad.shape}"
)
expected_inp_grad = torch.tensor(
[
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(inp_grad.grad, expected_inp_grad, rtol=1e-5, atol=1e-5), (
f"Input gradient incorrect.\nExpected:\n{expected_inp_grad}\nGot:\n{inp_grad.grad}"
)
expected_probs_grad = torch.tensor(
[
[8.0, 8.0],
[16.0, 16.0],
[24.0, 24.0],
[32.0, 32.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(probs_grad.grad, expected_probs_grad, rtol=1e-5, atol=1e-5), (
f"Probs gradient incorrect.\nExpected:\n{expected_probs_grad}\nGot:\n{probs_grad.grad}"
)
def test_permute_with_zero_num_out_tokens(self):
"""Test permute with num_out_tokens=0, which should auto-infer as num_tokens * topK."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
index = torch.tensor(
[
[0, 1],
[0, 1],
[2, 3],
[2, 3],
],
device=_get_device(),
dtype=torch.int32,
)
permuted, row_id_map = moe_permute(inp, index, num_out_tokens=0, map_type="index")
expected_num_out_tokens = self.num_tokens * self.topK
assert permuted.shape == (expected_num_out_tokens, self.hidden_size), (
f"Expected permuted shape ({expected_num_out_tokens}, {self.hidden_size}), "
f"got {permuted.shape}"
)
assert row_id_map.shape == (expected_num_out_tokens,), (
f"Expected row_id_map shape ({expected_num_out_tokens},), got {row_id_map.shape}"
)
permuted_explicit, row_id_map_explicit = moe_permute(
inp, index, num_out_tokens=expected_num_out_tokens, map_type="index"
)
assert torch.allclose(permuted, permuted_explicit, rtol=1e-5, atol=1e-5), (
"Auto-inferred permute should match explicit permute"
)
assert torch.equal(row_id_map, row_id_map_explicit), (
"Auto-inferred row_id_map should match explicit row_id_map"
)
class TestMaskModeNonPad:
"""Tests for permutation with mask routing map (non-padded mode)."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test fixtures."""
_seed()
self.num_tokens = 4
self.hidden_size = 4
self.topK = 2
self.num_experts = 4
def test_permute_unpermute_mask_mode(self):
"""Test permute and unpermute with mask mode, verify values and gradients."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
routing_map = torch.tensor(
[
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1],
],
device=_get_device(),
dtype=torch.int8,
)
probs = torch.tensor(
[
[0.3, 0.7, 0.0, 0.0],
[0.3, 0.7, 0.0, 0.0],
[0.0, 0.0, 0.3, 0.7],
[0.0, 0.0, 0.3, 0.7],
],
device=_get_device(),
dtype=torch.float32,
)
num_out_tokens = int(routing_map.sum().item())
permuted, row_id_map = moe_permute(inp, routing_map, num_out_tokens, map_type="mask")
assert permuted.shape == (num_out_tokens, self.hidden_size), (
f"Expected permuted shape ({num_out_tokens}, {self.hidden_size}), got {permuted.shape}"
)
assert row_id_map.shape == (num_out_tokens,), (
f"Expected row_id_map shape ({num_out_tokens},), got {row_id_map.shape}"
)
expected_permuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(permuted, expected_permuted, rtol=1e-5, atol=1e-5), (
f"Permuted values incorrect.\nExpected:\n{expected_permuted}\nGot:\n{permuted}"
)
unpermuted = moe_unpermute(
permuted,
row_id_map,
merging_probs=probs,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
routing_map=routing_map,
)
assert unpermuted.shape == (self.num_tokens, self.hidden_size), (
f"Expected unpermuted shape ({self.num_tokens}, {self.hidden_size}), "
f"got {unpermuted.shape}"
)
expected_unpermuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(unpermuted, expected_unpermuted, rtol=1e-5, atol=1e-5), (
f"Unpermuted values incorrect.\nExpected:\n{expected_unpermuted}\nGot:\n{unpermuted}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
probs_grad = probs.clone().detach().requires_grad_(True)
permuted_g, row_id_map_g = moe_permute(
inp_grad, routing_map, num_out_tokens, map_type="mask"
)
permuted_processed = permuted_g * 2.0
unpermuted_g = moe_unpermute(
permuted_processed,
row_id_map_g,
merging_probs=probs_grad,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
routing_map=routing_map,
)
loss = unpermuted_g.sum()
loss.backward()
assert inp_grad.grad is not None, "Input gradient should not be None"
assert probs_grad.grad is not None, "Probs gradient should not be None"
assert inp_grad.grad.shape == inp.shape
assert probs_grad.grad.shape == probs.shape
expected_inp_grad = torch.tensor(
[
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
[2.0, 2.0, 2.0, 2.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(inp_grad.grad, expected_inp_grad, rtol=1e-5, atol=1e-5), (
f"Input gradient incorrect.\nExpected:\n{expected_inp_grad}\nGot:\n{inp_grad.grad}"
)
def test_permute_with_probs_mask_mode(self):
"""Test permute_with_probs and unpermute with mask mode."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
routing_map = torch.tensor(
[
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1],
],
device=_get_device(),
dtype=torch.int8,
)
probs = torch.tensor(
[
[0.3, 0.7, 0.0, 0.0],
[0.3, 0.7, 0.0, 0.0],
[0.0, 0.0, 0.3, 0.7],
[0.0, 0.0, 0.3, 0.7],
],
device=_get_device(),
dtype=torch.float32,
)
num_out_tokens = int(routing_map.sum().item())
permuted, permuted_probs, row_id_map = moe_permute_with_probs(
inp, probs, routing_map, num_out_tokens
)
assert permuted.shape == (num_out_tokens, self.hidden_size)
assert permuted_probs.shape == (num_out_tokens,)
assert row_id_map.shape == (num_out_tokens,)
expected_permuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(permuted, expected_permuted, rtol=1e-5, atol=1e-5), (
f"Permuted values incorrect.\nExpected:\n{expected_permuted}\nGot:\n{permuted}"
)
expected_permuted_probs = torch.tensor(
[0.3, 0.3, 0.7, 0.7, 0.3, 0.3, 0.7, 0.7], device=_get_device(), dtype=torch.float32
)
assert torch.allclose(permuted_probs, expected_permuted_probs, rtol=1e-5, atol=1e-5), (
f"Permuted probs incorrect.\nExpected:\n{expected_permuted_probs}\n"
f"Got:\n{permuted_probs}"
)
unpermuted = moe_unpermute(
permuted,
row_id_map,
merging_probs=probs,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
routing_map=routing_map,
)
assert unpermuted.shape == (self.num_tokens, self.hidden_size)
expected_unpermuted = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(unpermuted, expected_unpermuted, rtol=1e-5, atol=1e-5), (
f"Unpermuted values incorrect.\nExpected:\n{expected_unpermuted}\nGot:\n{unpermuted}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
probs_grad = probs.clone().detach().requires_grad_(True)
permuted_g, permuted_probs_g, row_id_map_g = moe_permute_with_probs(
inp_grad, probs_grad, routing_map, num_out_tokens
)
permuted_processed = permuted_g * 2.0
unpermuted_g = moe_unpermute(
permuted_processed,
row_id_map_g,
merging_probs=probs_grad,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
routing_map=routing_map,
)
loss = unpermuted_g.sum()
loss.backward()
assert inp_grad.grad is not None
assert probs_grad.grad is not None
assert inp_grad.grad.shape == inp.shape
assert probs_grad.grad.shape == probs.shape
class TestMaskModePad:
"""Tests for permutation with mask routing map (padded mode)."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test fixtures."""
_seed()
self.num_tokens = 4
self.hidden_size = 4
self.topK = 2
self.num_experts = 4
self.align_size = 2
def test_permute_and_pad_with_probs(self):
"""Test permute_and_pad_with_probs and unpermute, verify values and gradients."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
routing_map = torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 1, 1],
],
device=_get_device(),
dtype=torch.int8,
)
probs = torch.tensor(
[
[0.5, 0.0, 0.0, 0.0],
[0.3, 0.7, 0.0, 0.0],
[0.0, 0.0, 0.5, 0.0],
[0.0, 0.0, 0.3, 0.7],
],
device=_get_device(),
dtype=torch.float32,
)
tokens_per_expert = routing_map.sum(dim=0)
permuted, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = (
moe_permute_and_pad_with_probs(
inp, probs, routing_map, tokens_per_expert, self.align_size
)
)
total_padded_tokens = target_tokens_per_expert.sum().item()
assert permuted.shape == (total_padded_tokens, self.hidden_size), (
f"Expected permuted shape ({total_padded_tokens}, {self.hidden_size}), "
f"got {permuted.shape}"
)
assert permuted_probs.shape == (total_padded_tokens,)
assert row_id_map.shape == (total_padded_tokens,)
assert pad_offsets.shape == (self.num_experts,)
expected_target = torch.tensor([2, 2, 2, 2], device=_get_device())
assert torch.equal(target_tokens_per_expert, expected_target), (
f"Expected target_tokens_per_expert {expected_target}, got {target_tokens_per_expert}"
)
expected_pad_offsets = torch.tensor([0, 0, 1, 1], device=_get_device())
assert torch.equal(pad_offsets, expected_pad_offsets), (
f"Expected pad_offsets {expected_pad_offsets}, got {pad_offsets}"
)
unpermuted = moe_unpermute(
permuted,
row_id_map,
merging_probs=probs,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
pad_offsets=pad_offsets,
routing_map=routing_map,
)
assert unpermuted.shape == (self.num_tokens, self.hidden_size)
expected_unpermuted = torch.tensor(
[
[0.5, 0.5, 0.5, 0.5],
[2.0, 2.0, 2.0, 2.0],
[1.5, 1.5, 1.5, 1.5],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(unpermuted, expected_unpermuted, rtol=1e-5, atol=1e-5), (
f"Unpermuted values incorrect.\nExpected:\n{expected_unpermuted}\nGot:\n{unpermuted}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
probs_grad = probs.clone().detach().requires_grad_(True)
permuted_g, permuted_probs_g, row_id_map_g, pad_offsets_g, target_tokens_g = (
moe_permute_and_pad_with_probs(
inp_grad, probs_grad, routing_map, tokens_per_expert, self.align_size
)
)
permuted_processed = permuted_g * 2.0
unpermuted_g = moe_unpermute(
permuted_processed,
row_id_map_g,
merging_probs=probs_grad,
restore_shape=(self.num_tokens, self.hidden_size),
map_type="mask",
pad_offsets=pad_offsets_g,
routing_map=routing_map,
)
loss = unpermuted_g.sum()
loss.backward()
assert inp_grad.grad is not None
assert probs_grad.grad is not None
assert inp_grad.grad.shape == inp.shape
assert probs_grad.grad.shape == probs.shape
class TestSortChunks:
"""Tests for sort chunks by index."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test fixtures."""
_seed()
self.num_tokens = 8
self.hidden_size = 4
@pytest.mark.skip(reason="Hanged to be fixed")
def test_sort_chunks_by_index(self):
"""Test sort chunks by index, verify values and gradients."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[5.0, 5.0, 5.0, 5.0],
[6.0, 6.0, 6.0, 6.0],
[7.0, 7.0, 7.0, 7.0],
[8.0, 8.0, 8.0, 8.0],
],
device=_get_device(),
dtype=torch.float32,
)
split_sizes = torch.tensor([2, 2, 2, 2], device=_get_device(), dtype=torch.int32)
sorted_index = torch.tensor([2, 0, 3, 1], device=_get_device(), dtype=torch.int32)
output = moe_sort_chunks_by_index(inp, split_sizes, sorted_index)
assert output.shape == inp.shape, f"Expected output shape {inp.shape}, got {output.shape}"
expected_output = torch.tensor(
[
[5.0, 5.0, 5.0, 5.0],
[6.0, 6.0, 6.0, 6.0],
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[7.0, 7.0, 7.0, 7.0],
[8.0, 8.0, 8.0, 8.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(output, expected_output, rtol=1e-5, atol=1e-5), (
f"Output values incorrect.\nExpected:\n{expected_output}\nGot:\n{output}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
output_g = moe_sort_chunks_by_index(inp_grad, split_sizes, sorted_index)
output_processed = output_g * 2.0
loss = output_processed.sum()
loss.backward()
assert inp_grad.grad is not None
assert inp_grad.grad.shape == inp.shape
expected_grad = torch.full_like(inp, 2.0)
assert torch.allclose(inp_grad.grad, expected_grad, rtol=1e-5, atol=1e-5), (
f"Gradient values incorrect.\nExpected:\n{expected_grad}\nGot:\n{inp_grad.grad}"
)
@pytest.mark.skip(reason="Hanged to be fixed")
def test_sort_chunks_by_index_with_probs(self):
"""Test sort chunks by index with probs, verify values and gradients."""
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[5.0, 5.0, 5.0, 5.0],
[6.0, 6.0, 6.0, 6.0],
[7.0, 7.0, 7.0, 7.0],
[8.0, 8.0, 8.0, 8.0],
],
device=_get_device(),
dtype=torch.float32,
)
probs = torch.tensor(
[0.3, 0.7, 0.3, 0.7, 0.3, 0.7, 0.3, 0.7], device=_get_device(), dtype=torch.float32
)
split_sizes = torch.tensor([2, 2, 2, 2], device=_get_device(), dtype=torch.int32)
sorted_index = torch.tensor([2, 0, 3, 1], device=_get_device(), dtype=torch.int32)
output, permuted_probs = moe_sort_chunks_by_index_with_probs(
inp, probs, split_sizes, sorted_index
)
assert output.shape == inp.shape
assert permuted_probs.shape == (self.num_tokens,)
expected_output = torch.tensor(
[
[5.0, 5.0, 5.0, 5.0],
[6.0, 6.0, 6.0, 6.0],
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[7.0, 7.0, 7.0, 7.0],
[8.0, 8.0, 8.0, 8.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=torch.float32,
)
assert torch.allclose(output, expected_output, rtol=1e-5, atol=1e-5), (
f"Output values incorrect.\nExpected:\n{expected_output}\nGot:\n{output}"
)
expected_probs = torch.tensor(
[0.3, 0.7, 0.3, 0.7, 0.3, 0.7, 0.3, 0.7], device=_get_device(), dtype=torch.float32
)
assert torch.allclose(permuted_probs, expected_probs, rtol=1e-5, atol=1e-5), (
f"Permuted probs incorrect.\nExpected:\n{expected_probs}\nGot:\n{permuted_probs}"
)
inp_grad = inp.clone().detach().requires_grad_(True)
probs_grad = probs.clone().detach().requires_grad_(True)
output_g, permuted_probs_g = moe_sort_chunks_by_index_with_probs(
inp_grad, probs_grad, split_sizes, sorted_index
)
output_processed = output_g * 2.0
loss = output_processed.sum()
loss.backward()
assert inp_grad.grad is not None
assert probs_grad.grad is not None
assert inp_grad.grad.shape == inp.shape
assert probs_grad.grad.shape == probs.shape
class TestEdgeCases:
"""Tests for edge cases and error handling."""
@pytest.fixture(autouse=True)
def setup(self):
"""Setup test fixtures."""
_seed()
def test_empty_input_index_mode(self):
"""Test with empty input in index mode."""
num_tokens = 0
hidden_size = 4
topK = 2
inp = torch.empty((num_tokens, hidden_size), device=_get_device(), dtype=torch.float32)
index = torch.empty((num_tokens, topK), device=_get_device(), dtype=torch.int32)
probs = torch.empty((num_tokens, topK), device=_get_device(), dtype=torch.float32)
permuted, row_id_map = moe_permute(inp, index, 0, map_type="index")
assert permuted.shape == (0, hidden_size)
assert row_id_map.shape == (0,)
unpermuted = moe_unpermute(permuted, row_id_map, merging_probs=probs, map_type="index")
assert unpermuted.shape == (0, hidden_size)
def test_empty_input_mask_mode(self):
"""Test with empty input in mask mode."""
num_tokens = 0
hidden_size = 4
num_experts = 4
inp = torch.empty((num_tokens, hidden_size), device=_get_device(), dtype=torch.float32)
routing_map = torch.empty((num_tokens, num_experts), device=_get_device(), dtype=torch.int8)
probs = torch.empty((num_tokens, num_experts), device=_get_device(), dtype=torch.float32)
permuted, row_id_map = moe_permute(inp, routing_map, 0, map_type="mask")
assert permuted.shape == (0, hidden_size)
assert row_id_map.shape == (0,)
unpermuted = moe_unpermute(
permuted,
row_id_map,
merging_probs=probs,
restore_shape=(num_tokens, hidden_size),
map_type="mask",
routing_map=routing_map,
)
assert unpermuted.shape == (0, hidden_size)
def test_single_token_index_mode(self):
"""Test with single token in index mode."""
hidden_size = 4
topK = 2
inp = torch.tensor([[1.0, 1.0, 1.0, 1.0]], device=_get_device(), dtype=torch.float32)
index = torch.tensor([[0, 1]], device=_get_device(), dtype=torch.int32)
probs = torch.tensor([[0.3, 0.7]], device=_get_device(), dtype=torch.float32)
permuted, row_id_map = moe_permute(inp, index, topK, map_type="index")
assert permuted.shape == (topK, hidden_size)
assert row_id_map.shape == (topK,)
assert torch.allclose(permuted[0], inp[0], rtol=1e-5, atol=1e-5)
assert torch.allclose(permuted[1], inp[0], rtol=1e-5, atol=1e-5)
unpermuted = moe_unpermute(permuted, row_id_map, merging_probs=probs, map_type="index")
assert unpermuted.shape == (1, hidden_size)
expected = torch.tensor([[1.0, 1.0, 1.0, 1.0]], device=_get_device(), dtype=torch.float32)
assert torch.allclose(unpermuted, expected, rtol=1e-5, atol=1e-5)
def test_single_token_mask_mode(self):
"""Test with single token in mask mode."""
hidden_size = 4
inp = torch.tensor([[1.0, 1.0, 1.0, 1.0]], device=_get_device(), dtype=torch.float32)
routing_map = torch.tensor([[1, 1, 0, 0]], device=_get_device(), dtype=torch.int8)
probs = torch.tensor([[0.3, 0.7, 0.0, 0.0]], device=_get_device(), dtype=torch.float32)
num_out_tokens = 2
permuted, row_id_map = moe_permute(inp, routing_map, num_out_tokens, map_type="mask")
assert permuted.shape == (num_out_tokens, hidden_size)
assert row_id_map.shape == (num_out_tokens,)
assert torch.allclose(permuted[0], inp[0], rtol=1e-5, atol=1e-5)
assert torch.allclose(permuted[1], inp[0], rtol=1e-5, atol=1e-5)
unpermuted = moe_unpermute(
permuted,
row_id_map,
merging_probs=probs,
restore_shape=(1, hidden_size),
map_type="mask",
routing_map=routing_map,
)
assert unpermuted.shape == (1, hidden_size)
expected = torch.tensor([[1.0, 1.0, 1.0, 1.0]], device=_get_device(), dtype=torch.float32)
assert torch.allclose(unpermuted, expected, rtol=1e-5, atol=1e-5)
def test_invalid_map_type(self):
"""Test with invalid map_type."""
inp = torch.randn(4, 4, device=_get_device(), dtype=torch.float32)
routing_map = torch.zeros(4, 4, device=_get_device(), dtype=torch.int8)
with pytest.raises(ValueError, match="map_type should be one of"):
moe_permute(inp, routing_map, 8, map_type="invalid")
def test_different_dtypes(self):
"""Test with different data types."""
num_tokens = 4
topK = 2
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
inp = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
device=_get_device(),
dtype=dtype,
)
index = torch.tensor(
[
[0, 1],
[0, 1],
[2, 3],
[2, 3],
],
device=_get_device(),
dtype=torch.int32,
)
probs = torch.tensor(
[
[0.3, 0.7],
[0.3, 0.7],
[0.3, 0.7],
[0.3, 0.7],
],
device=_get_device(),
dtype=torch.float32,
)
permuted, row_id_map = moe_permute(inp, index, num_tokens * topK, map_type="index")
assert permuted.dtype == dtype, f"Expected dtype {dtype}, got {permuted.dtype}"
unpermuted = moe_unpermute(permuted, row_id_map, merging_probs=probs, map_type="index")
assert unpermuted.dtype == dtype, f"Expected dtype {dtype}, got {unpermuted.dtype}"