from typing import Any, Dict, List, Literal, Optional, Union

import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer


class Transformer(nn.Module):

    save_in_root: bool = True

    def __init__(
        self,
        model_name_or_path: str = "jinaai/jina-embeddings-v5-text-small",
        max_seq_length: Optional[int] = None,
        config_args: Optional[Dict[str, Any]] = None,
        model_args: Optional[Dict[str, Any]] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        cache_dir: Optional[str] = None,
        backend: Literal["torch", "onnx", "openvino"] = "torch",
        **kwargs,
    ) -> None:
        super(Transformer, self).__init__()
        if backend != "torch":
            raise ValueError(
                f"Backend '{backend}' is not supported, please use 'torch' instead"
            )
        config_kwargs = config_args or {}
        model_kwargs = model_args or {}

        self.config = AutoConfig.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **config_kwargs
        )
        self.default_task = model_args.pop("default_task", None)
        if self.default_task and self.default_task not in self.config.task_names:
            raise ValueError(
                f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}."
            )

        self.model = AutoModel.from_pretrained(
            model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
        )
        self.max_seq_length = max_seq_length or self.config.max_position_embeddings

    def tokenize(
        self, texts: List[str], padding: Union[str, bool] = True
    ) -> Dict[str, torch.Tensor]:
        return self.tokenizer(
            texts, max_length=self.max_seq_length, truncation=True, padding=padding, return_tensors="pt"
        )

    def forward(
        self,
        features: Dict[str, torch.Tensor],
        task: Optional[str] = None,
        truncate_dim: Optional[int] = None,
    ) -> Dict[str, torch.Tensor]:
        self.model.eval()
        if task is None:
            if self.default_task is None:
                raise ValueError(
                    "Task must be specified before encoding data. You can set it either during "
                    "loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or "
                    "pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))."
                )
            task = self.default_task
        else:
            if task not in self.config.task_names:
                raise ValueError(
                    f"Invalid task: {task}. Must be one of {self.config.task_names}."
                )
        self.model.set_adapter(task)

        device = self.model.device

        with torch.no_grad():
            batch = {k: v.to(device) for k, v in features.items() if torch.is_tensor(v)}
            outputs = self.model(
                **batch
            )
            hidden = outputs.last_hidden_state
            mask = batch.get("attention_mask")
            if mask is None:
                pooled = hidden[:, -1]
            else:
                sequence_lengths = mask.sum(dim=1) - 1
                pooled = hidden[
                    torch.arange(hidden.shape[0], device=hidden.device),
                    sequence_lengths,
                ]

            if truncate_dim is not None:
                pooled = pooled[:, :truncate_dim]
            embeddings = F.normalize(pooled, p=2, dim=-1)

        features["sentence_embedding"] = embeddings
        return features

    @classmethod
    def load(cls, input_path: str) -> "Transformer":
        return cls(model_name_or_path=input_path)