import torch
import torch.nn as nn
from torch.distributed._functional_collectives import (
all_to_all_single,
all_to_all_single_autograd,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
import triton
import triton.language as tl
TOKEN_GROUP_ALIGN_SIZE_M = 8
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
i = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + i)
length = tl.load(tokens_per_expert_group_ptr + i)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full(
(max_len,),
-1,
dtype=torch.int32,
)
for e in range(experts_per_rank):
write_start = write_offsets[e].item()
for r in range(num_ranks):
i = r * experts_per_rank + e
start_index = start_index_values[i].item()
length = tokens_per_expert_group[i].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
"""
Prepare permutation indices and the number of tokens for each expert.
Args:
tokens_per_expert_group: number of tokens for each expert from all ranks.
experts_per_rank: number of experts per rank.
num_ranks: number of ranks.
max_len: maximum length of the output index vector.
alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts.
use_cpu: whether to use CPU implementation.
Returns:
permuted_indices: Tensor of indices that map original token order to the expert-grouped order.
m_sizes: aligned number of tokens for each expert (padded to alignment boundary).
m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens.
Explanatory details:
`tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
From: | rank 0 | rank 1 |
To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
| 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
"""
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
def _round_up(x: int, y: int) -> int:
"""Round up x to the nearest multiple of y."""
x_ceil_div_y = (x + y - 1) // y
return x_ceil_div_y * y
def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts):
global TOKEN_GROUP_ALIGN_SIZE_M
x_padded_per_expert = x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
with torch.no_grad():
(permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices(
num_tokens_per_expert,
num_local_experts,
ep_degree,
padded_max_len,
TOKEN_GROUP_ALIGN_SIZE_M,
)
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]
return input_shape, x, permuted_indices, num_tokens_per_expert
def _unpermute(out, input_shape, permuted_indices):
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
return out
class TensorParallel(ParallelStyle):
def _prepare_input_fn(self, mod, inputs, device_mesh):
routed_input, num_tokens_per_expert = inputs
routed_input = DTensor.from_local(
routed_input, device_mesh, (Replicate(),)
).to_local(grad_placements=(Partial(),))
return routed_input, num_tokens_per_expert
def _partition_fn(self, name, module, device_mesh):
module.register_parameter(
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)]))
)
module.register_parameter(
"w2",
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])),
)
module.register_parameter(
"w3",
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(1)])),
)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
self._prepare_input_fn,
)
class ExpertParallel(ParallelStyle):
def __init__(self):
super().__init__()
self.input_splits = None
self.output_splits = None
self.input_shape = None
self.permuted_indices = None
def _token_dispatch(self, mod, inputs, device_mesh: DeviceMesh):
routed_input, num_tokens_per_expert = inputs
ep_degree = device_mesh.shape[0]
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree
with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
output_splits = (
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()
routed_input = all_to_all_single_autograd(
routed_input,
self.output_splits,
self.input_splits,
device_mesh.get_group(),
)
(
self.input_shape,
routed_input,
self.permuted_indices,
num_tokens_per_expert_group,
) = _permute(
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
)
return routed_input, num_tokens_per_expert_group
@staticmethod
def _partition_fn(name, mod, device_mesh):
for name, param in mod.named_parameters(recurse=False):
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
mod.register_parameter(name, dist_param)
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = _unpermute(
routed_output, self.input_shape, self.permuted_indices
)
routed_output = all_to_all_single_autograd(
routed_output,
self.input_splits,
self.output_splits,
device_mesh.get_group(),
)
return routed_output
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=ExpertParallel._partition_fn,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)
class ExpertTensorParallel(ExpertParallel):
def _token_dispatch(self, mod, inputs, device_mesh):
routed_input, num_tokens_per_expert = inputs
routed_input = DTensor.from_local(
routed_input, device_mesh["tp"], (Replicate(),)
).to_local(grad_placements=(Partial(),))
inputs = (routed_input, num_tokens_per_expert)
return super()._token_dispatch(mod, inputs, device_mesh["ep"])
def _partition_fn_2d(self, name, mod, ep_tp_mesh):
mod.register_parameter(
"w1",
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])),
)
mod.register_parameter(
"w2",
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])),
)
mod.register_parameter(
"w3",
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])),
)
def _token_combine(self, mod, routed_output, device_mesh):
return super()._token_combine(mod, routed_output, device_mesh["ep"])
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=self._partition_fn_2d,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)
class ReordererSequenceParallel(ParallelStyle):
def __init__(self):
super().__init__()
def _prepare_inputput_fn(self, mod, inputs, device_mesh: DeviceMesh):
top_scores, selected_experts_indices = inputs
num_tokens, _ = top_scores.shape
def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
assert x.is_contiguous()
if num_tokens % device_mesh.size() != 0:
raise ValueError(
"Uneven split of tokens of is not supported yet. "
"Requires EP degree dividing batch size * seq len."
)
local_num_tokens = num_tokens // device_mesh.size()
local_rank = device_mesh.get_local_rank()
offset = local_rank * local_num_tokens
output = x[offset : offset + local_num_tokens]
return output
top_scores = _split_along_first_dim(top_scores)
selected_experts_indices = _split_along_first_dim(selected_experts_indices)
return top_scores, selected_experts_indices
def _prepare_output_fn(self, mod, outputs, device_mesh: DeviceMesh):
top_scores, token_indices_experts_sorted, num_tokens_per_expert = outputs
local_rank = device_mesh.get_local_rank()
if not hasattr(mod, "top_k"):
raise ValueError(
"TokenReorderer class in MoE should always have top_k attribute."
)
token_indices_experts_sorted += top_scores.shape[0] // mod.top_k * local_rank
return top_scores, token_indices_experts_sorted, num_tokens_per_expert
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partition_fn=None,
input_fn=self._prepare_inputput_fn,
output_fn=self._prepare_output_fn,
)