import torch
from torch.distributed._tensor import Partial, Replicate, Shard
from torch.distributed._tensor.experimental import register_sharding
npu = torch.ops.npu
@register_sharding(npu.npu_moe_token_permute.default)
def npu_moe_token_permute_strategy(tokens, indices, num_out_tokens=None, padded_mode=False):
strategies = []
replicate_strategy = (
[Replicate(), Replicate()],
[Replicate(), Replicate(), None, None]
)
strategies.append(replicate_strategy)
hidden_size_sharding_strategy = (
[Shard(1), Replicate()],
[Shard(1), Replicate(), None, None]
)
strategies.append(hidden_size_sharding_strategy)
return strategies
@register_sharding(npu.npu_moe_token_permute_grad.default)
def npu_moe_token_permute_grad_strategy(tokens, grad_permuted_tokens, indices, sorted_indices, padded_mode=False):
strategies = []
replicate_strategy = (
[Replicate()],
[Replicate(), Replicate(), Replicate(), Replicate(), None]
)
strategies.append(replicate_strategy)
hidden_size_sharding_strategy = (
[Shard(1)],
[Shard(1), Shard(1), Replicate(), Replicate(), None]
)
strategies.append(hidden_size_sharding_strategy)
return strategies
@register_sharding(npu.npu_moe_token_unpermute.default)
def npu_moe_token_unpermute_strategy(permuted_tokens, sorted_indices, probs=None, padded_mode=False,
restore_shape=None):
strategies = []
replicate_strategy = (
[Replicate()],
[Replicate(), Replicate(), None if probs is None else Replicate(), None, None]
)
strategies.append(replicate_strategy)
hidden_size_sharding_strategy = (
[Shard(1)],
[Shard(1), Replicate(), None if probs is None else Replicate(), None, None]
)
strategies.append(hidden_size_sharding_strategy)
return strategies
@register_sharding(npu.npu_moe_token_unpermute_grad.default)
def npu_moe_token_unpermute_grad_strategy(permuted_tokens, grad_unpermuted_tokens, sorted_indices, probs=None,
padded_mode=False, restore_shape=None):
strategies = []
replicate_strategy = (
[Replicate(), None if probs is None else Replicate()],
[Replicate(), Replicate(), Replicate(), None if probs is None else Replicate(), None, None]
)
strategies.append(replicate_strategy)
hidden_size_sharding_strategy = (
[Shard(1), Partial()],
[Shard(1), Shard(1), Replicate(), None if probs is None else Replicate(), None, None]
)
strategies.append(hidden_size_sharding_strategy)
return strategies