import torch
from torch import nn

from ..model_config import WordEmbeddingTPMode
from ..parallel_group import ParallelGroup
from .utils import get_partial_sharded, ModelWrapperBase


class ParallelEmbedding(ModelWrapperBase):
    """
    A parallel embedding layer that replaces a standard torch.nn.Embedding layer.
    """

    def __init__(
        self,
        embedding: torch.nn.Embedding,
        tp_group: ParallelGroup,
        shard_mode: WordEmbeddingTPMode = WordEmbeddingTPMode.col,
    ):
        super().__init__(embedding)
        self.tp_group = tp_group
        self.tp_size = tp_group.world_size
        self.tp_rank = tp_group.rank_in_group
        try:
            self.shard_mode = WordEmbeddingTPMode(shard_mode)
        except ValueError as err:
            raise ValueError(f"word embedding tp mode must be 'col' or 'row', got {shard_mode!r}.") from err
        self._vocab_size = self.num_embeddings
        self._row_start = 0
        self._row_end = self._vocab_size
        self._orig_padding_idx = embedding.padding_idx
        self.create_weights()

    def create_weights(self):
        if not self.tp_size > 1:
            return
        shard_dim = 1 if self.shard_mode == WordEmbeddingTPMode.col else 0
        shard_weight = get_partial_sharded(self._inner.weight, self.tp_size, self.tp_rank, dim=shard_dim)
        self._inner.weight = nn.Parameter(shard_weight.contiguous())
        if self.shard_mode == WordEmbeddingTPMode.row:
            block_size = self._inner.weight.shape[0]
            self._row_start = self.tp_rank * block_size
            self._row_end = min(self._row_start + block_size, self._vocab_size)
            orig_padding_idx = self._orig_padding_idx
            if orig_padding_idx is not None and orig_padding_idx < 0:
                orig_padding_idx = self._vocab_size + orig_padding_idx
            if orig_padding_idx is not None and self._row_start <= orig_padding_idx < self._row_end:
                self._inner.padding_idx = orig_padding_idx - self._row_start
            else:
                self._inner.padding_idx = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.tp_size == 1:
            return self._inner(x)
        if self.shard_mode == WordEmbeddingTPMode.row:
            in_local_vocab = (x >= self._row_start) & (x < self._row_end)
            safe_local_indices = torch.where(in_local_vocab, x - self._row_start, torch.zeros_like(x))
            x = self._inner(safe_local_indices)
            x = x.masked_fill_(~in_local_vocab.unsqueeze(-1), 0)
            return self.tp_group.all_reduce(x)
        x = self._inner(x)
        x = self.tp_group.all_gather(x, dim=-1)
        x = x[..., : self.embedding_dim]
        return x